The overview of this paper
BERT์ ๊ฐ์ LM pre-training์ ์ฌ๋ฌ NLP task์ ๋ํด ์๋นํ ์ฑ๋ฅ์ ํฅ์์์ผฐ๋ค. ํ์ง๋ง, PLM์ ๋ณดํต ๊ณ์ฐ์ ๋น์ฉ์ด ๋งค์ฐ ๋น์ธ๊ณ , ๊ทธ์ ๋ฐ๋ผ์ ์์์ด ์ ํ๋ ํ๊ฒฝ์์ ์คํํ๋๋ฐ ์ด๋ ค์์ด ์๋ค. ๋
ผ๋ฌธ์์๋ Transformer distillation method๋ฅผ ์ ์ํด์ ์ถ๋ก ์๋๋ฅผ ๋น ๋ฅด๊ฒ ํ๊ณ , ๋ชจ๋ธ ํฌ๊ธฐ๋ ์ค์ด๋ค๊ฒ ํ๊ณ , ๊ทธ ๋์ ์ ์ ํ๋๋ ์ ์ง์์ผฐ๋ค. ์ด Transformer distillation method๋ Transformer ๊ธฐ๋ฐ ๋ชจ๋ธ์ ๋ํด knowledge distillation$($KD$)$์ ์ ์ฉ์์ผฐ๋ค. ์ด๋ฅผ ์ํด ํ๋ถํ ์ง์์ ๊ฐ์ง๊ณ ์๋ ํฐ 'teacher' BERT์์ ์์ 'student' TinyBERT๋ก ์ง์์ ์ ๋ฌํ๋ค. ๊ทธ๋ฆฌ๊ณ ๋
ผ๋ฌธ์์๋ TinyBERT๋ฅผ ์ํ two-stage learning ํ๋ ์์ํฌ๋ฅผ ์ ์ํ์๋๋ฐ, ์ด๊ฒ์ pre-training๊ณผ task-specific learning stage์ Transformer disitllation์ ์ํํ๋ค. ์ด ํ๋ ์์ํฌ๋ TinyBERT๊ฐ BERT์ general-domain๋ฟ๋ง ์๋๋ผ task-specific ์ง์ ๋ํ ์บก์ฒํ ์ ์๋๋ก ๋ณด์ฅํ๋ค.
TinyBERT with 4layers๋ ํจ๊ณผ์ ์ด์๊ณ , ์ฑ๋ฅ์ ์ผ๋ก teacher model๋ณด๋ค ๋ ์ข์ ์ฑ๋ฅ์ ๋ณด์ฌ์คฌ๋ค. ๋ฌด๋ ค 7.5๋ฐฐ ์๊ณ , 9.4๋ฐฐ ๋น ๋ฅธ ์ถ๋ก ์๋๋ฅผ ๋ณด์ฌ์คฌ๋ค.
Table of Contents
1. Introduction
2. Method
3. Experiment Results
4. Ablation Studies
4-1. Effects of Learning Procedure
4-2. Effects of Distillation Objective
5. Effects of Mapping Function
1. Introduction
LM์ pre-trainingํ๊ณ , downstream task์ ๋ํด fine-tuning์ ํ๋ ๊ฒ์ด NLP์์ ํ๋์ ์๋ก์ด ํจ๋ฌ๋ค์์ผ๋ก ๋ถ์ํ๊ณ ์๋ค. ๊ทธ๋ฆฌ๊ณ ์ค์ ๋ก ์ด ๋ฐฉ๋ฒ๋ค์ด NLP ๋ถ์ผ์์ ์ข์ ์ฑ๋ฅ๋ค์ ๋ณด์ฌ์ฃผ๊ณ ์๋ค. ํ์ง๋ง, PLM์ ๋ณดํต ๋ง์ ์์ ํ๋ผ๋ฏธํฐ์ ๊ธด ์ถ๋ก ์๊ฐ์ ๊ฐ์ง๊ณ ์์ด์, ์ด๋ ํธ๋ํฐ๊ณผ ๊ฐ์ edge device์ ์ ์ฉํ๊ธฐ ์ด๋ ต๋ค. ์ต๊ทผ์ ์ฐ๊ตฌ๋ค์ PLM์ ์ฅํฉ์ฑ์ด ์๋ค๋ ์ฌ์ค์ ๋ฐํ๋๋ค. ๋ฐ๋ผ์ ์ฑ๋ฅ์ ์ ์ง์ํค๋ฉด์ PLM์์ ๊ณ์ฐ overhead์ ๋ชจ๋ธ์ ์ ์ฅ ์ฉ๋์ ์ค์ด๋ ๊ฒ์ด ์ค์ํ๋ค.
์ฌ๊ธฐ์๋ ์ ํ๋๋ ์ ์ง์ํค๋ฉด์ ๋ชจ๋ธ์ ํฌ๊ธฐ๋ ์ค์ด๊ณ , ๋ชจ๋ธ์ ์ถ๋ก ์๋๋ฅผ ๊ฐ์ํํ๋ ๋ง์ ๋ชจ๋ธ ์์ถ ๊ธฐ์ ๋ค์ด ์๋ค. ์ด ๋
ผ๋ฌธ์์๋ ๊ทธ ์ค์์๋ teacher-student ํ๋ ์์ํฌ์ knowledge distillation$($KD$)$์ ์ง์คํ์๋ค. KD๋ ๊ฑฐ๋ํ teacher network์์ ์์ student network๋ก ์ง์์ ์ ๋ฌํด์ teacher network์ ํ๋์ ๋ชจ์ฌํ๋ค. ์ด ํ๋ ์์ํฌ์ ๊ธฐ๋ฐํด์, ๋
ผ๋ฌธ์์๋ Transformer ๊ธฐ๋ฐ์ ์๋ก์ด distillation method๋ฅผ ์ ์ํ์๋ค. ๊ทธ๋ฆฌ๊ณ large-scale PLM์ ์ํ method๋ฅผ ์กฐ์ฌํ๊ธฐ ์ํ ์์๋ก BERT๋ฅผ ์ฌ์ฉํ์๋ค.
KD๋ pre-trained LM ๋ฟ๋ง ์๋๋ผ NLP ๋ถ์ผ์์ ๊ด๋ฒ์ํ๊ฒ ํ์ตํ์๋ค. pre-training-then-fine-tuning ํจ๋ฌ๋ค์์ large-scale์ ๋น์ง๋ํ์ต text corpus์์ BERT๋ฅผ ํ์ต์ํค๊ณ , task-specific dataset์์ fine-tuning ํ์๋ค. ์ด๋ BERT distillation์ ์ด๋ ค์์ ๊ต์ฅํ ์ฆ๊ฐ์์ผฐ๋ค. ๋ฐ๋ผ์, ๋ ๊ฐ์ training stage๋ฅผ ์ํด ํจ๊ณผ์ ์ธ KD ์ ๋ต์ด ํ์ํ๋ค.
๊ฒฝ์๋ ฅ ์๋ TinyBERT๋ฅผ ๋ง๋ค๊ธฐ ์ํด, ๋
ผ๋ฌธ์์๋ ์๋ก์ด Transformer Distillation method์ ์ ์ํด์ teacher BERT์์ embedding๋ ์ง์์ ์ฆ๋ฅํ์๋ค. ํน๋ณํ, ๋
ผ๋ฌธ์์๋ BERT layer์ ์๋ก ๋ค๋ฅธ representation์ ๋ํ 3๊ฐ์ loss function์ ์ ์ํ์๋ค.
- embedding layer์ output
- Transformer layer๋ก๋ถํฐ ๋์จ hidden state์ attention ํ๋ ฌ
- prediction layer์ logit output
attention ๊ธฐ๋ฐ์ fitting์ ์ต๊ทผ์ ๋ฐ๊ฒฌ์ ์ํด ์๊ฐ์ ๋ฐ์๋ค. ์ต๊ทผ์ ์ฐ๊ตฌ๋ค์ ํตํด BERT์ ์ํด ํ์ต๋ attention ๊ฐ์ค์น๋ ์๋นํ ์ธ์ด์ ์ ๋ณด๋ฅผ ์บก์ฒํ ์ ์๊ณ , ๋ฐ๋ผ์ teacher BERT๋ก๋ถํฐ student TinyBERT๊น์ง ์ธ์ด์ ์ ๋ณด๊ฐ ์ ์ ๋ฌ๋ ์ ์๋๋ก ์ฅ๋ คํ ์ ์๋ค. ๊ทธ๋์, ๋
ผ๋ฌธ์์๋ ๋ค์์ ๊ทธ๋ฆผ 1๊ณผ ๊ฐ์ด general distillation๊ณผ task-specific distillation์ ํฌํจํ๋ ์๋ก์ด two-stage learning ํ๋ ์์ํฌ๋ฅผ ์ ์ํ์๋ค.
general distillation์ ๋จ๊ณ์์, fine-tuning์ด ๋์ง ์์ ๊ธฐ์กด์ BERT๋ teacher model์ฒ๋ผ ์ฌ์ฉ๋๋ค. student TinyBERT๋ general-domain corpus์ ๋ํด ์ ์๋ Transformer distillation์ ์ฌ์ฉํด์ teacher model์ ํน์ฑ์ ํ๋ด๋ธ๋ค. ๊ทธ ํ์, ์ถ๊ฐ์ ์ธ distillation์ ์ํ student model์ ์ด๊ธฐํ ์ฒ๋ผ ์ฌ์ฉ๋๋ general TinyBERT๋ฅผ ์ป์ ์ ์๋ค. task-specific distillation ๋จ๊ณ์์๋, ๋จผ์ data augmentation์ ์ํํ๊ณ , fine-tuned BERT๋ฅผ teacher model์ฒ๋ผ ์ฌ์ฉํ์ฌ augmented dataset์ ๋ํด distillation์ ์ํํ๋ค. ์ด๋ two stage๊ฐ TinyBERT์ ์ฑ๋ฅ๊ณผ ์ ๊ทํ๋ฅผ ํฅ์์ํค๋๋ฐ ํ์์ ์ด๋ผ๋ ๊ฒ์ ๋ณด์ฌ์ค๋ค.
์ด ๋
ผ๋ฌธ์ ์ฃผ๋ contribution์ ๋ค์๊ณผ ๊ฐ๋ค.
- ์๋ก์ด Transformer distillation method ์ ์. teacher BERT์ ์ธ์ด์ ์ ๋ณด๊ฐ ์ ์ ํ๊ฒ TinyBERT๋ก ์ ๋ฌ๋จ.
- two-stage learning ํ๋ ์์ํฌ ์ ์. pre-training & fine-tuning
- ↓ parameter & ↓ ์ถ๋ก ์๊ฐ์๋ ๋ถ๊ตฌ ์ค์ํ ์ฑ๋ฅ์ ๋ณด์ฌ์ค.
2. Method
์ด ์น์
์์, ๋
ผ๋ฌธ์์๋ Transformer model์ ์ํ ์๋ก์ด distillation method๋ฅผ ์ ์ํ๊ณ , BERT๋ก๋ถํฐ ์ฆ๋ฅ๋ ๋
ผ๋ฌธ์ ๋ชจ๋ธ์ ๋ํด two-stage learning ํ๋ ์์ํฌ๋ฅผ ์ ์ฉํ TinyBERT๋ฅผ ๋ณด์ฌ์คฌ๋ค.
2-1. Transformer Distillation
์ ์๋ Transformer Distillation์ Transformer network๋ฅผ ์ํด ํน๋ณํ ๋์์ธ๋ KD method์ด๋ค. ์ด ์์
์์ student์ teacher network ๋ชจ๋๋ Transformer layer๋ก ๋ง๋ค์ด์ก๋ค. ๋ช
ํํ ์ค๋ช
์ ์ํด ๋
ผ๋ฌธ์์๋ method๋ฅผ ์ค๋ช
ํ๊ธฐ ์ ์ ๋ฌธ์ ๋ถํฐ ๊ณ ์ํ์๋ค.
Problem Formulation $N$๊ฐ์ layer๋ฅผ ๊ฐ์ง teacher ๋ชจ๋ธ๊ณผ $M$๊ฐ์ layer๋ฅผ ๊ฐ์ง student ๋ชจ๋ธ์ด ์๋ค๊ณ ๊ฐ์ ํด๋ณด์. Transformer distillation์ ์ํด teacher model์์ N๊ฐ ์ค M๊ฐ์ Transformer layer๋ฅผ ์ ํํ์. ๊ทธ ๋ค์์ student layer๊ณผ teacher layer ๊ฐ์ ์ธ๋ฑ์ค๋ฅผ ๋งคํํ๋ ํจ์์ธ $n = g(m)$์ ์ ์ํ์๋ค. ์ด๊ฒ์ student model์ $m$๋ฒ์งธ layer๋ teacher model์ $g(m)$๋ฒ์งธ layer๋ก๋ถํฐ ์ ๋ณด๋ฅผ ๋ฐ์์ ํ์ตํ๋ค๋ ์๋ฏธ์ด๋ค. ์ฌ๊ธฐ์ $m = 0$์ index embedding์ ์๋ฏธํ๊ณ , $M + 1$์ prediction layer๋ฅผ ์๋ฏธํ๋ค. ๊ฐ๊ฐ์ $0 = g(0)$๊ณผ $N + 1 = g(M + 1)$์ ์๋ฏธํ๋ค. ์ด๋ฅผ ๊ณต์ํํ๋ฉด student model์ teacher model๋ก๋ถํฐ ๋ค์์ ์์ ์ต์ํํจ์ผ๋ก์จ ์ง์์ ์ป์ ์ ์๋ค.
์ฌ๊ธฐ์ $\mathfrak{L}_{layer}$๋ model layer๊ฐ ์ฃผ์ด์ง ๋์ loss function์ ๋ปํ๊ณ , $f_{m}(x)$๋ behavior function์ผ๋ก $m$๋ฒ์งธ layer๋ก๋ถํฐ ์ ๋๋์๊ณ , $\lambda_{m}$์ $m$ ๋ฒ์งธ layer์ distillation์ ์ค์๋๋ฅผ ๋ํ๋ด๋ ํ๋ผ๋ฏธํฐ์ด๋ค.
Transformer-layer Distillation ๋
ผ๋ฌธ์์ ์ ์๋ Transformer-layer distillation์ attention based distillation๊ณผ hidden states based distillation๋ฅผ ํฌํจํ๊ณ ์๋๋ฐ, ์ด๊ฒ ๊ทธ๋ฆผ 2์ ๋ํ๋์๋ค. attention based distillation์ BERT์ ์ํด ํ์ต๋ attention ๊ฐ์ค์น๊ฐ ์ธ์ด์ ์ ๋ณด๋ฅผ ์บก์ฒํ ์ ์๋ค๋ ์ฌ์ค์ ์๊ฐ์ ๋ฐ์๋ค. ์ด๋ฌํ ์ ํ์ ์ธ์ด์ ์ ๋ณด๋ syntax์ coreference ์ ๋ณด๋ฅผ ํฌํจํ๊ณ ์๊ณ , ์ด๋ ์์ฐ์ด ์ดํด์ ํ์์ ์ด๋ค. ๋ฐ๋ผ์ ๋
ผ๋ฌธ์์๋ ์ธ์ด์ ์ ๋ณด๊ฐ teacher๋ก๋ถํฐ student๋ก ์ ๋ฌ๋๋ attention based distillation์ ์ ์ํ์๋ค. student๋ teacher network์ multi-head attention ํ๋ ฌ์ ๋ง์ถ๊ธฐ ์ํด ํ์ต๋๊ณ , objective๋ ๋ค์๊ณผ ๊ฐ์ด ์ ์๋๋ค.
์ฌ๊ธฐ์ $h$๋ attention head์ ์๋ฅผ ์๋ฏธํ๊ณ , $\textbf{A}_{i | \in \mathbb{R}^{l \times l}}$๋ teacher ๋๋ student์ $i$๋ฒ์งธ head์ ํด๋นํ๋ attention ํ๋ ฌ์ด๋ค. $l$๋ input text์ ๊ธธ์ด์ด๊ณ , $MSE()$๋ mean squared error ์์คํจ์์ด๋ค. ์ด ์์
์์ attention ํ๋ ฌ $\textbf{A}_{i}$๋ softmax ์ถ๋ ฅ์ธ $softmax(\textbf{A}_{i})$ ๋์ ์ fitting target์ผ๋ก ์ฌ์ฉ๋๋ค. ์๋ํ๋ฉด, ๋
ผ๋ฌธ์ ์คํ์์ ์ด๋ฌํ ์ธํ
์ ๋ ๋น ๋ฅด๊ฒ ์๋ ดํ๊ณ ๋ ๋์ ์ฑ๋ฅ์ ๋ณด์ฌ์ค๋ค๊ณ ๋ณด์ฌ์คฌ๊ธฐ ๋๋ฌธ์ด๋ค.
๊ทธ๋ฆฌ๊ณ attention based distillation์ ๋ํด ์ถ๊ฐ์ ์ผ๋ก Transformer layer์ ์ถ๋ ฅ์ผ๋ก๋ถํฐ ๋์จ ์ง์์ ์ฆ๋ฅํ์๊ณ , objective๋ ๋ค์๊ณผ ๊ฐ๋ค.
์ฌ๊ธฐ์ ํ๋ ฌ $\textbf{H}^{S} \in \mathbb{l \times d^{'}}$์ $\textbf{H}^{T} \in \mathbb{l \times d}$๋ ๊ฐ๊ฐ student์ teacher์ hidden state๋ฅผ ์๋ฏธํ๊ณ , ์ด๋ Transformer์ Position-wise Feed-Forward Network์ ์ํด ๊ณ์ฐ๋๋ค. ์ค์นผ๋ผ ๊ฐ์ธ $d$์ $d^{'}$๋ teacher์ student์ hidden size๋ฅผ ์๋ฏธํ๊ณ $d^{'}$๋ ๋ณดํต $d$๋ณด๋ค ์์ ๊ฐ์ ๊ฐ์ง๋๋ฐ, ์ด๋ ๋ ์์ student network๋ฅผ ๊ฐ๊ธฐ ์ํจ์ด๋ค. ํ๋ ฌ $\textbf{W}_{h} \in \mathbb{R}^{d^{'} \times d}$๋ ํ์ต ๊ฐ๋ฅํ ์ ํ ๋ณํ ํจ์์ด๊ณ , ์ด๋ student network์ hidden state๋ฅผ teacher network์ state๋ก ๋๊ฐ์ด ๋ณํํ๋ค.
Embedding-layer Distillation hidden states based distillation๊ณผ ๋น์ทํ๊ฒ, ๋
ผ๋ฌธ์์๋ embedding-layer distillation์ ์ํํ์๊ณ , objective๋ ๋ค์๊ณผ ๊ฐ๋ค.
์ฌ๊ธฐ์ ํ๋ ฌ $\textbf{E}^{S}$์ $\textbf{H}^{T}$๋ ๊ฐ๊ฐ student์ teacher์ embedding์ด๋ค. ๋
ผ๋ฌธ์์๋ ์ด๋ค์ด hidden state ํ๋ ฌ์ฒ๋ผ ๋๊ฐ์ ๋ชจ์์ ๊ฐ์ง๊ณ ์๋ค. ํ๋ ฌ $\textbf{W}_{e}$๋ $\textbf{W}_{h}$์ ๋น์ทํ ์ญํ ์ ์ ํ ๋ณํ์ด๋ค.
Prediction-layer Distillation ์ค๊ฐ ๋ ์ด์ด์ ํน์ฑ์ ๋ชจ๋ฐฉํ๊ธฐ ์ํด, ๋
ผ๋ฌธ์์๋ teacher model์ prediction์ KD๋ฅผ ์ ์ฉํ์๋ค. ๋
ผ๋ฌธ์์๋ student logit๊ณผ teacher logit ๊ฐ์ soft cross-entropy loss๋ฅผ ์ ์ฉํ์๋ค.
์ฌ๊ธฐ์ $z^{S}$์ $z^{T}$๋ ๊ฐ๊ฐ student์ teacher์ ์ํด ์์ธก๋๋ logit vector์ด๋ค. CE๋ cross entropy loss๋ฅผ ์๋ฏธํ๊ณ $t$๋ temperature ๊ฐ์ ์๋ฏธํ๋ค. ์ด $t$๊ฐ์ $t = 1$์ด ๊ฐ์ฅ ์ ๋นํ๋ค.
์์ distillation objective๋ค์ ์ด์ฉํด์, ๋
ผ๋ฌธ์์๋ ๊ฐ teacher์ student ๊ฐ์ ๋ ์ด์ด์ ํด๋นํ๋ distillation loss๋ฅผ ๋ค์๊ณผ ๊ฐ์ด ํตํฉํ์๋ค.
2-2. TinyBERT Learning
BERT๋ ๋ณดํต ๋ ๊ฐ์ learning stage๋ก ๊ตฌ์ฑ๋์ด ์๋ค: pre-training & fine-tuning. pre-training ๋จ๊ณ์์ BERT์ ์ํด ํ์ต๋ ํ๋ถํ ์ง์์ ๋งค์ฐ ์ค์ํ๊ณ ์์ถ๋ ๋ชจ๋ธ๋ก ์ ๋ฌ๋์ด์ผ๋ง ํ๋ค. ๊ทธ๋์, ๋
ผ๋ฌธ์์๋ general distillation๊ณผ task-specific distillation์ ํฌํจํ๋ ์๋ก์ด two-stage learning ํ๋ ์์ํฌ๋ฅผ ๊ทธ๋ฆผ 1๊ณผ ๊ฐ์ด ์ ์ํ์๋ค. general distillation์ pre-trained BERT์ ํ๋ถํ embedded ์ง์์ ํ์ตํ๊ฒ ํ๊ณ , ์ด๋ TinyBERT์ ์ ๊ทํ ๋ฅ๋ ฅ์ ํฅ์์ํค๋๋ฐ ์ค์ํ ์ญํ ์ ํ๋ค. task-specific distillation์ fine-tuned TinyBERT๋ก๋ถํฐ ์ถ๊ฐ์ ์ผ๋ก TinyBERT๋ฅผ ํ์ต์ํจ๋ค. two-step distillation๊ณผ ํจ๊ป, ๋
ผ๋ฌธ์์๋ teacher๊ณผ student์ ๊ฐญ์ ์๋นํ ์ค์๋ค.
General Distillation ๋
ผ๋ฌธ์์๋ fine-tuning์ด ๋์ง ์์ BERT๋ฅผ teacher๋ก ์ฌ์ฉํ๊ณ , large-scale text corpus๋ฅผ ํ์ต ๋ฐ์ดํฐ๋ก ์ฌ์ฉํ์๋ค. general domain์ text์ ๋ํด Transformer distillation์ ์ํํจ์ผ๋ก์จ, downstream task์ ๋ํด fine-tuneํ ์ ์๋ TinyBERT๋ฅผ ์ป์๋ค. ํ์ง๋ง, hidden/embedding ํฌ๊ธฐ์ layer ์์ ์๋นํ ๊ฐ์๋ก ์ธํด general TinyBERT๋ BERT์ ๋นํด ์ข์ง ์์ ์ฑ๋ฅ์ ๋ณด์ฌ์ค๋ค.
Task-specific Distillation ์ด์ ์ ์ฐ๊ตฌ๋ค์์ fine-tuned BERT ๊ฐ์ด ๋ณต์กํ ๋ชจ๋ธ๋ค์ domain-specific task์ ๋ํด์ over-parameterization์ ๊ฒช๊ณ ์์์ ๋ฐํ๋๋ค. ๋ฐ๋ผ์, BERT๋ณด๋ค ์์ ๋ชจ๋ธ๋ BERT์ ํ์ ํ๋ ์ฑ๋ฅ์ ๋ผ ์ ์๋ค๋ ๊ฐ๋ฅ์ฑ์ ๋ณด์ฌ์ค๋ค. ์ด๋ฅผ ์ํด, ๋
ผ๋ฌธ์์๋ task-specific distillation์ ํตํด ๊ฒฝ์๋ ฅ ์๋ fine-tuned TinyBERT๋ฅผ ์ ์ํ์๋ค. task-specific distillation์์, ์ ์๋ augmented task-specific dataset์ ๋ํด์ Transformer distillation์ ์ฌ์ํํ์๋ค. fine-tuned BERT๊ฐ teacher๋ก ์ฌ์ฉ๋๊ณ , data augmentation method๋ task-specific ํ์ต ์
์ ๋๋ฆฌ๊ธฐ ์ํด ์ ์๋์๋ค. ๋ ๋ง์ task๊ด๋ จ example๊ณผ ํจ๊ป ํ์ตํ๋ฉด, student model์ ์ ๊ทํ ๋ฅ๋ ฅ์ ์ถ๊ฐ์ ์ผ๋ก ํฅ์์ํฌ ์ ์๋ค.
Data Augmentation ๋
ผ๋ฌธ์์๋ pre-trained LM์ธ BERT์ GloVe word embedding์ ํฉ์ณ์ data augmentation์ ์ํ ๋์ฒด๋ฅผ ํ์๋ค. ๊ทธ๋ฆฌ๊ณ LM์ ์ฌ์ฉํ์ฌ single-piece ๋จ์ด์ ๋จ์ด ๋์ฒด๋ฅผ ์์ธกํ๊ณ , word embedding์ ์ฌ์ฉํด์ multiple-peices ๋จ์ด ๋์ฒด๋ฅผ ์ํด ๊ฐ์ฅ ๋น์ทํ ๋จ์ด๋ฅผ ๋์ฐพ๋๋ค. ๋ช๋ช ํ์ดํผ ํ๋ผ๋ฏธํฐ๋ค์ ๋ฌธ์ฅ์ ๋์ฒด ๋น์จ๊ณผ augmented dataset์ ์์ ์กฐ์ ํ๊ธฐ ์ดํด ์ ์๋๋ค. data augmentation ํ๋ก์์ ์ ์ถ๊ฐ์ ๋ํ
์ผ์ ์๊ณ ๋ฆฌ์ฆ 1์ ๋์ ์๋ค. ์ฌ๊ธฐ์ $p_{t}=0.4, N_{a}=20, K=15$์ด๋ค.
์์ ๋ learning stage๋ ์๋ก ์ํธ๋ณด์์ ์ด๋ค. general distillation์ task-specific distillation์ ์ํ ์ข์ ์ด๊ธฐํ๋ฅผ ์ ๊ณตํ๋ ๋ฐ๋ฉด, augmented data์์์ task-specific distillation์ task-specificํ ์ง์์ ํ์ตํ๋ ๊ฒ์ ์ง์คํจ์ผ๋ก์จ TinyBERT๋ฅผ ์ถ๊ฐ์ ์ผ๋ก ํฅ์์์ผฐ๋ค. model size์ ์๋นํ ๊ฐ์์๋ ๋ถ๊ตฌํ๊ณ , data augmentation๊ณผ ํจ๊ป ์ ์๋ Transformer distillation์ pre-training๊ณผ fine-tuning stage์ ์ํํจ์ผ๋ก์จ, TinyBERT๋ ๋ค์ํ NLP task์ ๋ํด์ ๊ฒฝ์๋ ฅ ์๋ ์ฑ๋ฅ์ ๋ฌ์ฑํ์๋ค.
3. Experiment Results
๋ ผ๋ฌธ์์๋ ๋ ผ๋ฌธ์ model์ GLUE ๋ฐ์ดํฐ์ ์ test set์ ๋ํด ์คํ์ ์งํํ์๋ค. ๊ทธ ๊ฒฐ๊ณผ๊ฐ ๋ค์์ ํ 1์ ๋ํ๋ ์๋ค.
4-layer student model์ ๋ํ ์คํ์ ๊ฒฐ๊ณผ๋ ๋ค์๊ณผ ๊ฐ๋ค.
- $BERT_{tiny}$์ $BERT_{base}$ ๊ฐ์ ํฐ ์ฑ๋ฅ์ ๊ฐญ์ด ์์์. ์ด๋ ๊ทน์ ์ธ model size์ ๊ฐ์ ๋๋ฌธ
- TinyBERT๊ฐ $BERT_{tiny}$๋ณด๋ค ํจ์ฌ ๋์. ์ ์๋ KD๊ฐ ์ฑ๋ฅ ํฅ์์ ํจ๊ณผ์ ์ด์์.
- TinyBERT๊ฐ KD baseline๋ณด๋ค ์ฑ๋ฅ ๋ฉด์์ ์ฐ์
- teacher $BERT_{base}$์ ๋น๊ตํด์ TinyBERT๊ฐ ๋ ์๊ณ , ๋ ๋นจ๋์
- CoLA ๋ฐ์ดํฐ์ ์ ๋ํด 4-layer-distilled model์ teacher model์ ๋นํด ํฐ ์ฑ๋ฅ ๊ฐญ์ ๋ณด์ฌ์คฌ์. TinyBERT๋ ์๋นํ ์ฑ๋ฅ ํฅ์์ ๋ณด์ฌ์คฌ์.
- ์ ์ layer ์์๋ ๋ถ๊ตฌํ๊ณ TinyBERT๋ Mobile BERT-24layer์ ๋น์ทํ score๋ฅผ ๋ฐ์
- capacity๋ฅผ ๋๋ฆฌ๋ฉด teacher์ ํ์ ํ๋ ์ฑ๋ฅ์ ์ป์ ์ ์์ ์ ๋๋ก ์ฑ๋ฅ์ด ํฅ์ํจ
- TinyBERT์ general distillation ํ์ task-specific stage๋ฅผ ๊ฒช์. ์ด๋ ๋ค๋ฅธ KD ๋ชจ๋ธ๋ค๊ณผ ๋ฐ๋๋๋ ์์
๋
ผ๋ฌธ์ two-stage distillation ํ๋ ์์ํฌ์์ TinyBERT๋ general distillation์ ํตํด ์ด๊ธฐํ๋๋ฉด์, ๋ชจ๋ธ์ ๊ตฌ์กฐ๋ฅผ ์ ํํ๋๋ฐ ๋์ฑ ์ ๋์ ์ด๊ฒ ๋ง๋ค์ด์ฃผ์๋ค.
4. Ablation Studies
์ด ์น์
์์๋ ๋ค์์ ๋ ๊ฐ์ง contribution์ ๋ํด abalation study๋ฅผ ์งํํ์๋ค.
- two-stage TinyBERT ํ์ต ํ๋ ์์ํฌ์ ์๋ก ๋ค๋ฅธ ํ๋ก์์
- ์๋ก ๋ค๋ฅธ distillation objective
4-1. Effects of Learning Procedure
two-stage TinyBERT ํ์ต ํ๋ ์์ํฌ๋ 3๊ฐ์ ์ค์ํ ํ๋ก์์ ๋ก ๊ตฌ์ฑ๋์ด ์๋ค: GD$($general distillation$)$, TD$($task-specific distillation$)$, DA$($data augmentation$)$. ๊ฐ๊ฐ์ ํ์ต ํ๋ก์์ ๋ฅผ ์ ๊ฑฐํ์ ๋ ์ฑ๋ฅ์ ์ฐจ์ด๋ฅผ ๋ถ์ํ๋๋ฐ, ์ด ๊ฒฐ๊ณผ๊ฐ ๋ค์์ ํ 2์ ๋ํ๋ ์๋ค. ๊ฒฐ๊ณผ์ ์ผ๋ก 3๊ฐ์ ํ์ต ํ๋ก์์ ๋ชจ๋ ์ค์ํ์ง๋ง, TD์ DA๋ ๋น์ทํ ์ฑ๋ฅ ํฅ์์ ๋ณด์ฌ์ฃผ๋ ๋ฐ๋ฉด์, GD๊ฐ ์ฑ๋ฅ์ ๊ฐ์ฅ ํฐ ์ํฅ์ ์ฃผ๊ณ ์๋ค.
4-2. Effects of Distillation Objective
TinyBERT ํ์ต์ ๋ํ distillation objective์ ํจ๊ณผ๋ฅผ ์กฐ์ฌํ์๋ค. ๋ค์ํ baseline๋ค์ ๋ํด Transformer-layer distillation์ด ์๋ ํ์ต๊ณผ embedding-layer distillation์ด ์๋ ํ์ต, prediction-layer distillation์ด ์๋ ํ์ต์ ๊ฐ๊ฐ ์กฐ์ฌํ์๋ค. ์ด์ ๋๋ ๊ฒฐ๊ณผ๊ฐ ๋ค์์ ํ 3์ ๋ํ๋ ์๋ค. ๊ฒฐ๊ณผ๋ฅผ ์ดํด๋ณด๋ฉด, Transformer-layer distillation์ด ์๋ ๊ฒฝ์ฐ ์๋นํ ์ฑ๋ฅ ๊ฐ์๋ฅผ ๋ณด์ฌ์ค๋ค๋ ๊ฒ์ ์ ์ ์๋ค. ์ด๋ student model์ ์ด๊ธฐํ์ ์ฐ๊ณผ๋์ด ์๊ธฐ ๋๋ฌธ์ด๋ค. ์ฌ๊ธฐ์ attention based distillation์ด hidden states based distillation์ด ๋ ๊ฐํ ์ํฅ๋ ฅ์ ๊ฐ์ง๊ณ ์์์ ์ ์ ์์๋ค.
5. Effects of Mapping Function
๋
ผ๋ฌธ์์๋ TinyBERT ํ์ต์ ๋ํด ์๋ก ๋ค๋ฅธ mapping function $n = g(m)$์ ํจ๊ณผ์ ๋ํด์๋ ์กฐ์ฌํ์๋ค. ๊ธฐ์กด TinyBERT์ uniform ์ ๋ต๊ณผ ๋ ๊ฐ์ ์ ํ์ ์ธ baseline์ธ top strategy$(g(m) = m + N - M; 0 < m \leq M)$์ bottom-strategy$(g(m) =m; 0 < m \leq M)$์ ๋น๊ตํ์๋ค.
๋น๊ต ๊ฒฐ๊ณผ๋ ๋ค์์ ํ 4์ ๋ํ๋ ์๋ค. ๊ทธ ๊ฒฐ๊ณผ top-strategy๊ฐ bottom-strategy๋ณด๋ค MNLI์ ๋ํด ๋ ์ข์ ์ฑ๋ฅ์ ๋ณด์ฌ์คฌ์ง๋ง, MRPC์ CoLA์ ๋ํด์๋ ๋ ์ ์ข์ ์ฑ๋ฅ์ ๋ณด์ฌ์คฌ๋ค. ์ด๋ ์๋ก ๋ค๋ฅธ BERT layer๋ก๋ถํฐ์ ์ง์์ ๋ํด ์์กดํ๋ ์๋ก ๋ค๋ฅธ task์ ๋ํ ๊ด์ฐฐ์ ์
์ฆํ์๋ค. ๊ทธ๋ฆฌ๊ณ uniform strategy๊ฐ ๋ค๋ฅธ ๋ baseline๋ณด๋ค ์ข์ ์ฑ๋ฅ์ ๋ณด์ฌ์คฌ๋ค.
์ถ์ฒ
https://arxiv.org/abs/1909.10351