Day to_day

[논문 리뷰] DeiT: Training data-efficient image transformers & distillation through attention 본문

논문 리뷰

[논문 리뷰] DeiT: Training data-efficient image transformers & distillation through attention

m_inglet 2024. 2. 8. 19:48
728x90
반응형

 

들어가며

ViT모델 다음 DeiT가 나왔고 그 이후에 Swin Transformer가 나왔는데, 더 흥미로운 논문 먼저 읽다 보니 DeiT를 놓칠뻔했다!

ViT모델과는 어떤 점이 다른지, ViT의 한계를 DeiT는 어떤 방법으로 극복했는지를 중심으로 논문리뷰를 남기려고한다.

 

미리 알고 있어야 할 개념

이전 포스팅으로 한번 정리했듯이 Knowledge Distillation, KL divergence, Cross Entropy에 대한 개념을 한번 정리하고 이 논문을 보면 더 쉽게 이해할 수 있을 것이다.

 

Cross Entropy 개념 / KL divergence 정리

 

Cross Entropy 개념 / KL divergence 정리 (추가)

포스팅 개요 그동안 크로스 엔트로피에 대해서 자주 들었지만 내가 설명하려면 정확하게 말이 잘 안 나왔었다. 이번 기회에 정확하게 정리하고 정보이론에서 정보란 무엇이며, 엔트로피는 어떻

day-to-day.tistory.com

 

Knowledge Distillation 이해하기

 

Knowledge Distillation 이해하기

들어가며 cross entropy의 개념에 대해 이미 알고 있고, DeiT 논문을 읽기 시작하면서 Knowledge Distillation의 개념에 대해 들어는 봤지만 정확하게 정리가 되지 않은 상태였다. 그래서 이번 기회에 Knowledg

day-to-day.tistory.com

 

 

DeiT 이것이 핵심이다!

1. ViT와 동일한 Transformer 모델 구조이다.

2. Knowledge Distillation 기법으로 CNN 구조의 Teacher 모델(RegNet)의 지식을 학습한 Student 모델이다.

3. 결과적으로 ViT의 inductive bias 문제를 해결함과 동시에 적은 데이터로도 높은 효율을 달성하였다.

 

DeiT 배경

비전 작업을 해결하기 위해 트랜스포머의 구성 요소를 컨볼루션 신경망에 이식하는 혼합 아키텍처를 제안하는 연구(ViT)가 나오면서 "데이터 양이 부족한 상태에서 훈련되면 일반화가 잘 되지 않는다"라고 결론 내리고, 해당 모델의 훈련에는 많은 컴퓨팅 리소스가 사용되는 점을 한계점으로 짚었다.

그래서 DeiT 논문에서는 특별한 아이디어 설계가 들어갔다기보다는 ViT를 data-effiecient 하게 학습시키기 위해 다양한 실험들을 통해 기존의 연구의 효과를 분석하면서 ablation study를 통해 인사이트를 제시하고 있다.

 

DeiT architecture

DeiT의 아키텍처는 ViT와 동일한 구조이며 다른 점은 패치를 나눌 때 달라진다.

  • ViT 모델의 input으로 class token, image patch가 들어가는데 DeiT에서는 추가로 distillation 토큰이 들어간다. distillation 토큰은 class token처럼 랜덤으로 초기화되어서 들어간다.

다음은 class token과 distillation token이 학습되는 과정을 그린 것인데 class token과 distillation token은 모두 트랜스포머에 입력되어 역전파를 통해 학습이 되지만 오차를 계산하는 대상이 다르다.

class token의 목표는 실제 label과 일치하도록 오차를 줄여 나아가는 방향으로 학습이 되고, distillation token은 teacher 모델의 추론 값과 오차를 줄여나가는 방향으로 학습이 된다.

 

Class token

class 토큰은 ViT 논문 리뷰에서도 설명했듯이 학습 가능한 벡터로, 첫 번째 레이어 이전 패치 토큰에 추가된다. 이 토큰은 트랜스포머 레이어를 거치고, 그다음 선형 레이어로 투영되어 클래스 예측하게 된다. 트랜스포머에서는 차원 D의 (N+1) 토큰 배치를 처리하며, 이 중에서 클래스 벡터만이 출력을 예측하는 데 사용된다.

class 토큰은 매번 랜덤 한 값으로 초기화가 되고, Transformer의 self-attention이 패치 토큰과 클래스 토큰 간에 정보를 전파하도록 하는 역할을 한다.

 

Distillation token

미리 knowledge distillation에 대한 사전 지식이 있다면 soft label과 hard label에 대해 숙지를 했을 것이다. distillation label에 따라 달라지는 전체 loss를 구하는 식은 아래와 같다.

Soft distillation

아래의 식에서 첫 번째 항은 class token과 실제 label의 cross entropy loss이고, 두 번째 항을 비교하며 보면 좋을 것 같다. soft distillation에서는 Teacher 모델의 소프트맥스와 Student 모델의 소프트맥스 간의 쿨백-라이블러 발산(Kullback-Leibler divergence)을 최소화하는 것이 목적이다.

  • 교사 모델의 로짓을 Zt, 학생 모델의 로짓을 Zs로 표기하고,
  • 증류를 위한 온도를 τ로, 쿨백-라이블러 발산 손실(KL) 및 실제 레이블 y에 대한 교차 엔트로피 손실 (LCE)을 균형 조절하는 계수를 λ로, 소프트맥스 함수를 ψ로 나타낸다.

 

Hard label distillation

Hard label distillation은 Teacher Model에서 가장 큰 Softmax 값을 가진 Label(yt)을 True Label로 처리하여 Cross Entropy를 구한다. 그리고 Ground Truth와의 Cross Entropy를 구하고, 이 둘을 평균 내는 방식으로 Global Loss를 얻어내게 되며 추가적으로 hard label을 사용하고 label smoothing을 해줌으로써 soft label과 비슷한 효과를 가져올 수 있어 실제 실험결과 soft distillation보다 좋은 성능을 얻어냈다고 한다

soft label이 예측이 떨어질 수 있는 이유는 데이터 증강(예: crop augmentation) 과정에서 예측하기 어려운 이미지가 만들어진다면, teacher model의 예측 확률이 변경되어 다른 class들의 확률 값이 높아져서 오히려 왜곡될 가능성이 있기 때문이다. 예를 들어 사진에 고양이로 라벨이 되어있지만 augmetation 과정에서 고양이가 없는 이미지가 crop 되어 높은 예측 확률을 반환하지 못할 수 있다.

 

그러면 Class token과 Distillation token이 같아지지 않나?

저자는 실험을 통해 학습된 class 및 distillation 토큰이 다른 벡터로 수렴했다고 한다. 이러한 토큰 간의 평균 코사인 유사도는 0.06이다. 그 이유는 class 및 distillation 임베딩이 각 레이어에서 계산되기 때문에 이들은 네트워크를 통해 서서히 유사해지며, 마지막 레이어에서 그 유사성이 높아지지만 (cos=0.93), 여전히 1보다 낮다고 한다. 결국 두 개의 토큰이 유사하지만 동일하지 않은 목표를 생성하기 때문에 같아지진 않는다고 한다.

 

Experiments

DeiT-B는 ViT-B와 동일한 아키텍처이고, 임베딩 차원과 헤드 수를 조절함으로써 DeiT-Ti, DeiT-S를 생성했다고 한다. (헤드 당 차원은 64로 일정하게 유지)

실험에 대한 내용은 아래와 같다.

  • 해상도 224x224 이미지에 대해서 처리량 측정
  • Test dataset은 ImageNet으로 진행
  • fine tunning 할 때 larger resolution으로 진행했으며 모델마다 옆에 DeiT-B↑384 이런 식으로 해상도가 적혀있음
  • distillation 과정을 거친 모델에 대해서는 DeiT⚗라고 표시

최종적으로,

ImageNet-1k에서 DeiT의 최고 모델은 85.2%의 상위 1위 정확도를 갖고 ⇒ 384 해상도에서 JFT-300M에서 pre-train 된 Vit-B 모델 (84.15%)을 능가하는 수준이라고 한다.

저자는 Convnet 모델은 inductive bias 때문에 좋은 teacher model로 선택하였고, 이후에 모든 실험에서는 기본 teacher model로 RegNetY-16GF (84M 매개변수)를 사용했다고 한다. 그리고 이 RegNetY-16GF은 ImageNet에서 82.9%의 상위 1위 정확도를 달성한 모델이다.

 

 

 

또 distillation의 효과를 보기 위해 위쪽 세줄은 soft distillation 기법보다 hard distillation이 더 성능이 향상됨을 증명하고,

아래쪽 세줄은 

  1. 클래스 토큰만 사용
  2. distillation 토큰만 사용
  3. 클래스 + distillation 토큰 모두 사용

위와 같이 조건을 바꾸어 클래스와 distillation 토큰 모두를 사용하는 것이 좋은 성능을 낸다는 것을 보여주었다.

 

 

Conclusion

다시 한번 정리를 하자면, DeiT는 이미지 트랜스포머 구조로 knowledge distillation 방법으로 많은 양의 데이터를 필요하지 않고 훈련을 시킨 모델이다.

DeiT에서는 기존의 데이터 증강 및 정규화 전략을 시작점으로 삼았으며, distillation 토큰 이외의 중요한 아키텍처를 도입하진 않았다. 따라서 이미지 트랜스포머가 이미 컨볼루션 신경망과 비슷한 성능을 보이고 있음을 고려할 때, 저자는 트랜스포머에 더 맞는 데이터 증강에 관한 발전된 연구들을 기대할 수 있을 것이라 남겼다.

 

Reference

https://velog.io/@heomollang/DeiT-관련-논문-리뷰-04-Training-data-efficient-image-transformers-distillation-through-attentionDeiT

https://hyoseok-personality.tistory.com/entry/Paper-Review-DeiT-Training-data-efficient-image-transformers-distillation-through-attention

728x90
반응형
BIG
Comments