일 | 월 | 화 | 수 | 목 | 금 | 토 |
---|---|---|---|---|---|---|
1 | ||||||
2 | 3 | 4 | 5 | 6 | 7 | 8 |
9 | 10 | 11 | 12 | 13 | 14 | 15 |
16 | 17 | 18 | 19 | 20 | 21 | 22 |
23 | 24 | 25 | 26 | 27 | 28 |
- 빠르게 실패하기
- 정밀도
- 평가 지표
- Normalization
- 데이터 전처리
- 오차 행렬
- NVL
- recall
- LAG
- NULLIF
- layer normalization
- ifnull
- 결정트리
- SQL
- 재현율
- Batch Normalization
- 비지도학습
- 웹서비스 기획
- CASE WHEN
- five lines challenge
- 백엔드
- nvl2
- 지도학습
- 데이터 분석
- beautifulsoup
- 감정은 습관이다
- 데이터 프로젝트
- DecisionTree
- 강화학습
- sorted
- Today
- Total
Day to_day
Knowledge Distillation 이해하기 본문
들어가며
cross entropy의 개념에 대해 이미 알고 있고, DeiT 논문을 읽기 시작하면서 Knowledge Distillation의 개념에 대해 들어는 봤지만 정확하게 정리가 되지 않은 상태였다. 그래서 이번 기회에 Knowledge Distillation의 개념과 특징에 대해 정리하고자 한다.
Knowledge Distillation (증류)
딥러닝에서 distillation 증류?라고 생각할 수 있다. 이 단어의 뜻을 찾아보고선 나도 똑같은 반응이었으니깐.
큰 모델(Teacher Network)로부터 증류한 지식을 작은 모델(Student Network)로 transfer 하는 방법이라고 할 수 있다.
핵심은 큰 모델이 학습한 일반화된 능력을 작은 모델에 전달해주는 것을 말한다.
Soft Label과 Hard Label
이미 학습된 딥러닝 모델(Teacher model)로 어떤 task에서 예측을 하면 각 클래스마다 확률값을 출력할 것이다.
예를 들어 확률 [0.000001, 0.0000003, 0.9, 0.000002, 0.000001] 값을 출력했다고 하면 정답 확률은 soft label은 정답 확률 이외의 다른 확률들도 하나의 지식으로 본다는 것이 특징이다.
그래서 [0.01, 0.03, 0.89, 0.06, 0.1]과 같이 정답 확률과 너무 큰 차이가 나지 않도록 값을 조금 변경한다. 이것은 정답 확률 이외에 더 많은 정보를 갖고 있다고 할 수 있고, 이 soft label을 사용하여 student model을 학습한다. 수식은 아래와 같다.
논문에서는 t를 Temperature라고 표현하고 있고, 증류라는 개념이 들어갔기 때문에 온도로 표현하면서 이 값에 따라 값이 높아지면 더 soft 하게, 낮아지면 hard 하게 만드는 것이다. softmax 함수에 zi(x) 대신 zi(x)/t를 사용해서 softmax 함수에서 입력값이 큰 것은 아주 크게, 입력값이 작은 것은 아주 작게 만드는 성질을 완화시켜 준다.
아래의 예시를 들어보면 t를 3으로 두었을 때가 값의 차이가 좀 더 soft 해졌음을 직관적으로 볼 수 있다.
그러면 hard label은 soft label과 반대로 [0, 0, 1, 0, 0]과 같이 극단적인 값을 갖는 것을 말한다.
정답 이외에 다른 정보는 포함시키지 않는 것이다.
Knowledge Distillation loss
Knowlege Distillation에서 loss를 계산하는 방식이 다른 모델들과 조금 다른데 아래의 그림과 식을 가지고 이해를 해보자. 두 가지의 loss를 더해서 loss를 계산한다.
첫 번째 항은 Teacher model에서 Soft label을 계산하고, 이 Soft label과 동일한 결괏값을 출력하도록 student model을 학습시킨다. Teacher model에서의 예측을 정보 손실 없이 Student model이 모방하도록 학습시키는 방법이다.
두 번째 항은 student model의 출력값과 hard label(실제 ground truth) 사이의 cross entropy loss를 계산한다. 그리고 α는 두 항 사이의 비율을 조절하는 역할이다.
'Deep Learning' 카테고리의 다른 글
CBOW & Skip gram 개념 완벽 이해하기!! (0) | 2024.03.26 |
---|---|
Batch Normalization, Layer Normalization 비교 (5) | 2024.03.16 |
1 x 1 convolution이란? 직관적으로 이해해보기 (0) | 2024.02.25 |
Global Average Pooling이 뭐길래? (3) | 2024.02.24 |
Cross Entropy 개념 / KL divergence 정리 (추가) (1) | 2024.01.02 |