๋ณธ ํฌ์คํธ๋ฅผ ์ฝ๊ธฐ ์ ์ DistilBERT์ ์ฌ์ฉ๋ ๋ฉ์ธ ํ ํฌ๋์ธ Knowledge Distillation์ ๋ํด์ ๋จผ์ ํ์ตํ์๊ธธ ๋ฐ๋๋๋ค. ๋ค์์ ํฌ์คํธ๋ฅผ ์ฐธ๊ณ ํ์์ค.
The overview of this paper
NLP์์ large-scale์ pre-trained model์ ํ์ฉํ์ฌ transfer learning์ ์ฒ๋ฆฌํ๋ ์ผ์ด ํํด์ง๋ฉด์, ์ด ๊ฑฐ๋ํ ๊ท๋ชจ์ ๋ชจ๋ธ์ ํ์ ๋ ์์์ผ๋ก ์ด๋ป๊ฒ ๊ตฌ๋ํ ์ง๋ ์์ง๋ ์ด๋ ค์ด ๋ฌธ์ ๋ก ๋จ์์๋ค. ๊ทธ๋์ ์ด ๋ ผ๋ฌธ์์๋ ์์ ๊ท๋ชจ์ general purpose language representation model์์๋ ๋ถ๊ตฌํ๊ณ , ๋ค์ํ ๋ถ์ผ์ task์ ๋ํด ์ข์ ์ฑ๋ฅ์ ๋ณด์ฌ์ฃผ๋ DistilBERT๋ฅผ ์ ์ํ์๋ค. ์ด DistilBERT๋ BERT์ ๋นํด 40% ์์ ํฌ๊ธฐ์ 60% ๋ ๋น ๋ฅธ ์๋๋ฅผ ๋ณด์ฌ์ฃผ๋ฉด์, ๊ธฐ์กด BERT์ 97%์ ๋ฌํ๋ ์ฑ๋ฅ์ ๋ณด์ฌ์ฃผ์๋ค. larger model์ด pre-training ๋์ค์ ํ์ตํ๋ inductive bias๋ฅผ leverageํ๊ธฐ ์ํด์, triple loss๊ฐ ํฉ์ณ์ง LM์ ์๊ฐํ์๋ค. ๋ ผ๋ฌธ์์ ์๊ฐํ ์ด ์๊ณ ๋น ๋ฅธ ๋ชจ๋ธ์ on-device ํ๊ฒฝ์์๋ ์ถฉ๋ถํ ๊ตฌ๋ํ๋ค๋ ๊ฒ ๋ํ ์ฆ๋ช ๋์๋ค.
Table of Contents
1. Introduction
2. Knowledge distillation
3. DistilBERT: a distilled version of BERT
4. Ablation study
1. Introduction
์์ฆ์ ๊ฐ๋ฐ๋๋ pre-trained language model์ ์ดํด๋ณด๋ฉด ๋๋ถ๋ถ์ด ์๋ง์ ํ๋ผ๋ฏธํฐ๋ฅผ ๊ฐ์ง๊ณ ์๋ค. ๊ทธ๋ฆฌ๊ณ , ์์ฆ์ ๋ง์ ์ฐ๊ตฌ๋ค์ด ๋ ํฐ ๊ท๋ชจ์ LM์ ๋ ์ข์ ์ฑ๋ฅ ํฅ์์ ๊ฐ์ ธ์จ๋ค๋ ๊ฒ์ ์ฆ๋ช ํ๊ณ ์๋ค.
์ด๋ฌํ ํฐ ๊ท๋ชจ์ LM์ ๋ํ ๊ด์ฌ์ ๋ค์์ ๊ฑฑ์ ์ ๋ถ๋ฌ์ผ์ผํฌ ์ ์๋ค.
- ๋น์ฉ์ด ๋ง์ด ๋ฐ์ํ๊ฒ ๋๋ค. $($computationally & environmentally$)$
- LM์ ๋ฐ๋ฌ์ LM์ด on-device ํ๊ฒฝ์์ ๊ตฌ๋ํ ์ ์๋ ์ ์ฌ๋ ฅ์ ๋ณด์ฌ์ฃผ๊ณ ์์ง๋ง, ๋๋ฌด ํฐ ๊ท๋ชจ๋ก ์ธํด ๋์ ๋ฐฉ๋ฉด์ task์ ์ ์ฉ์ ํ ์ ์๋ค๋ ํฌ๋ํฐ ๋จ์ ์ ๋ณด์ ํ๊ณ ์๋ค.
์ด ๋ ผ๋ฌธ์์๋, knowledge distillation์ ํตํด pre-train๋ ์์ ํฌ๊ธฐ์ LM์ด larger model์ ๋ฒ๊ธ๊ฐ๋ ์ฑ๋ฅ์ ๋ณด์ฌ์ค ์ ์๋ค๋ ๊ฐ๋ฅ์ฑ์ ๋ณด์ฌ์คฌ๋ค. ๊ทธ๋ฆฌ๊ณ ๋ ผ๋ฌธ์ general-purpose pre-trained model์ด larger model์ ๊ฐ์ฉ์ฑ์ ๋ณด์กดํ๋ฉด์, ๋ค์ํ task์ ๋ํด์ ์ข์ ์ฑ๋ฅ์ ๋ณด์ฌ์ฃผ๋๋ก fine-tune๋ ์ ์๋ค๋ ๊ฒ์ ๋ณด์ฌ์คฌ๋ค. ๋ํ, ์ด ๋ชจ๋ธ์ด ๋ชจ๋ฐ์ผ ๊ธฐ๊ธฐ์์๋ ์ถฉ๋ถํ ๋์๊ฐ ์ ์์ ๋งํผ ์๋ค๋ ๊ฒ์ ๋ณด์ฌ์คฌ๋ค.
triple loss๋ฅผ ์ฌ์ฉํจ์ผ๋ก์จ, 40% ์์ ํฌ๊ธฐ์ Transformer๊ฐ ๋ ํฐ Transformer๊ณผ ๋น์ทํ ์ฑ๋ฅ์ ๋ณด์ฌ์ฃผ์๋ค. ablation study์ ๋ฐ๋ฅด๋ฉด triple loss์ ์ธ ๊ฐ์ง ์์๋ ์ข์ ์ฑ๋ฅ์ ๋ด๊ธฐ ์ํด ๋งค์ฐ ์ค์ํ๋ค๋ ๊ฒ์ ์ ์ ์์๋ค.
2. Knowledge distillation
Knowledge distillation์ ์์ถ ๊ธฐ์ ๋ก, compact model์ด 'ํ์'์ ์ ์ฅ์ผ๋ก, '๊ต์ฌ'์ธ larger model ๋๋ ensemble model๋ก๋ถํฐ ์ด๋ค์ ๋ฐฉ์์ ๋ฐ๋ผํ๋ฉด์ ํ์ตํ๋ ๋ฐฉ๋ฒ์ด๋ค.
์ง๋ํ์ต์์๋, ๋ถ๋ฅ ๋ชจ๋ธ์ด ์ต์ ์ ๋ผ๋ฒจ์ ๋ํ ์ธก์ ํ๋ฅ ์ ์ต๋ํ์ํด์ผ๋ก์จ class๋ฅผ ์์ธกํ๋๋ก ํ์ต๋๋ค. ๋ฐ๋ผ์, ๊ธฐ์กด์ ํ๋ จ ๋ชฉ์ ์ ๋ชจ๋ธ์ ์์ธก ๋ถํฌ์ ํ๋ จ ๋ผ๋ฒจ์ ๋ถํฌ ๊ฐ์ cross-entropy๋ฅผ ์ต์ํ์ํค๋ ๊ฒ์ด๋ค. ๋ฐ๋ผ์ ๋ชจ๋ธ์ ํ๋ จ์ ์ ๋ํด์ ๋งค์ฐ ์ ์ํํ๊ฒ ๋๋ค. ๊ฒฐ๊ตญ์ ๋ชจ๋ธ์ ์ณ์ ํด๋์ค์ ๋ํด์๋ ๋์ ํ๋ฅ ๋ก ์์ธกํ๊ณ , ๊ทธ ๋ฐ๋์ ํด๋์ค์ ๋ํด์๋ ๊ฑฐ์ 0์ ๊ฐ๊น์ด ํ๋ฅ ๋ก ์์ธก์ ํ๋ค. ๊ทธ๋ฌ๋ ์ด๋ฌํ "0์ ๊ฐ๊น์ด" ํ๋ฅ ์ค ์ผ๋ถ๋ ๋ค๋ฅธ ํ๋ฅ ๋ณด๋ค ํฌ๋ฉฐ ๋ชจ๋ธ์ ์ผ๋ฐํ ๊ธฐ๋ฅ๊ณผ ํ ์คํธ ์ธํธ์์ ์ผ๋ง๋ ์ ์ํ๋ ๊ฒ์ธ์ง๋ฅผ ๋ถ๋ถ์ ์ผ๋ก ๋ฐ์ํ๋ค.
Training loss
ํ์์ ๊ต์ฌ์ soft target probabilities์ ๋ํด distillation probabilities๋ก ํ์ต๋๋ค: $L_{ce}=\sum_{i} t_{i}*log(s_i)$ ์ฌ๊ธฐ์ $t_i$๋ ๊ต์ฌ์ ์ํด ์ธก์ ๋ ํ๋ฅ ์ด๋ค. ์ด ์์ ์ ์ฒด ๊ต์ฌ์ ๋ถํฌ๋ฅผ ํ์ฉํ์ฌ ํ๋ถํ ํ์ต์ ์ ๊ณตํ๋ค. Hinton์ ๋ฐ๋ผ์ ์ด ๋ ผ๋ฌธ์์๋ softmax-temperature: $p_i=\frac {exp(z_i/T)}{\sum_{j} exp(z_j/T)}$์ ์ฌ์ฉํ์๋ค. ์ฌ๊ธฐ์ $T$๋ ์ถ๋ ฅ ๋ถํฌ์ ๋งค๋๋ฌ์์ ์กฐ์ ํ๊ณ , $z_i$๋ ํด๋์ค $i$์ ๋ํ ๋ชจ๋ธ์ ์ค์ฝ์ด์ด๋ค. ํ๋ จ ์ค์๋ ํ์๊ณผ ๊ต์ฌ์๊ฒ ๋๊ฐ์ temperature $T$๊ฐ ์ ์ฉ๋์ง๋ง, ์ถ๋ก ์ ํ ๋์ $T$๋ ๊ธฐ์กด์ softmax๋ฅผ ํ๋ณตํ๊ธฐ ์ํด์ 1๋ก ๊ณ ์ ๋๋ค.
๋ง์ง๋ง ๋จ๊ณ๋ distillation loss์ธ $L_{ce}$์ ์ด ์ผ์ด์ค์์๋ masked language model์ loss์ธ $l_{mlm}$์ธ ์ง๋ ํ์ต์ ์ค์ฐจ์ ์ ํ ๊ฒฐํฉ์ด๋ค. ๋ ผ๋ฌธ์์๋ ์ฌ๊ธฐ์ cosine embedding loss์ธ $L_{cos}$๋ฅผ ์ถ๊ฐํ๋ ๊ฒ์ด ์ด๋์ด๋ผ๊ณ ํ๋๋ฐ, ์ด๋ ์๋ํ๋ฉด ๊ต์ฌ์ ํ์์ hidden state vector์ ๋ฐฉํฅ์ ์กฐ์ ํด์ฃผ๋ ๊ฒฝํฅ์ด ์๊ธฐ ๋๋ฌธ์ด๋ค.
3. DistilBERT: a distilled version of BERT
Student architecture
์ด์ ์ ์์ ํ๋ ๊ฒ์ฒ๋ผ ํ์์ ์ ์ฅ์ธ DistilBERT๋ BERT์ ๋๊ฐ์ ๊ตฌ์กฐ๋ฅผ ๊ฐ์ง๊ณ ์๋ค. ํ์ง๋ง, Token-type embedding๊ณผ pooler๊ฐ ๋ ๊ฐ์ ์์์ ์ํด layer์ ์๊ฐ ์ค์ด๋ฌ์ ๋ฐ๋ผ์ ์ ๊ฑฐ๋์๋ค. Transformer architecture์์ ์ฌ์ฉ๋๋ ์ฐ์ฐ์๋ค์ ํ๋์ ์ ํ๋์ํ์ ์๋ง๊ฒ ์กฐ์ ๋ ๊ฒ์ด๊ธฐ ๋๋ฌธ์, ๋ณธ ์ฐ๊ตฌ์์๋ ๋ง์ง๋ง ์ฐจ์์ ํ ์์์ ๋ณ๋์ ์ฃผ๋ ๊ฒ์ ๊ณ์ฐ ๋น์ฉ์ ๋ณ๋ก ์ํฅ์ ๋ฏธ์น์ง ์๋๋ค๋ ๊ฒ์ ์์๋ด์๋ค. ๊ทธ๋ณด๋ค๋ layer์ ์์ ๊ฐ์ ์์์ ๋ณ๋์ ์ฃผ๋ ๊ฒ์ด ์ํฅ์ ๋ ๋ผ์น๊ธฐ ๋๋ฌธ์, layer์ ์๋ฅผ ์ค์ด๋ ๊ฒ์ ์ง์คํ์๋ค.
Student initialization
์์์ ์ค๋ช ํ ์ต์ ํ ๋ฐ ์ํคํ ์ฒ ์ ํ ์ธ์๋ ๊ต์ก ์ ์ฐจ์ ์ค์ํ ์์๋ ์๋ ดํ sub-network์ ๋ํ ์ฌ๋ฐ๋ฅธ ์ด๊ธฐํ๋ฅผ ์ฐพ๋ ๊ฒ์ด๋ค. ๊ต์ฌ์ ํ์ ๋คํธ์ํฌ ๊ฐ์ ์ผ๋ฐ์ ์ธ ์ฐจ์์ ์ด์ ์ ์ฑ๊ธฐ๊ธฐ ์ํด, ๋ ผ๋ฌธ์์๋ ๋ ๋ ์ด์ด ์ค ํ ๋ ์ด์ด๋ฅผ ์ทจํ์ฌ ๊ต์ฌ๋ก๋ถํฐ ํ์์ ์ด๊ธฐํํ์๋ค.
Distillation
๋ ผ๋ฌธ์์๋ ์ง๊ธ๊น์ง ๋์จ BERT์ค ๊ฐ์ฅ ์ต๊ณ ์ ์ฑ๋ฅ์ ๋ณด์ฌ์ฃผ๋ ๋ชจ๋ธ์ ์ฌ์ฉํ์๋ค. ๋ฐ๋ผ์ DistilBERT๋ dynamic masking์ ์ฌ์ฉํ๊ณ , next sentence prediction์ ๋ฐฐ์ ํจ์ผ๋ก์จ gradient accumulation์ ํตํด ๋งค์ฐ ํฐ ํฌ๊ธฐ์ ๋ฐฐ์น๋ก๋ถํฐ distill๋๋ค.
DistilBERT์ ๊ธฐ์กด์ BERT์ ์ฑ๋ฅ ๋น๊ต์ ๋ํ ํ๋ ๋ค์๊ณผ ๊ฐ๋ค.
4. Ablation study
์ด ์ ์์๋ triple loss์ ์ธ ๊ฐ์ ์์์ ์ํฅ๋ ฅ๊ณผ distilled model์ student initialization์์์ ์ฑ๋ฅ์ ๋ํด์ ์กฐ์ฌํ์๋ค. ์ด์ ๋ํ ๊ฒฐ๊ณผ๋ ๋ค์์ ํ์ ๋ํ๋์๋ค. ํ๋ฅผ ์ดํด๋ณด๋ฉด, Masked Language Model loss๋ฅผ ์ ๊ฑฐํ๋ ๊ฒ์ด ๋ค๋ฅธ ๋ loss๋ฅผ ์ ๊ฑฐํ๋ ๊ฒ๋ณด๋ค ๋ ์ํฅ์ ๋ผ์น๋ค๊ณ ๊ฒฐ๋ก ์ ๋ด๋ฆด ์ ์๋ค.
์ฐธ๊ณ ๋ฌธํ
https://arxiv.org/abs/1910.01108