์ ๋ฒ PaLM ๋ฆฌ๋ทฐ ํฌ์คํธ์ ์ด์ด์ ๋ ๋ฒ์งธ PaLM ๋ฆฌ๋ทฐ ํฌ์คํธ์ ๋๋ค~! PaLM ๋ฆฌ๋ทฐ ํฌ์คํธ$($1$)$์ ๋ณด๊ณ ์ถ์ผ์๋ค๋ฉด ์ฌ๊ธฐ๋ฅผ ์ฐธ๊ณ ํด์ฃผ์๊ธธ ๋ฐ๋๋๋ค!! ์ด๋ฒ ํฌ์คํธ์์๋ PaLM์ ๋ชจ๋ธ ๊ตฌ์กฐ์ ๋ํด ์์ธํ๊ฒ ๋ค๋ค๋ณด๋ ค๊ณ ํฉ๋๋ค. ์ด ํฌ์คํธ๋ PaLM์ ์๊ฐํ ๋ ผ๋ฌธ์ธ 'PaLM: Scaling Language Modeling with Pathways'๋ฅผ ์ฐธ๊ณ ํ์ฌ ์์ฑ๋์์ต๋๋ค.
PaLM Architecture
PaLM์ ๊ธฐ์กด์ Transformer architecture์์ ์ค์ง decoder๋ง์ ์ฌ์ฉํ์๋ค. $($๊ฐ ์์ ์์ ์ค์ง ์์ ๊ณผ ์ด์ ์ ์์ ๋ค๋ง ๋ณผ ์ ์์$)$ ๊ทธ๋ฆฌ๊ณ ์ด์ ์กฐ๊ธ์ ์์ ์ ํ์๋ค.
- SwiGLU ํ์ฑํ ํจ์: PaLM์๋ SwiGLU ํ์ฑํ ํจ์ $(Swish(xW) \cdot xV)$๊ฐ ์ฌ์ฉ๋์๋ค. ์๋ํ๋ฉด, ๊ธฐ์กด์ ReLU์ GeLU, Swish activation๋ณด๋ค ํจ์ฌ ํฅ์๋ ์ฑ๋ฅ์ ๋ณด์ฌ์คฌ๊ธฐ ๋๋ฌธ์ด๋ค. ์ด SwiGLU๋ฅผ ์ฌ์ฉํ๊ธฐ ์ํด์๋ MLP์์ 2๊ฐ๊ฐ ์๋ 3๊ฐ์ ํ๋ ฌ๊ณฑ์ด ํ์ํ์ง๋ง, ์ฐ๊ตฌ์ ๋ฐ๋ฅด๋ฉด ์ปดํจํ ๋ฑ๊ฐ ์คํ์์ ํ์ง ํฅ์์ ๋ณด์ฌ์ฃผ์๋ค.
- ๋ณ๋ ฌ ๋ ์ด์ด: PaLM์์๋ Transformer block์ ๋ํด ๊ธฐ์กด์ '์ง๋ ฌ' ๋ฐฉ์์ด ์๋, '๋ณ๋ ฌ' ๋ฐฉ์์ ์ฌ์ฉํ์๋ค. ๊ธฐ์กด์ ์ง๋ ฌ ๋ฐฉ์์ ๋ค์๊ณผ ๊ฐ์ด ์ธ ์ ์๋ค.
$y = x + MLP(LayerNorm(x + Attention(LayerNorm(x))))$
์ฌ๊ธฐ์ ๋ณ๋ ฌ๋ก ๋ณํํ๋ฉด ๋ค์๊ณผ ๊ฐ์์ง๋ค.
$y = x + MLP(LayerNorm(x)) + Attention(LayerNorm(x))$
๋ณ๋ ฌ ๋ฐฉ์์ ์ฌ์ฉํ๊ฒ ๋๋ฉด, 15% ์ ๋ ๋น ๋ฅธ ์๋๋ก ํ์ต์ ์งํํ ์ ์๋ค. ์๋ํ๋ฉด, MLP์ Attention์ ์ ๋ ฅ ํ๋ ฌ๊ณฑ์ด ํ ๋ฒ์ ์งํ๋ ์ ์๊ธฐ ๋๋ฌธ์ด๋ค.
- Multi-Query Attention: ๊ธฐ์กด์ Transformer ๋ฐฉ์์ $k$ attention head๋ฅผ ์ฌ์ฉํ๊ณ , ๊ฐ๊ฐ์ attention head์์๋ ๊ฐ ์์ ์ ๋ํ ์ ๋ ฅ ๋ฒกํฐ๋ ๋ชจ์ $[k, h]$์ "query", "key" ๋ฐ "value" ํ ์๋ก linear projection๋๋ค. ์ฌ๊ธฐ์ $h$๋ attention head์ ํฌ๊ธฐ์ด๊ณ , $key/value$ projection์ ๊ฐ head์ ๋ํด ๊ณต์ ๋๋ค. ์ด๋ฌํ ๋ฐฉ์์ ๋ชจ๋ธ์ ํ๋ฆฌํฐ์ ํ์ต ์๋์ ์ค๋ฆฝ์ ์ ์งํ๋ ํจ๊ณผ๋ฅผ ๊ฐ์ง๊ณ ์์ง๋ง, autoregressiveํ decoding time์์ ์๋นํ ๋น์ฉ ์ ์ฝ ํจ๊ณผ๋ฅผ ๊ฐ์ง๊ณ ์๋ค. ์ด๋ ์๋ํ๋ฉด, ๊ธฐ์กด์ multi-headed attention์ ๊ฐ์๊ธฐ ํ๋์จ์ด ์์์ autoregressive decoding ์ค์ ๋ฎ์ ํจ์จ์ฑ์ ๊ฐ์ง๊ธฐ ๋๋ฌธ์ด๋ค. ์๋ํ๋ฉด, $key/value$ tensor๊ฐ ์์๋ค ์ฌ์ด์์ ๊ณต์ ๋์ง ์๊ณ , ํ ๋ฒ์ ์ค์ง ํ๋์ ํ ํฐ๋ง์ด ๋์ฝ๋ฉ๋๊ธฐ ๋๋ฌธ์ด๋ค.
- RoPE ์๋ฒ ๋ฉ: PaLM์๋ absolute ๋๋ relative position embedding์ด ์ฌ์ฉ๋์ง ์๊ณ , RoPE embedding์ด ์ฌ์ฉ๋์๋ค. ์๋ํ๋ฉด, RoPE embedding์ด ๊ธด ๋ฌธ์ฅ ๊ธธ์ด์ ๋ํด ๋ ์ข์ ์ฑ๋ฅ์ ๋ณด์ฌ์คฌ๊ธฐ ๋๋ฌธ์ด๋ค.
- Shared Input-Output Embeddings: ์ด์ ์์ ์์ ์์ฃผ$($๋ณดํธ์ ์ผ๋ก๋ ์๋์ง๋ง$)$ ์ํ๋๋ ์ ๋ ฅ ๋ฐ ์ถ๋ ฅ ์๋ฒ ๋ฉ ํ๋ ฌ์ ๊ณต์ ํ์๋ค.
- No Biases: dense kernel์ด๋ layer norm์ ์ด๋ ํ ํธํฅ๋ ์ฌ์ฉํ์ง ์์๋ค. ๊ฑฐ๋ํ ๋ชจ๋ธ์ ๋ํด์๋ ์ด๋ ๊ฒ ํ๋ ๊ฒ์ด ํ์ต ์์ ์ฑ์ ํฅ์์์ผฐ๊ธฐ ๋๋ฌธ์ด๋ค.
- Vocabulary: PaLM์ 256k๊ฐ์ ํ ํฐ์ ๊ฐ์ง๋ SentencePiece vocabulary๋ฅผ ์ฌ์ฉํ์๋ค. ์ฌ๊ธฐ์ 256k๊ฐ์ ํ ํฐ์ ๊ณผ๋ํ ํ ํฐํ๋ฅผ ํ์ง ์๊ณ , ํ์ต ์ฝํผ์ค์ ๋ง์ ์์ ์ธ์ด๋ฅผ ์ง์ํ๊ธฐ ์ํด ์ ํ๋์๋ค. vocabulary๋ ํ์ต ๋ฐ์ดํฐ๋ก๋ถํฐ ์์ฑ๋์๊ณ , ์ด๊ฒ์ด ํ์ต ํจ์จ์ฑ์ ํฅ์์ํจ๋ค๋ ๊ฒ์ ์ ์ ์์๋ค. vocabulary๋ losslessํ๊ณ reversibleํ๋ฐ, ์ด๋ vocabulary์ white space๊ฐ ์๋ฒฝํ๊ฒ ๋ณด์กด๋๋ค๋ ์๋ฏธ์ด๋ค. ๊ทธ๋ฆฌ๊ณ , vocabulary์ ์๋ Unicode character๋ ๊ฐ ๋ฐ์ดํธ์ ๋ํ ์ดํ ํ ํฐ๊ณผ ํจ๊ป UTF-8 byte๋ก ๋ถํ ๋์๋ค. ์ซ์๋ ํญ์ ๊ฐ๊ฐ์ digit token์ผ๋ก ๋ถํ ๋์๋ค. $($123.5 -> 1 2 3 . 5$)$
Model Scale Hyperparameter
๋ชจ๋ธ์ ํฌ๊ธฐ๋ฅผ ์ฌ๋ฌ๊ฐ์ง๋ก ํ์ฌ ์คํ์ ์งํํ์๋๋ฐ, ๊ฐ๊ฐ 540B, 62B, 8B๊ฐ์ ํ๋ผ๋ฏธํฐ๋ฅผ ๊ฐ์ง๋ ๋ชจ๋ธ๋ค์ด๋ค. ๊ฐ ๋ชจ๋ธ๋ค์ ๊ตฌ์ฒด์ ์ธ ์ค๋ช ์ ๋ค์์ ํ๋ฅผ ์ฐธ๊ณ ํ์ฌ๋ผ.
๋ง๋ฌด๋ฆฌ
์ด ํฌ์คํธ์์๋ ๋ชจ๋ธ์ ๊ตฌ์กฐ์ ๋ํด ๊ฐ๋ตํ๊ฒ ์ดํด๋ณด์๋ค. ๋์ฑ ์์ธํ ๋ด์ฉ๋ค๊น์ง ๋ค๋ฃจ๊ธฐ์๋ ์์ด ๋๋ฌด ๋ฐฉ๋ํ์ฌ ์ด๋ ๊ฒ๊น์ง๋ง ์์ฑํ ์ ์ํด ๋ถํ๋๋ฆฝ๋๋ค,, ๋์ฑ ์์ธํ ๋ด์ฉ์ ๋ณด๊ณ ์ถ์ผ์๋ฉด ์๋ ์ฒจ๋ถ๋์ด ์๋ ๋ ผ๋ฌธ์ ํ ๋ฒ ์ฝ์ด๋ณด์๊ธธ ๋ฐ๋๋๋ค!! ๊ทธ๋ผ ์ด๋ง ํฌ์คํธ๋ฅผ ๋ง์ณ๋ณด๋๋ก ํ๊ฒ ์ต๋๋ค~
์ถ์ฒ
https://arxiv.org/abs/2204.02311