Paper Reading ๐Ÿ“œ/Natural Language Processing

KD์— ์‚ด์ง์˜ ๋ณ€ํ™”๋ฅผ ์ค˜๋ณด์ž!! ๐Ÿ˜œ - Knowledge Distillation of Large Language Models ๋…ผ๋ฌธ ๋ฆฌ๋ทฐ

Cartinoe 2023. 6. 22. 12:46

The overview of this paper

 ์ด์ „์˜ KD๋Š” ์ฃผ๋กœ black-box model API๋ฅผ ๋ชจ๋ฐฉํ•˜๊ธฐ ์œ„ํ•ด white-box ๋ถ„๋ฅ˜ ๋ชจ๋ธ ๋˜๋Š” small model์„ ํ•™์Šต์‹œํ‚ค๋Š”๋ฐ ์ ์šฉ๋œ๋‹ค. white-box ์ƒ์„ฑ LLM์œผ๋กœ๋ถ€ํ„ฐ ์–ด๋–ป๊ฒŒ ํšจ๊ณผ์ ์œผ๋กœ distill ํ•˜๋Š”์ง€๋Š” ์•„์ง under-explore ๋˜์–ด ์žˆ๋‹ค.

 

 ์ด ๋…ผ๋ฌธ์—์„œ๋Š” forward KLD๋ฅผ reverse KLD๋กœ ๋Œ€์ฒดํ•จ์œผ๋กœ์จ ์ƒ์„ฑ์  larger LM์œผ๋กœ๋ถ€ํ„ฐ smaller LM์„ distill ํ•˜๋Š” MiniLLM์„ ์†Œ๊ฐœํ•˜์˜€๋‹ค. ์ด๊ฒƒ์€ student model์ด teacher ๋ถ„ํฌ์˜ low-probability ์˜์—ญ์„ ๊ณผ๋„ํ•˜๊ฒŒ ํ‰๊ฐ€ํ•˜๋Š” ๊ฒƒ์œผ๋กœ๋ถ€ํ„ฐ ๋ชจ๋ธ์„ ๋ณดํ˜ธํ•˜๊ธฐ ๋•Œ๋ฌธ์— ์ƒ์„ฑ์  LM์— ๋”์šฑ ์ ํ•ฉํ•œ LM์ด๋‹ค. MiniLLM์€ ์ „๋ฐ˜์ ์œผ๋กœ ๋†’์€ ํ€„๋ฆฌํ‹ฐ, ๋‚ฎ์€ bias ๋…ธ์ถœ, ๋” ๋‚ฎ์€ calibration, ๋” ์ข‹์€ long-text ์ƒ์„ฑ ์„ฑ๋Šฅ์„ ๊ฐ€์ง€๋Š” ๋”์šฑ ์ •ํ™•ํ•œ ์‘๋‹ต์„ ์ƒ์„ฑํ•ด ๋‚ธ๋‹ค.

 

 

Table of Contents

1. Introduction

2. Methods

3. Experiments

 

 

1. Introduction

 KD๋Š” ๋งŽ์€ ์–‘์˜ computing ์ž์›์— ๋Œ€ํ•œ ํ•„์š”๋ฅผ ๊ฐ์†Œ์‹œํ‚จ๋‹ค. ์ด KD์—๋Š” ๋‹ค์Œ์˜ 2๊ฐ€์ง€ ์ข…๋ฅ˜๊ฐ€ ์žˆ๋‹ค.

 

  1. black-box KD: teacher ์˜ˆ์ธก๋งŒ ์‚ฌ์šฉ ๊ฐ€๋Šฅ
  2. white-box KD: teacher ํŒŒ๋ผ๋ฏธํ„ฐ๋งŒ ์‚ฌ์šฉ ๊ฐ€๋Šฅ

 

 student model์€ white-box KD๋กœ๋ถ€ํ„ฐ ๋” ๋งŽ์€ ๊ฒƒ์„ ๋ฐฐ์šธ ์ˆ˜ ์žˆ์ง€๋งŒ, ์ง€๊ธˆ๊นŒ์ง€๋Š” white-box KD w/ ์ƒ์„ฑ LLM์ด ๋ณ„๋กœ ํƒ๊ตฌ๋˜์ง€ ์•Š์•˜๋‹ค.

 

 ์ด ๋…ผ๋ฌธ์—์„œ๋Š” LLM์˜ white-box KD๋ฅผ ์กฐ์‚ฌํ•˜์˜€๋‹ค. ๋…ผ๋ฌธ์—์„œ๋Š” ๊ธฐ์กด์˜ KLD objective๋Š” forward KLD๊ฐ€ ์‚ฌ์šฉ๋˜์—ˆ์ง€๋งŒ, ์ด๋Š” ๋ถ„๋ฅ˜ task์—์„œ๋Š” $p(y|x)$์™€ $q_{\theta}(y|x)$๊ฐ€ ์ ์€ ์ˆ˜์˜ ๋ชจ๋“œ๋ฅผ ๊ฐ€์ง€๊ธฐ ๋•Œ๋ฌธ์— ๋ฌธ์ œ๊ฐ€ ์—†์—ˆ๋‹ค. ํ•˜์ง€๋งŒ open text generation task๋Š” output space๊ฐ€ ๋”์šฑ ๋ณต์žกํ•˜๊ณ  $p(y|x)$๋Š” $q_{\theta}(y|x)$๊ฐ€ ํ‘œํ˜„ํ•  ์ˆ˜ ์žˆ๋Š” ๋ชจ๋“œ๋ณด๋‹ค ๋” ๋งŽ์€ ๋ชจ๋“œ๋ฅผ ๊ฐ€์ง„๋‹ค. ๋”ฐ๋ผ์„œ forward KLD๋ฅผ ์ตœ์†Œํ™”ํ•˜๋Š” ๊ฒƒ์€ ํ•ฉ๋‹นํ•˜์ง€ ์•Š๋‹ค.

 

 ์ด๋Ÿฌํ•œ ๋ฌธ์ œ๋ฅผ ์™„ํ™”ํ•˜๊ธฐ ์œ„ํ•ด reverse KLD๋ฅผ ์ตœ์†Œํ™”ํ•˜๋Š” ๊ฒƒ์„ ์ œ์•ˆํ•˜์˜€๋‹ค. $KL[q_{\theta}||p]$๋ฅผ ์ตœ์†Œํ™”ํ•˜๋Š” ๊ฒƒ์€ $q_{\theta}$๊ฐ€ $p$์˜ ์ฃผ์š” ๋ชจ๋“œ๋ฅผ ๊ฒ€์ƒ‰ํ•˜๊ฒŒ ํ•ด์ฃผ๊ณ , $p$์˜ ๋น„์–ด ์žˆ๋Š” ๊ณต๊ฐ„์—๋Š” ๋‚ฎ์€ ํ™•๋ฅ ๊ฐ’์„ ์ง€์ •ํ•ด์ค€๋‹ค. ๋˜ํ•œ $min_{\theta} KL[q_{\theta}||p]$๋ฅผ ์ตœ์ ํ™”ํ•˜๊ธฐ ์œ„ํ•ด Policy Gradient๋ฅผ ์‚ฌ์šฉํ•ด์„œ objective์˜ ๊ธฐ์šธ๊ธฐ๋ฅผ ๊ตฌํ•˜์˜€๋‹ค. ๋…ผ๋ฌธ์—์„œ๋Š” ๋ชจ๋ธ์„ ํ•™์Šต์‹œํ‚ค๋Š”๋ฐ ๋†’์€ ๋ถ„์‚ฐ, reward ํ•ดํ‚น, ์ƒ์„ฑ ๊ธธ์ด bias ๋“ฑ์˜ ๋ฌธ์ œ๋ฅผ ๊ฒช๋Š” ๊ฒƒ์„ ๋ฐœ๊ฒฌํ•˜์˜€๊ณ , ์ด๋ฅผ ๋‹ค์Œ๊ณผ ๊ฐ™์ด ํ•ด๊ฒฐํ•˜์˜€๋‹ค.

 

  1. ๋ถ„์‚ฐ์„ ์ค„์ด๊ธฐ ์œ„ํ•ด sigle-step ์ •๊ทœํ™” ์ง„ํ–‰
  2. reward ํ•ดํ‚น์„ ์™„ํ™”ํ•˜๊ธฐ ์œ„ํ•ด teacher-mixed ์ƒ˜ํ”Œ๋ง ์ง„ํ–‰
  3. ๊ธธ์ด bias๋ฅผ ์ œ๊ฑฐํ•˜๊ธฐ ์œ„ํ•ด length ์ •๊ทœํ™” ์ง„ํ–‰

 

๊ทธ๋ฆผ 1. ํ‰๊ฐ€ ์„ธํŠธ์—์„œ ํ‰๊ท  GPT-4 ํ”ผ๋“œ๋ฐฑ ์Šค์ฝ”์–ด ์ธก๋ฉด์—์„œ seqKD์™€ MiniLLM์˜ ๋น„๊ต

 

2. Methods

 ๋…ผ๋ฌธ์—์„œ๋Š” KD๋ฅผ teacher ๋ชจ๋ธ ๋ถ„ํฌ์™€ student model ๋ถ„ํฌ ๊ฐ„์˜ ์ฐจ์ด๋ฅผ ์ตœ์†Œํ™”ํ•˜๋Š” ์ตœ์ ํ™” ๋ฌธ์ œ๋กœ ๊ณ ์•ˆํ•˜์˜€๋‹ค. ๊ธฐ์กด์— ์‚ฌ์šฉ๋˜๋˜ forward KLD $KL[p||q_{\theta}]$๋Š” $q_{\theta}$๊ฐ€ $p$์˜ ๋ชจ๋“  ๋ชจ๋“œ๋ฅผ ์ปค๋ฒ„ํ•  ์ •๋„๋กœ ์ถฉ๋ถ„ํžˆ ํ‘œํ˜„๋ ฅ์ด ์žˆ์„ ๋•Œ language generation task์—์„œ $p$์˜ ๋น„์–ด ์žˆ๋Š” ์˜์—ญ์„ ๊ณผ๋Œ€ ํ‰๊ฐ€ํ•˜๋Š” ๋ชจ์Šต์„ ๋ณด์˜€๋‹ค.

 

2-1. MiniLLM: Knowledge Distillation with Reverse KLD

 

 ๋…ผ๋ฌธ์—์„œ๋Š” student & teacher ๋ถ„ํฌ ๊ฐ„์˜ reverse KLD๋ฅผ ์ตœ์†Œํ™”ํ•˜๋Š” ๊ฒƒ์„ MiniLLM์˜ learning objective๋กœ ์‚ผ์•˜๋‹ค.

 

 

 $q_{\theta}$๋Š” $p$์˜ ๊ฑฐ๋Œ€ํ•œ ๋ชจ๋“œ์— ๋Œ€ํ•ด ๋†’์€ ํ™•๋ฅ ์„ ์ง€์ •ํ•˜๊ณ , ์ž‘์€ ๊ฒƒ๋“ค์— ๋Œ€ํ•ด์„œ๋Š” ๋ฌด์‹œํ•˜๋Š” ๋ชจ์Šต์„ ๋ณด์—ฌ์ค€๋‹ค. ๊ทธ๋ฆผ 2์ฒ˜๋Ÿผ forward KLD ์ตœ์†Œํ™”๋Š” $p$์˜ ์˜ํ™•๋ฅ  ์œ„์น˜์— ํฐ ํ™•๋ฅ ์„ ๋ถ€์—ฌํ•˜์ง€๋งŒ, reverse KLD๋Š” $p$์˜ ์ฃผ์š” ๋ชจ๋“œ์— ์ง‘์ค‘ํ•˜๋Š” ๊ฒƒ์„ ์•Œ ์ˆ˜ ์žˆ๋‹ค.

 

๊ทธ๋ฆผ 2. forward KLD์™€ reverse KLD์˜ ๊ฐ€์šฐ์‹œ์•ˆ ๋ถ„ํฌ

 

 ๋…ผ๋ฌธ์—์„œ๋Š” LLM์„ reverse KLD๋ฅผ ์ตœ์†Œํ™”์‹œํ‚ค๋Š” KD method๋ฅผ ๊ทธ๋ฆผ 3์ฒ˜๋Ÿผ MiniLLM์œผ๋กœ ์ด๋ฆ„ ์ง€์—ˆ๋‹ค. seqKD์™€ ๋‹ฌ๋ฆฌ MiniLLM์€ teacher ๋ถ„ํฌ $p$๋กœ๋ถ€ํ„ฐ ์ƒ˜ํ”Œ๋ง๋œ ๋ชจ๋“  $y$์— ๋Œ€ํ•ด $q_{\theta}$๋ฅผ ๋งž์ถ”๋„๋ก ๊ฐ•์š”ํ•˜์ง€ ์•Š๋Š”๋‹ค. ๊ทธ ๋Œ€์‹ ์—, student model์ด ์ž์‹ ์˜ ๋Šฅ๋ ฅ ๋‚ด์—์„œ teacher๊ฐ€ ์„ ํ˜ธํ•˜๋Š” ์ƒ˜ํ”Œ์„ ์ƒ์„ฑํ•˜๋„๋ก ๊ถŒ์žฅํ•˜๋ฉฐ, ์ด๋Š” ์„ฑ์ทจ ๊ฐ€๋Šฅ์„ฑ์ด ๋” ๋†’๋‹ค.

 

๊ทธ๋ฆผ 3. seqKD(์™ผ์ชฝ)์™€ MiniLLM(์˜ค๋ฅธ์ชฝ)์˜ ๋น„๊ต. seqKD๋Š” student๊ฐ€ ๋ชจ๋“  teahcer-generated ์ƒ˜ํ”Œ์„ ๊ธฐ์–ตํ•˜๋„๋ก ๊ฐ•์š”ํ•˜๋Š” ๋ฐ˜๋ฉด, MiniLLM์€ ์ž๊ธฐ ์ž์‹ ์˜ ์ƒ์„ฑ๊ณผ teacher์˜ ํ”ผ๋“œ๋ฐฑ์„ ์‚ฌ์šฉํ•ด์„œ student๊ฐ€ ๊ฐœ์„ ๋  ์ˆ˜ ์žˆ๋„๋ก ํ—ˆ๋ฝํ•ด ์คŒ

 

2-2. Optimization with Policy Gradient

 

Gradient Derivation.  ๋…ผ๋ฌธ์—์„œ๋Š” objective function์˜ ๊ธฐ์šธ๊ธฐ๊ฐ€ Policy Gradient ์ •๋ฆฌ๋ฅผ ์‚ฌ์šฉํ•ด์„œ ๋‹ค์Œ๊ณผ ๊ฐ™์ด ์–ป์–ด์งˆ ์ˆ˜ ์žˆ๋‹ค๊ณ  ๋งํ•œ๋‹ค.

 

 

 ์—ฌ๊ธฐ์„œ $T = |\mathbf{y}|$์ด๊ณ  $R_{t} = \sum_{t'=t}^{T} log \frac{p(y_{t'}|\mathbf{y}_{<t'}, x)}{q_{\theta}(y_{t'}|\mathbf{y}_{<t'}, x)}$๋Š” ๊ฐ ์Šคํ… ์ƒ์„ฑ์˜ ํ€„๋ฆฌํ‹ฐ๋ฅผ ์ธก์ •ํ•˜๋Š” $r_{t'} = log \frac{p(y_{t'}|\mathbf{y}_{<t'}, x)}{q_{\theta}(y_{t'}|\mathbf{y}_{<t'}, x)}$์˜ ์ถ•์ ์ด๋‹ค. ์ง๊ด€์ ์œผ๋กœ teacher ๋ถ„ํฌ ํ•˜์—์„œ high probability๋ฅผ ๊ฐ€์ง€๊ธฐ ์œ„ํ•ด $p(y_{t'}|\mathbf{y}_{<t'}, x)$๊ฐ€ ์ฆ๊ฐ€๋˜๊ธธ ์›ํ•˜์ง€๋งŒ, ๋™์‹œ์— $q_{\theta}(y_{t'}|\mathbf{y}_{<t'}, x)$๋ฅผ ๋‚ฎ์ถค์œผ๋กœ์จ ๋‹ค์–‘์„ฑ์„ ์œ ์ง€ํ•˜๊ธฐ๋ฅผ ์›ํ•œ๋‹ค. ํ•˜์ง€๋งŒ, ์•„์ง ๋ช‡ ๊ฐ€์ง€ ๋ฌธ์ œ๋ฅผ ๊ฐ€์ง€๊ณ  ์žˆ๋Š”๋ฐ, ๋…ผ๋ฌธ์—์„œ๋Š” ์ด ๋ฌธ์ œ๋ฅผ ์™„ํ™”ํ•˜๊ธฐ ์œ„ํ•œ 3๊ฐ€์ง€ ์ „๋žต์„ ์ œ์‹œํ•˜์˜€๋‹ค.

 

Single-Step Regularization.  single-step ์ƒ์„ฑ ํ€„๋ฆฌํ‹ฐ $r_{t}$๋Š” ๋ถ„์‚ฐ์„ ํ•™์Šตํ•˜๋Š”๋ฐ ์ค‘์š”ํ•˜๋‹ค. $r_{t}$์— ๋” ์ž˜ ์ง‘์ค‘ํ•˜๊ธฐ ์œ„ํ•ด $R_{t}$๋กœ๋ถ€ํ„ฐ $r_{t}$๋ฅผ ๋ถ„ํ• ํ•˜๊ธฐ ์œ„ํ•ด $\bigtriangledown \mathbf{J}(\theta)$๋ฅผ ์žฌ์ž‘์„ฑํ•˜๊ณ  $\mathbb{E}_{y_{t} ~ q_{\theta}(t)}[r_{t}]$์˜ ๊ธฐ์šธ๊ธฐ๋ฅผ ์ •๊ทœํ™”๋กœ ์ง์ ‘์ ์œผ๋กœ ๊ณ„์‚ฐํ•˜์˜€๋‹ค.

 

 

 ์ด๋Ÿฌํ•œ ์ •๊ทœํ™”๋Š” ํ•™์Šต ์ค‘์— ๋ถ„์‚ฐ์„ ์ค„์—ฌ์ฃผ๊ณ  ์ˆ˜๋ ด์„ ๊ฐ€์†ํ™”์‹œ์ผœ ์ฃผ๋Š” single-step ์ƒ์„ฑ ํ€„๋ฆฌํ‹ฐ์˜ ๋”์šฑ ์ •ํ™•ํ•˜๊ณ  ํšจ์œจ์ ์ธ ์ธก์ •์„ ์ค€๋‹ค.

 

Teacher-Mixed Sampling.  ๋…ผ๋ฌธ์—์„œ๋Š” reward ํ•ดํ‚น์„ ์ค„์ด๊ธฐ ์œ„ํ•ด teacher & syudent ๋ถ„ํฌ๋ฅผ ๊ฐ time step์—์„œ ์„ž์—ˆ๋‹ค.

 

 

 ์—ฌ๊ธฐ์„œ $\alpha$๋Š” ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ๋กœ teacher mix-in์˜ ๊ฐ•๋„๋ฅผ ์กฐ์ ˆํ•œ๋‹ค. $\tilde{p}$๋กœ๋ถ€ํ„ฐ์˜ ์ƒ˜ํ”Œ๋ง์ด teacher์˜ ๋„์›€์œผ๋กœ low-quality ์ƒ์„ฑ์„ ์–ต์ œํ•˜๊ณ , reward ํ•ดํ‚น์„ ์™„ํ™”์‹œ์ผœ ์ค€๋‹ค. ๋…ผ๋ฌธ์—์„œ๋Š” ๊ธฐ์šธ๊ธฐ์˜ unbiased ์ธก์ •์ž๋ฅผ ์–ป๊ธฐ ์œ„ํ•ด $(\bigtriangledown \mathbf{J})_{Main}$๊ณผ $(\bigtriangledown \mathbf{J})_{Reg}$๋ฅผ importance smapling๊ณผ ํ•จ๊ป˜ ์žฌ์ž‘์„ฑํ•˜์˜€๋‹ค. 

 

 

 ์—ฌ๊ธฐ์„œ $w_{t}$๋Š” importance weight์ธ๋ฐ, $w_{t}$๋ฅผ ์‚ฌ์šฉํ•˜๋Š” ๊ฒƒ์€ ํ•˜์ดํผ ํŒŒ๋ผ๋ฏธํ„ฐ์— ๋Œ€ํ•ด ๋งค์šฐ ๋ฏผ๊ฐํ•˜๊ณ , ๋А๋ฆฌ๊ฒŒ ์ˆ˜๋ ดํ•˜๊ฒŒ ๋œ๋‹ค. ๊ทธ๋ž˜์„œ ์ธก์ •์ž์˜ ๋ถ„์‚ฐ์„ ๋‚ฎ์ถ”๊ธฐ ์œ„ํ•ด $w_{t} \approx \frac{q_{\theta}(y_{t'}|\mathbf{y}_{<t'}, x)}{\tilde_{p}(y_{t'}|\mathbf{y}_{<t'}, x)}$๋กœ ๊ทผ์‚ฌํ•˜์˜€๋‹ค.

 

Length Normalization.  long sequence๋Š” ์ž‘์€ $R_{t+1}$์„ ๊ฐ€์ง€๋Š” ๊ฒฝํ–ฅ์ด ์žˆ๋Š”๋ฐ, ์ด๋Š” ๋ชจ๋ธ์ด ์งง์€ ์‘๋‹ต์„ ์ƒ์„ฑํ•˜๋„๋ก ๋ถ๋‹์•„์ค€๋‹ค. ๊ทธ๋ž˜์„œ ๋…ผ๋ฌธ์—์„œ๋Š” length normalization์„ $R_{t+1}$์— ์ถ”๊ฐ€ํ•˜์˜€๋‹ค.

 

 

In Summary.  ์•ž์„œ ๋‚˜์—ด๋œ ์ „๋žต์„ ๋ฌถ์–ด์„œ, ๋‹ค์Œ๊ณผ ๊ฐ™์ด ์ตœ์ข… ์ตœ์ ํ™” ๊ธฐ์šธ๊ธฐ๋ฅผ ์–ป์„ ์ˆ˜ ์žˆ์—ˆ๋‹ค.

 

 

2-3. Training Algorithm

 

 MiniLLM์˜ training algorithm์€ ์•Œ๊ณ ๋ฆฌ์ฆ˜ 1์— ๋‚˜ํƒ€๋‚˜ ์žˆ๋‹ค. ๋…ผ๋ฌธ์—์„œ๋Š” ๊ฐ€์žฅ ๋‚ฎ์€ validation loss๋ฅผ ๊ฐ€์ง€๋Š” training data์—์„œ fine-tune ๋œ checkpoint๋กœ๋ถ€ํ„ฐ student model์„ ์ดˆ๊ธฐํ™”ํ•˜์˜€๋‹ค. ๊ทธ๋ฆฌ๊ณ  $(\bigtriangledown \mathbf{J})_{Main}$์— PPO clipping ์ „๋žต์„ ์ถ”๊ฐ€ํ•ด์„œ ํ•™์Šต ์•ˆ์ „์„ฑ์„ ๊ฐœ์„ ์‹œ์ผฐ๋‹ค. ๋˜ํ•œ ํ•™์Šต ํšจ์œจ์„ฑ์„ ๊ฐœ์„ ์‹œํ‚ค๊ธฐ ์œ„ํ•ด PPO์—์„œ value network์™€ KL ์ •๊ทœํ™”๋ฅผ ์‚ฌ์šฉํ•˜์ง€๋Š” ์•Š์•˜๋‹ค. ๊ทธ๋ฆฌ๊ณ  pre-training corpus์— language modeling loss $L_{PT}$๋ฅผ ์ถ”๊ฐ€ํ•˜์˜€๋‹ค. 

 

์•Œ๊ณ ๋ฆฌ์ฆ˜ 1. MiniLLM: Knowledge Distillation of LLMs

 

3. Experiments

3-1. Experimental Setup

 

 ๊ฑฐ๋Œ€ ๋ชจ๋ธ์„ instruction-response ๋ฐ์ดํ„ฐ์…‹ $D$์—์„œ fine-tune ํ•จ์œผ๋กœ์จ teacher $p$๋ฅผ ๋งŒ๋“ค์—ˆ๋‹ค. ๊ทธ๋‹ค์Œ์— teacher์˜ ์ง€๋„์™€ ํ•จ๊ป˜ $D$์—์„œ smaller student model์„ distill ํ•˜๊ธฐ ์œ„ํ•ด ์„œ๋กœ ๋‹ค๋ฅธ KD method๋ฅผ ๋น„๊ตํ•˜์˜€๋‹ค.

 

Base Models.  3๊ฐ€์ง€ ์ข…๋ฅ˜์˜ ๋ชจ๋ธ์„ ๋‹ค์–‘ํ•œ ์‚ฌ์ด์ฆˆ๋ฅผ ์‚ฌ์šฉํ•ด์„œ distill ํ•˜์˜€๋‹ค.

 

  • OPT(1.3B, 2.7B, 6.7B) - teacher model: OPT-13B
  • GPT-2(120M, 340M, 760M) - teacher model: GPT-2-1.5B
  • LLaMA(7B) - teacher model: LLaMA-13B

 

Training.  training data๋กœ databricks-dolly-15k๋ฅผ ์‚ฌ์šฉํ•˜์˜€๋‹ค. ๊ทธ๋ฆฌ๊ณ  $D_{PT}$์— ๋Œ€ํ•ด GPT-2๋Š” OpenWebText๋ฅผ ์‚ฌ์šฉํ•˜๊ณ  ๋‹ค๋ฅธ ๋ชจ๋ธ์—๋Š” RoBERTa training corpus๋ฅผ ์‚ฌ์šฉํ•˜์˜€๋‹ค.

 

Evaluation.  ๋…ผ๋ฌธ์—์„œ๋Š” 5๊ฐœ์˜ instruction-following ๋ฐ์ดํ„ฐ์…‹์—์„œ ํ•™์Šต๋œ ๋ชจ๋ธ์„ ํ‰๊ฐ€ํ•˜์˜€๋‹ค.

 

  • DollyEval
  • SelfInst
  • VicunaEval
  • S-NI
  • UnNI

 

 ๋˜ํ•œ model-generated ์‘๋‹ต์„ ํ‰๊ฐ€ํ•˜๊ธฐ ์œ„ํ•ด 2๊ฐœ์˜ metric์„ ์‚ฌ์šฉํ•˜์˜€๋‹ค.

 

  • ROUGE-L
  • GPT-4

 

Baselines.  ๋…ผ๋ฌธ์—์„œ๋Š” ๋ฉ”์ธ ์‹คํ—˜์—์„œ ์ด 3๊ฐœ์˜ baseline์„ ๊ณ ๋ คํ•˜์˜€๋‹ค.

 

  • SFT w/o KD: student model์„ golden response๋กœ supervise ๋œ $D$์—์„œ ๋ฐ”๋กœ fine-tune
  • KD: student model์„ teacher-generated data์—์„œ fine-tune
  • SeqKD: student model์„ teacher-generated data์—์„œ fine-tune

 

3-2. Results

 

 ํ‰๊ฐ€ ๊ฒฐ๊ณผ๋Š” ํ‘œ 1์— ๋‚˜ํƒ€๋‚˜ ์žˆ๋‹ค.

 

ํ‘œ 1. ํ‰๊ฐ€ ๊ฒฐ๊ณผ betst score๋Š” ๋ณผ๋“œ์ฒด, teacher model์„ ๋Šฅ๊ฐ€ํ•˜๋Š” student model์€ * ํ‘œ์‹œํ•จ.

 

 ์ด ๊ฒฐ๊ณผ๋ฅผ ํ†ตํ•ด ๋…ผ๋ฌธ์—์„œ ๊ด€์ฐฐํ•œ ์ ์€ ๋‹ค์Œ๊ณผ ๊ฐ™๋‹ค.

 

  1. SFT๋ฅผ forward KLD๋ฅผ ์ตœ์†Œํ™”ํ•˜๋Š” KD & seqKD์™€ ๋น„๊ตํ•จ์œผ๋กœ์จ ๊ธฐ์กด์˜ KD method๊ฐ€ teacher model๋กœ๋ถ€ํ„ฐ ๋Œ€๋ถ€๋ถ„์˜ ๊ฒฝ์šฐ์— ์„ฑ๊ณต์ ์œผ๋กœ ์ง€์‹์„ distill ํ•˜๋Š” ๊ฒƒ์„ ํ™•์ธํ•  ์ˆ˜ ์žˆ์—ˆ๋‹ค.
  2. MiniLLM๊ณผ baseline์˜ GPT-4 ํ”ผ๋“œ๋ฐฑ score๋ฅผ ๋น„๊ตํ•จ์œผ๋กœ์จ ๋…ผ๋ฌธ์˜ method๋กœ distill ๋œ ๋ชจ๋ธ์ด ๋Œ€๋ถ€๋ถ„์˜ ๊ฒฝ์šฐ์— baseline์„ ๋Šฅ๊ฐ€ํ•œ๋‹ค. ์ด๊ฒƒ์€ MiniLLM์ด ์ „๋ฐ˜์ ์œผ๋กœ ๋†’์€ ์„ฑ๋Šฅ์„ ๋ณด์—ฌ์ค€๋‹ค๋Š” ๊ฒƒ์„ ์•Œ ์ˆ˜ ์žˆ์—ˆ๋‹ค. ๋˜ํ•œ ๋‹ค๋ฅธ ๋ฐ์ดํ„ฐ์…‹๋ณด๋‹ค DollyEval์—์„œ ๋” ์ž˜ ์ž‘๋™ํ•˜๋Š” ๋ชจ์Šต์„ ๋ณด์—ฌ์คฌ๋Š”๋ฐ, ์ด๋Š” ์ข‹์€ OOD ์ผ๋ฐ˜ํ™”๋ฅผ ๊ฐ€์ง„๋‹ค๋Š” ๊ฒƒ์„ ์˜๋ฏธํ•œ๋‹ค.
  3. ROUGE-L score๋Š” MiniLLM์ด ground-truth ์‘๋‹ต๊ณผ ๊ฐ€์žฅ ๋งŽ์€ ์˜ค๋ฒ„๋žฉ์„ ๊ฐ€์ง€๋Š” ๊ฐ€์žฅ ์ •ํ™•ํ•œ ์‘๋‹ต์„ ์ƒ์„ฑํ•œ๋‹ค๋Š” ๊ฒƒ์„ ๋ณด์—ฌ์ค€๋‹ค.
  4. MiniLLM์˜ ๊ฐœ์„ ์€ base model์˜ ์‚ฌ์ด์ฆˆ๊ฐ€ 120M์œผ๋กœ๋ถ€ํ„ฐ 13B๋กœ ๋‹ค์–‘ํ•ด์ง์—๋„ ์ผ๊ด€๋˜์—ˆ๋‹ค. ์ด๋Ÿฌํ•œ ๊ฒฝํ–ฅ์€ ๊ทธ๋ฆผ 1์— ๋‚˜ํƒ€๋‚˜ ์žˆ๊ณ , ์ด๊ฒƒ์€ ํ›Œ๋ฅญํ•œ scalability์™€ generalization์„ ๋ณด์—ฌ์ค€๋‹ค.

 

3-3. Analysis

 

Exposure Bias.  forward KLD๋Š” exposure bias๋ฅผ ๊ฒช๋Š”๋‹ค. MiniLLM์—์„œ๋Š” ํ•™์Šต ์ค‘์— student model๋กœ๋ถ€ํ„ฐ ์ƒ˜ํ”Œ์„ ์ˆ˜์ง‘ํ•˜๋Š”๋ฐ ์ด๊ฒƒ์ด training๊ณผ evaluation ๊ฐ„์˜ ๋ฏธ์Šค๋งค์น˜๋ฅผ ์™„ํ™”์‹œ์ผœ ์ค€๋‹ค.

 

Calibration.  RL-trained model์€ ์•ˆ ์ข‹์€ calibration์„ ๋ณด์—ฌ์ค€๋‹ค. ๊ทธ๋ž˜์„œ MiniLLM๊ณผ KD baseline๋“ค์˜ calibration์„ ํ…Œ์ŠคํŠธํ•˜์˜€๋‹ค. ๊ทธ ๊ฒฐ๊ณผ, KD์™€ seqKD๋กœ ํ•™์Šต๋œ ๋ชจ๋ธ์€ teacher model๋ณด๋‹ค ์•ˆ ์ข‹์€ calibration์„ ๋ณด์—ฌ์ค€๋‹ค. ์ด๋Š” forward KLD๋ฅผ ์ตœ์†Œํ™”ํ•˜๊ธฐ ๋•Œ๋ฌธ์— student์™€ teacher ๋ชจ๋ธ ๊ฐ„์˜ ์ƒ๋‹นํ•œ ๋ถ„ํฌ ์ฐจ์ด๋ฅผ ์ด๋ˆ๋‹ค. ๋ฐ˜๋ฉด์— MiniLLM์€ ํƒ€๊นƒ ๋ถ„ํฌ์˜ ์ฃผ๋œ ๋ถ€๋ถ„์„ ์ •ํ™•ํ•˜๊ฒŒ ํ•™์Šตํ•˜๋Š”๋ฐ ์ฃผ๋ชฉํ•ด์„œ, ์ด๊ฒƒ์ด student & teacher ๊ฐ„์˜ ECE score ๊ฐญ์„ ์ขํžŒ๋‹ค.

 

Performance on DIfferent Response Length.  ์„œ๋กœ ๋‹ค๋ฅธ ๋ฒ”์œ„์˜ golden response ์‘๋‹ต ๊ธธ์ด๊ฐ€ ์ฃผ์–ด์กŒ์„ ๋•Œ ๋ชจ๋ธ์˜ ์„ฑ๋Šฅ์— ๋Œ€ํ•ด ์—ฐ๊ตฌํ•˜์˜€๋‹ค. ๊ทธ๋ฆผ 4์—์„œ๋Š” ground-truth ์‘๋‹ต์˜ ๊ธธ์ด๋กœ ๋ถ„ํ• ๋œ 3๊ฐœ์˜ S-NI ํ•˜์œ„ ์ง‘ํ•ฉ์—์„œ SFT ๋ชจ๋ธ์— ๋Œ€ํ•œ ๋‹ค์–‘ํ•œ KD ๋ชจ๋ธ์˜ ROUGE-L socre๋ฅผ ๋ณด์—ฌ์ค€๋‹ค. ๋…ผ๋ฌธ์—์„œ๋Š” ์งง์€ ์‘๋‹ต($\leq 5$)์„ ์˜ˆ์ธกํ•˜๋Š” prompt์—์„œ๋Š” ๋‚ฎ์€ score๋ฅผ ๋‹ฌ์„ฑํ•˜๋Š” ๋ชจ์Šต์„ ๋ฐœ๊ฒฌํ•˜์˜€๋‹ค. ์ด๋Š” training set์˜ ๋Œ€๋ถ€๋ถ„์€ ๊ธด ๋ฌธ์žฅ์œผ๋กœ ์ด๋ฃจ์–ด์ ธ ์žˆ๊ธฐ ๋•Œ๋ฌธ์ด๋‹ค. ๊ทธ๋ฆฌ๊ณ  output space๊ฐ€ ๋น„๊ต์  ์ž‘์•„์„œ reverse & forward KLD๋Š” ๋น„์Šทํ•œ ์„ฑ๋Šฅ์„ ๊ฐ€์ง„๋‹ค. longer response($\geq 6$)์˜ prompt์— ๋Œ€ํ•ด MiniLLM์€ ๊ธฐ์กด KD approach๋ณด๋‹ค ์žฅ์ ์„ ๊ฐ€์กŒ๋‹ค.

 

๊ทธ๋ฆผ 4.&nbsp;distilled model์˜ ROUGE-L score

 

3-4. Ablations

 

Effect of Optimization Strategies.  ์ตœ์ ํ™”๋ฅผ ์•ˆ์ •ํ™”์‹œํ‚ค๊ณ  ๊ฐ€์†ํ™”ํ•˜๊ธฐ ์œ„ํ•ด ์ œ์•ˆ๋œ 3๊ฐœ์˜ ์ „๋žต์— ๋Œ€ํ•ด ablation์„ ์ง„ํ–‰ํ•˜์˜€๋‹ค. Teacher-Mixed ์ƒ˜ํ”Œ๋ง๊ณผ Length Normalization์€ training์„ ์•ˆ์ •ํ™”์‹œํ‚ค๊ธฐ ์œ„ํ•ด ํ•„์ˆ˜์ ์ด๋ผ๋Š” ๊ฒƒ์„ ๋ฐœ๊ฒฌํ•˜์˜€๋‹ค. ๋น„๋ก ์ด๋Ÿฌํ•œ ์ „๋žต ์—†์ด๋„ reverse KD๋Š” ๊ฐ์†Œํ•˜์ง€๋งŒ, ๊ทธ๋Ÿด ๊ฒฝ์šฐ์— ๋ชจ๋ธ์€ reward ํ•ดํ‚น์„ ๊ฒช์„ ๋ฟ๋งŒ ์•„๋‹ˆ๋ผ ํ‘œ 2์ฒ˜๋Ÿผ ๋‚ฎ์€ ์ƒ์„ฑ ์„ฑ๋Šฅ์„ ์ด๋ˆ๋‹ค. ๊ทธ๋ฆผ 5๋กœ๋ถ€ํ„ฐ Single-Step ์ •๊ทœํ™”๋Š” training ํ”„๋กœ์„ธ์Šค์˜ ๋ถ„์‚ฐ์„ ํšจ๊ณผ์ ์œผ๋กœ ์ค„์—ฌ์ค€๋‹ค๋Š” ๊ฒƒ์„ ๋ฐœ๊ฒฌํ•˜์˜€๋‹ค.

 

ํ‘œ 2.&nbsp;์„œ๋กœ ๋‹ค๋ฅธ MiniLLM ์ „๋žต์ด ์ ์šฉ๋  ๋•Œ validation set์™€ test set์—์„œ์˜ ์„ฑ๋Šฅ

 

๊ทธ๋ฆผ 5.&nbsp;์„œ๋กœ ๋‹ค๋ฅธ ์ตœ์ ํ™” ์ „๋žต์ด ์ ์šฉ๋  ๋•Œ MiniLLM training ์ค‘์— student์™€ teacher ๊ฐ„์˜ forward KLD

 

Effect of Teacher-Mix-in Strength $\alpha$.  ๊ทธ๋ฆผ 6์—์„œ ์„œ๋กœ ๋‹ค๋ฅธ teacher-mix-in strength $\alpha$์—์„œ MiniLLM์˜ ์„ฑ๋Šฅ์„ ๋น„๊ตํ•˜์˜€๋‹ค. ๊ทธ ๊ฒฐ๊ณผ $\alpha = 0.2$๊ฐ€ ๊ฐ€์žฅ ์ ์ ˆํ•˜์˜€๋‹ค.

 

๊ทธ๋ฆผ 6.&nbsp;์•ŒํŒŒ ๊ฐ’์— ๋”ฐ๋ฅธ ROUGE-L score ๋ณ€ํ™”

 

Effect of Adding Pre-training Loss.  ํ‘œ 3์—์„œ pre-training loss๋ฅผ ์ถ”๊ฐ€ํ•˜๋Š” ๊ฒƒ์˜ ํšจ๊ณผ๋ฅผ ์—ฐ๊ตฌํ•˜์˜€๋‹ค. PT loss๋ฅผ ์ถ”๊ฐ€ํ•˜๋Š” ๊ฒƒ์ด instruction-following task์—์„œ์˜ ์„ฑ๋Šฅ์„ ๊ฑฐ์˜ ๋ณ€ํ™” ์—†์ด ์œ ์ง€ํ•˜๋ฉด์„œ NLP task์—์„œ์˜ ๋Šฅ๋ ฅ์„ ๋ณด์กดํ•˜๋Š”๋ฐ ๋„์›€์„ ์คฌ๋‹ค.

 

ํ‘œ 3. pre-training loss๋ฅผ ์ถ”๊ฐ€ํ•˜๋Š” ๊ฒƒ์˜ ํšจ๊ณผ