The overview of this paper
Transformer๋ ๋งค์ฐ ๊ฐ๋ ฅํ sequence model์ด์ง๋ง, sequence์ ๊ธธ์ด์ ๋ฐ๋ผ์ ์๊ฐ๊ณผ ๋ฉ๋ชจ๋ฆฌ๊ฐ ๊ณฑ์ ๋ก ํ์ํ๋ค๋ ๋จ์ ์ด ์๋ค. ์ด ๋ ผ๋ฌธ์์๋ attention ํ๋ ฌ์ sparse factorization์ ์๊ฐํ์๋๋ฐ, ์ด๋ Transformer์ ์๊ฐ ๋ณต์ก๋๋ฅผ $O(n \sqrt{n})$์ผ๋ก ์ค์๋ค. ๋ํ ๋ ผ๋ฌธ์์๋ ๋ค์์ ๋ด์ฉ๋ค์ ์๊ฐํ์๋ค.
- ๋์ฑ ๊น์ ๋คํธ์ํฌ๋ฅผ ํ์ต์ํค๊ธฐ ์ํด ๋ชจ๋ธ์ ๊ตฌ์กฐ์ ์ด๊ธฐํ์ ๋ณ๋์ ์ฃผ์์.
- attention ํ๋ ฌ์ ์ฌ๊ณ์ฐ์ผ๋ก ๋ฉ๋ชจ๋ฆฌ๋ฅผ ์๋.
- ํ์ต์ ์ํด fast attention์ ์ฌ์ฉํจ.
์ด๋ฌํ ๋ณํ๋ฅผ ์ค ๋ชจ๋ธ์ Sparse Transformer๋ผ๊ณ ๋ถ๋ฅด๊ธฐ๋ก ํ๋ค. ์ด ๋ชจ๋ธ์ ์๋ฐฑ๊ฐ์ ๋ ์ด์ด๋ฅผ ์ฌ์ฉํด์ ์๋ง ๊ฐ์ ์ํ์ค๋ฅผ ๋ชจ๋ธ๋งํ ์ ์์์ ๋ณด์ฌ์คฌ๋ค. ๊ทธ๋ฆฌ๊ณ ๋๊ฐ์ ๋ชจ๋ธ์ ์ฌ์ฉํ์ฌ ์ด๋ฏธ์ง, ์ค๋์ค, ํ ์คํธ๋ฅผ ๋ชจ๋ธ๋งํ์๋ค. ๋ ผ๋ฌธ์์๋ ๊ธ๋ก๋ฒ ์ผ๊ด์ฑ๊ณผ ๋ค์์ฑ์ ์ ์ฆํ๋ unconditional sample์ ์์ฑํ๊ณ ๊ธธ์ด๊ฐ 100๋ง ์ด์์ธ ๋ชจ๋ธ ์ํ์ค์ self-attention์ ์ฌ์ฉํ๋ ๊ฒ์ด ๊ฐ๋ฅํจ์ ๋ณด์ฌ์คฌ๋ค.
Table of Contents
1. Introduction
2. Background
3. Factorized Self-Attention
3-1. Qualitative assessment of learned attention patterns
3-2. Factorized self-attention
3-3. Two-dimensional factorized attention
4. Sparse Transformer
4-1. Factorized attention heads
4-2. Scaling to hundreds of layers
4-3. Modeling diverse data types
4-4. Saving memory by recomputing attention weights
4-5. Mixed-precision training
5. Training
6. Experiments
1. Introduction
๋ณต์กํ๊ณ , ๊ณ ์ฐจ์์ ๋ฐ์ดํฐ ๋ถํฌ๋ฅผ ์ธก์ ํ๋ ๊ฒ์ unsupervised learning์์ ์ค์ฌ ๋ฌธ์ ์ด๋ค. ์ถ๊ฐ์ ์ผ๋ก ์ด๊ฒ์ unsupervised representation learning์ ์ค์ํ ์์๋ผ๊ณ ์ฌ๊ฒจ์ง๋ค. neural autoregressive ๋ฐฉ์์ ๊ฒฐํฉ ํ๋ฅ ๋ถํฌ๋ฅผ ์กฐ๊ฑด๋ถ ๋ถํฌ์ ๊ณฑ์ผ๋ก ๋ถํดํ๋ค. ์ด๋ฌํ ์กฐ๊ฑด๋ถ ๋ถํฌ๋ฅผ ๋ชจ๋ธ๋งํ๋ ๊ฒ์ ๋งค์ฐ ์ด๋ ต์ง๋ง, ๊ทธ๋งํผ ์ด ๋ถํฌ๋ ๋ณต์กํ๊ณ , long-range dependenciesํ๊ณ , ํ์ตํ๊ธฐ ์ํด ์ ์ ํ expressive model์ ํ์๋ก ํ๋ค.
๋ณ๋๋ก, Transformer๋ ์ฌ๋ฌ NLP task์์ ๋ฐ์ด๋ ์ฑ๋ฅ์ ๋ณด์ฌ์ฃผ๊ณ ์๊ณ , ์ด๋ ๋ถ๋ถ์ ์ผ๋ก ์ผ์ ํ ์์ ๋ ์ด์ด์์ ์์์ ์ข ์์ฑ์ ๋ชจ๋ธ๋งํ๋ ๊ธฐ๋ฅ ๋๋ฌธ์ผ ์ ์๋ค. ๊ฐ๊ฐ์ self-attention ๋ ์ด์ด๊ฐ global receptive ํ๋๋ฅผ ๊ฐ์ง๊ณ ์๋ ๊ฒ์ฒ๋ผ, ๋คํธ์ํฌ๋ representational ์์ฉ๋ ฅ์ ์ ๋ ฅ ์์ญ์ ๊ฐ์ฅ ์ ์ฉํ ๋ถ๋ถ์ ํ ๋นํ ์ ์๋ค. ๋ฐ๋ผ์ architecture๋ ๊ณ ์ ๋ ์ฐ๊ฒฐ ํจํด์ ๊ฐ์ง๊ณ ์๋ ๋คํธ์ํฌ๋ณด๋ค ๋์ฑ ์ ์ฐํ๊ฒ ๋ค์ํ ์ ํ์ ๋ฐ์ดํฐ๋ฅผ ์์ฑํ ์ ์์๋ ๊ฑธ ์๋ ์๋ค.
ํ์ง๋ง, Transformer๊ฐ ํ์๋ก ํ๋ ๋ฉ๋ชจ๋ฆฌ์ ๊ณ์ฐ๋์ sequence์ ๊ธธ์ด์ ๋ฐ๋ผ์ ๊ณฑ์ ๋ก ๋์ด๋๋ค. ๋ฐ๋ผ์ Transformer๋ ๊ธด sequence๋ฅผ ์ฌ์ฉํ๋ ๊ฒ์ ๋ฐฐ์ ํ๋ค.
์ด ๋ ผ๋ฌธ์ ์ฃผ๋ contribution์ ์ฑ๋ฅ์ ์ ํ ์์ด sequence ๊ธธ์ด์ ๋ฐ๋ผ ์๊ฐ ๋ณต์ก๋ $O(n \sqrt[p]{n})$์ ๊ฐ์ง๋ attnetion ํ๋ ฌ์ ์ฌ๋ฌ sparse factorization์ ์๊ฐํ๋ค๋ ๊ฒ์ด๋ค. ์ด๋ฌํ ์์ ๋ค์ full attention ๊ณ์ฐ์ ์ฌ๋ฌ ๊ฐ์ faster attention์ผ๋ก ๋ถ๋ฆฌํ์ฌ ์๋ํ๋ฉฐ, ๊ฒฐํฉ ๋ ์์ dense attention ์ฐ์ฐ์ ๊ทผ์ ํ ์ ์๋ค. ๋ ผ๋ฌธ์์๋ ์ด๋ฅผ ์ฌ์ฉํ์ฌ ๊ธธ์ด๋ฅผ ๋ชจ๋ฅด๋ sequence์ self-attention์ ์ ์ฉํ์๋ค.
์ถ๊ฐ์ ์ผ๋ก ๋ ผ๋ฌธ์์๋ Transformer์ ๋ค์์ ์ฌ๋ฌ ๋ณํ๋ค์ ๊ฐํ์๋ค.
- residual block & ๊ฐ์ค์น ์ด๊ธฐํ ์ฌ๊ตฌ์กฐํ
- sparse attention kernel์ attention ํ๋ ฌ์ ์๋ธ์ ์ ํจ์จ์ ์ผ๋ก ๊ณ์ฐํจ
- backward pass์ผ ๋, attention ๊ฐ์ค์น๋ฅผ ์ฌ๊ณ์ฐํด์ ๋ฉ๋ชจ๋ฆฌ์ ์ฌ์ฉ๋์ ์ค์์.
๋ ผ๋ฌธ์์๋ ๋ชจ๋ธ์ด ์ด๋ฐ ๋ฐฉ์์ผ๋ก ์ฆ๊ฐ๋ ๋ชจ๋ธ์ด ์์ฐ์ด, ์ค๋์ค, ์ด๋ฏธ์ง์ ์์ฑ๊ณผ ์์ถ์์ SOTA๋ฅผ ๋ฌ์ฑํ ์ ์์์ ์คํ์ ์ผ๋ก ๊ฒ์ฆํ์๋ค. architecture์ ๋จ์์ฑ์ผ๋ก ์ธํด ๋ง์ task์ ๋ํด์ ์ ์ฉํ ๊ฒ์ด๋ผ๊ณ ์๊ฐํ๊ฒ ๋ง๋ ๋ค.
2. Background
๋ ผ๋ฌธ์์๋ autoregressive sequence ์์ฑ์ task๋ฅผ ๊ณ ๋ คํ์๋ค. ์ฌ๊ธฐ์ sequence $x = {x_1, x_2, \cdots, x_n}$์ ๊ฒฐํฉ ์กฐ๊ฑด๋ถ ํ๋ฅ ์ ์กฐ๊ฑด๋ถ ํ๋ฅ ๋ถํฌ์ ๊ณฑ์ผ๋ก ๋ชจ๋ธ๋ง ๋์๊ณ , ๋คํธ์ํฌ $\theta$์ ์ํด์ ํ๋ผ๋ฏธํฐํ ๋์๋ค.
์ด๋ฏธ์ง, ํ ์คํธ, ์ค๋์ค๋ฅผ ๋ณ๊ฐ์ ํ ํฐ์ฒ๋ผ ๋ค๋ฃจ์ด์ผ ํ๋ค. ๋คํธ์ํฌ $\theta$๋ token์ sequence๋ฅผ ์ ๋ ฅ์ผ๋ก ๋ฐ์์ next token์ผ๋ก ๊ฐ๋ฅํ $v$๊ฐ๋ค์ ๋ฒ์ฃผ ๋ถํฌ๋ฅผ ์ถ๋ ฅํ๋ค. ์ฌ๊ธฐ์ $v$๋ vocabulary์ ํฌ๊ธฐ์ด๋ค. ํ๋ จ ๋ชฉํ๋ ๋ฐ์ดํฐ์ ๋ก๊ทธ ํ๋ฅ ์ $\theta$์ ๊ดํ์ฌ ๊ทน๋ํ์ํค๋ ๊ฒ์ด๋ค.
๋ชจ๋ธ $\theta$์ ๋ํ ๊ฐ๋จํ๊ณ ๊ฐ๋ ฅํ ์ ํ์ decoder-only Transformer์ด๋ค. ์ด๋ฌํ ๋ชจ๋ธ์ ์ ์ฒด ์ํ์ค์ ๋ํ multihead self-attention ๋ธ๋ก์ ์ฌ์ฉํ์ฌ ์ ๋ ฅ ์ํ์ค๋ฅผ ๋ณํํ ๋ค์์ ๊ฐ ์ํ์ค ์์์ ๋ํ ์กฐ๋ฐํ ๋ณํ์ ์ํํ๋ค. ๋คํธ์ํฌ์ self-attention ๋ถ๋ถ์ ๊ฐ $n$ ์์์ ๋ํด $n$ ๊ฐ์ค์น๋ฅผ ๊ณ์ฐํด์ผ ํ์ง๋ง, ์ํ์ค์ ๊ธธ์ด๊ฐ ๋์ด๋จ์ ๋ฐ๋ผ ๋น ๋ฅด๊ฒ ๋ค๋ฃจ๊ธฐ ์ด๋ ค์์ง๋ค.
๋ค์ ์น์ ์์๋ Transformer architecture์ ๊ธด sequence์ ์ ์ ํ๊ฒ ๋ชจ๋ธ๋งํ ์ ์๋๋ก ๊ฐํด์ง ๋ณํ์ ๋ํด ์ค๋ช ํ๊ฒ ๋ค.
3. Factorized Self-Attention
Sparse Transformer๋ full self-attention ์ฐ์ฐ์ ๋ค์ ๊ทธ๋ฆผ 1์ b์ c์ฒ๋ผ ์ฌ๋ฌ ๋จ๊ณ์ attention์ผ๋ก ๋๋์๋ค.
3-1. Qualitative assessment of learned attention patterns
๋ค์์ ๊ทธ๋ฆผ 2์ 128-layer self-attention ๋คํธ์ํฌ๋ก ํ์ต๋ attention ํจํด์ CIFAR-10์ ํํํ์๋ค. ์๊ฐ์ ๊ฒ์ฌ๋ ๋๋ถ๋ถ์ ๋ ์ด์ด๊ฐ ๋๋ถ๋ถ์ ๋ฐ์ดํฐ ํฌ์ธํธ์์ sparse attention ํจํด์ ๊ฐ์ง๋ ๊ฒ์ ๋ณด์ฌ์คฌ๋ค. ์ด๋ ์ฑ๋ฅ์ ํฐ ์ํฅ์ ๋ฏธ์น์ง ์๊ณ ์ด๋ค ํํ์ sparsity๊ฐ ๋์ ๋ ์ ์์์ ๋ณด์ฌ์ค๋ค. ๊ทธ๋ฆผ 2์ c์์ ๋ช ๊ฐ์ ๋ ์ด์ด๋ ๋ช ํํ๊ฒ global ํจํด์ ๋ณด์ฌ์ฃผ์๊ณ , ๊ทธ๋ฆผ 2์ d๋ ๋ฐ์ดํฐ ์์กด์ ์ธ sparsity๋ฅผ ๋ณด์ฌ์คฌ๋ค. ๊ทธ๋ฆฌ๊ณ ์ด ๋์ ๋ชจ๋ attention ํ๋ ฌ์ ๋ฏธ๋ฆฌ ๊ฒฐ์ ๋ sparsity ํจํด์ ๋์ ํจ์ผ๋ก์จ ์ํฅ์ ๋ฐ๋๋ค.
3-2. Factorized self-attention
self-attention ๋ ์ด์ด๋ ์ ๋ ฅ ์๋ฒ ๋ฉ $X$์ ํ๋ ฌ์ ์ถ๋ ฅ ํ๋ ฌ๋ก ๋งคํํ๊ณ , ์ฐ๊ฒฐ ํจํด $S = {S_1, \cdots, S_n}$์ ์ํด ํ๋ผ๋ฏธํฐํ ๋๋ค. ์ฌ๊ธฐ์ $S_i$๋ $i$๋ฒ์งธ ์ถ๋ ฅ ๋ฒกํฐ๊ฐ ์ฐธ์กฐํ๋ ์ ๋ ฅ ๋ฒกํฐ์ ์ธ๋ฑ์ค ์งํฉ์ ๋ํ๋ธ๋ค. ์ถ๋ ฅ ๋ฒกํฐ๋ ๋ณํ ์ ๋ ฅ ๋ฒกํฐ์ ๊ฐ์คํฉ์ด๋ค.
์ฌ๊ธฐ์ $W_q, W_k, W_v$๋ ๊ฐ๊ฐ $\mathbf{x}_i$๊ฐ ์ฃผ์ด์ง๋ฉด query, key, value๋ก ๋ณํํ๋ ๊ฐ์ค์น ํ๋ ฌ์ ๋ํ๋ธ๋ค. ๊ทธ๋ฆฌ๊ณ ์ฌ๊ธฐ์ $d$๋ query์ key์ ๋ด๋ถ ์ฐจ์์ด๋ค. ๊ฐ ํฌ์ง์ ์์์ ์ถ๋ ฅ์ key์ query์ scaled dot-product์ ์ํด ๊ฐ์ค์น๊ฐ ๋ถ์ฌ๋ ๊ฐ์ ํฉ๊ณ์ด๋ค.
autoregressive model์ ์ํ full self-attetnion์ $S_i = {j: j \leq i}$์ ์ ์ํด์ ๋ชจ๋ ์์๋ค์ด ์ง์ธ์ ๋ชจ๋ ํฌ์ง์ ๊ณผ ์๊ธฐ ์์ ์ ํฌ์ง์ ์ ์ฐธ์กฐํ ์ ์๊ฒ ํด์ค๋ค.
factorized self-attention์ ๋์ ์ separate attention head $p$๋ฅผ ๊ฐ์ง๋ค. ์ฌ๊ธฐ์ $m$ ๋ฒ์งธ head๋ ์ธ๋ฑ์ค $A_{i}^{(m)} \subset {j : j \leq i}$์ ์๋ธ์ ์ ์ ์ํ๊ณ $S_i = A_{i}^{(m)}$์ ํ์ฉํด์ค๋ค. ๋ ผ๋ฌธ์์๋ ์ฃผ๋ก ์๋ธ์ $A$์ ๋ํ ํจ์จ์ ์ธ ์ ํ์ ๊ด์ฌ์ด ์๋ค. ์ฌ๊ธฐ์ $|A_{i}^{(m)}| \propto \sqrt[p]{n}$์ด๋ค.
๋ํ ๋น๋ถ๊ฐ ๋ชจ๋ ์ ๋ ฅ ์์น๊ฐ attention์ $p$๋จ๊ณ์ ๊ฑธ์ณ ๋ชจ๋ ๋ฏธ๋ ์ถ๋ ฅ ์์น์ ์ฐ๊ฒฐ๋๋ $A$์ ์ ํจํ ์ ํ์ ๊ณ ๋ คํ๋ค.
๋ชจ๋ $j \leq i$ ์์ ๋ํ์ฌ, ๋ ผ๋ฌธ์์๋ ์ต๋ ๊ธธ์ด๊ฐ $p + 1$์ธ ์์น์ ๊ฒฝ๋ก๋ฅผ ํตํด $i$๊ฐ $j$๋ฅผ ์ฐธ์กฐํ ์ ์๋๋ก ๋ชจ๋ $A$๋ฅผ ์ค์ ํ์๋ค. ํนํ $(j, a, b, c, \cdots, i)$๊ฐ ์ธ๋ฑ์ค๋ค์ ๊ฒฝ๋ก์ด๋ฉด $j \in A_{a}^{(1)}, a \in A_{b}^{(2)}, b \in A_{c}^{(3)}$์ด๋ค.
์ด๋ฌํ ๋ ๊ฐ์ ๊ธฐ์ค์ Transformer๊ฐ ์ผ์ ํ ์์ ๋จ๊ณ์์ ์์์ ์ ๋ ฅ ํฌ์ง์ ์ผ๋ก๋ถํฐ ์์์ ์ถ๋ ฅ ํฌ์ง์ ์ผ๋ก ์ ํธ๋ฅผ ์ ํํ ์ ์๋๋ก ํ๋ ๋ฅ๋ ฅ์ ์ ์งํ ์ ์๊ฒ ํด์ฃผ๋ฉด์ ์ด ๊ณ์ฐ๋์ $O(n \sqrt[p]{n})$์ผ๋ก ์ค์ฌ์คฌ๋ค. ๋ ผ๋ฌธ์์๋ ํ๋น์ฑ ๊ธฐ์ค์ ๋๊ทธ๋ฌ๋จ๋ฆฌ๋ ๊ฒ์ด ํน์ ๋๋ฉ์ธ์ ๋ํด ์ ์ฉํ inductive bias๊ฐ ๋ ์๋ ์๋ค๊ณ ๋งํ์๋ค.
3-3. Two-dimensional factorized attention
factorized attention ํจํด์ ๋ ๊ฐ์ ์ฐจ์์ผ๋ก ์ ์ํ๋ ์์ฐ์ค๋ฌ์ด ๋ฐฉ์์ ์ด์ $l$ location์ ์ฐธ์กฐํ๋ ํ๋์ head๋ฅผ ๊ฐ์ ธ์ผ ํ๊ณ , ๋ค๋ฅธ head๋ ๋ชจ๋ $l$๋ฒ์งธ location์ ์ฐธ์กฐํด์ผ ํ๋ค. ์ฌ๊ธฐ์ $l$์ stride์ด๊ณ $\sqrt{n}$์ ๊ฐ๊น๊ฒ ์ ํ๋์ด์ผ ํ๋ค. ์ด๋ฌํ ๋ฐฉ์์ strided attention์ด๋ผ๊ณ ๋ถ๋ฅธ๋ค.
๊ณต์์ ์ผ๋ก ๋ํ๋ด๋ฉด $t = max(0, i-l)$์ ๋ํด $A_{i}^{(1)} = {t, t+1, \cdots, i}$์ด๊ณ , $A_{i}^{(2)} = {j : (i - j) mod l = 0}$์ด๋ค. ์ด๋ฌํ ํจํด์ ์ ๊ทธ๋ฆผ 3์ b์ ๋ํ๋์๋ค.
์ด ๊ณต์์ ๋ฐ์ดํฐ๊ฐ ์ด๋ฏธ์ง ๋๋ ์์ ๊ฐ์ ์ ํ์ ์์ฐ์ ์ผ๋ก stride์ ์๋ง์ ๊ตฌ์กฐ๋ฅผ ๊ฐ์ง๊ณ ์๋ค๋ฉด ํธ๋ฆฌํ๋ค. ํ ์คํธ ๊ฐ์ด ์ฃผ๊ธฐ์ ๊ตฌ์กฐ๊ฐ ์๋ ๋ฐ์ดํฐ๋ ๋คํธ์ํฌ๊ฐ strided ํจํด์ ์ฌ์ฉํ์ฌ ์ ์ ํ๊ฒ ์ ๋ณด๋ฅผ ๋ณด๋ผ ์ ์๋ค. ๋ง์น ์์์ ๋ํ ๊ณต๊ฐ ์ขํ๋ ์์๊ฐ ๋ฏธ๋์ ๊ฐ์ฅ ๊ด๋ จ๋ ์ ์๋ ํฌ์ง์ ๊ณผ ๋ฐ๋์ ์๊ด๊ด๊ณ๊ฐ ์๋ ๊ฒ์ ์๋ ๊ฒ์ฒ๋ผ ๋ง์ด๋ค.
์ด๋ฌํ ๊ฒฝ์ฐ์๋, fixed attention ํจํด$($๊ทธ๋ฆผ 3์ c$)$์ ์ฌ์ฉํ๋ค. ์ฌ๊ธฐ์ ํน์ ์ ์ ์ด์ ์ location์ ์์ฝํ๊ณ ์ด ์ ๋ณด๋ฅผ ๋ชจ๋ ๋ฏธ๋ ์ ๋ก ์ ํํ๋ค.
๊ณต์์ ์ผ๋ก ๋ํ๋ด๋ฉด $A_{i}^{(1)} = {j : (\left \lfloor j/l \right \rfloor = \left \lfloor i/l \right \rfloor)}$์ธ๋ฐ, ์ฌ๊ธฐ์ ๊ดํธ๋ ๋ฐ๋ฅ ํจ์๋ฅผ ๋ํ๋ธ๋ค. ๊ทธ๋ฆฌ๊ณ $A_{i}^{(2)} = {j : j mod l \in {t, t+1, \cdots, l}}$์ธ๋ฐ, ์ฌ๊ธฐ์ $t = l - c$์ $c$๋ ํ์ดํผ ํ๋ผ๋ฏธํฐ์ด๋ค.
๋ง์ฝ stride๊ฐ 128์ด๊ณ $c = 8$์ด๋ฉด, ๋ชจ๋ ๋ฏธ๋ ํฌ์ง์ ์ 128๋ณด๋ค ํฌ๊ณ ํฌ์ง์ 120-128์ ์ฐธ์กฐํ ์ ์๋ค. ๋ง์ฝ ๋ชจ๋ ํฌ์ง์ ์ด 256๋ณด๋ค ํฌ๋ฉด 248-256์ ์ฐธ์กฐํ ์ ์๋ค.
$c=1$์ ์ฌ์ฉํ๋ fixed-attention ํจํด์ ๋คํธ์ํฌ์ ๋ง์ representation๋ค์ ์ค์ง ํ๋์ ๋ธ๋ก์๋ง ์ฌ์ฉ๋๋ ๋ฐ๋ฉด ์์์ ์์น๋ ๋ชจ๋ ๋ธ๋ก์์ ์ฌ์ฉ๋๋ ๊ฒ์ฒ๋ผ ๋คํธ์ํฌ์ ํํ์ฑ์ ์๋นํ ์ ํํ๋ค. ๊ทธ๋์ ๋ ผ๋ฌธ์์๋ ๋ณดํต $l \in {128, 256}$์ ๊ฐ์ ๊ฐ์ง ๋ $c \in {8, 16, 32}$๋ฅผ ์ฌ์ฉํ๋ ๊ฒ์ด ์ ๋นํ๋ค๊ณ ๋งํ๋ค. ๋น๋ก ์ด ๋ฐฉ์์ strided attention์ ๋นํด $c$์ ์ํด ๊ณ์ฐ ๋น์ฉ์ด ์ฆ๊ฐํ์ง๋ง ๋ง์ด๋ค.
๋ํ ์ฌ๋ฌ head๋ฅผ ์ฌ์ฉํ ๋ ํฌ๊ธฐ $l$์ ๋ธ๋ก ๋ด์์ ๊ธธ์ด๊ฐ $c$์ธ ๋ณ๊ฐ์ ํ์ ๋ธ๋ก์ ์ฐธ์กฐํ๋๋ก ํ๋ ๊ฒ์ด ๋์ผํ ํ์ ๋ธ๋ก์ ์ฐธ์กฐํ๋ ๊ฒ๋ณด๋ค ๋ ๋ซ๋ค๋ ๊ฒ์ ๋ฐ๊ฒฌํ๋ค.
4. Sparse Transformer
4-1. Factorized attention heads
์ผ๋ฐ์ ์ธ dense attention์ ์์ attend ํจ์์ ์ ํ ๋ณํ์ ๊ฐ๋จํ๊ฒ ์ํํ๋ค.
์ฌ๊ธฐ์ $W_p$๋ post-attention ๊ฐ์ค์น ํ๋ ฌ์ ๋ํ๋ธ๋ค. factorized self-attention์ ํตํฉํ๋ ๊ฐ์ฅ ๊ฐ๋จํ ๊ธฐ์ ์ ๊ฐ residual block์ ํ๋์ attention type์ ์ฌ์ฉํ๊ณ , ์ด๋ค์ ์์ฐจ์ ํน์ ํ์ดํผํ๋ผ๋ฏธํฐ์ ์ํด ๊ฒฐ์ ๋๋ ๋น์จ๋ก ๋ผ์๋ฃ๋ ๊ฒ์ด๋ค.
์ฌ๊ธฐ์ $r$์ ํ์ฌ residual block์ ์ธ๋ฑ์ค์ด๊ณ $p$๋ factorized attention head์ ์์ด๋ค.
๋ ๋ฒ์งธ ๋ฐฉ์์ ํ๋์ head๊ฐ factorized๋ ๋ head๊ฐ ์ฐธ์กฐํ ํฝ์ ์ ์์น์ ์ฐธ์กฐํ๋๋ก ํ๋ ๊ฒ์ด๋ค. ์ด๋ฅผ merged head๋ผ๊ณ ํ๋ค.
์ด๊ฒ์ ์ข ๋ ๊ณ์ฐ ์ง์ฝ์ ์ด์ง๋ง, ์ผ์ ํ factor์ ์ํด์๋ง ๊ฐ๋ฅํ๋ค. ์ธ ๋ฒ์งธ ๋ฐฉ์์ multi-head attention์ ์ฌ์ฉํ๋ ๊ฒ์ด๋ค. ์ฌ๊ธฐ์ $n_h$ attention ๊ณฑ์ ๋ณ๋ ฌ๋ก ๊ณ์ฐ๋ ๋ค์ feature dimension๊ณผ ํจ๊ป ํฉ์ณ์ง๋ค.
์ฌ๊ธฐ์ $A$๋ ๋ถ๋ฆฌ๋ attention ํจํด, merged ํจํด ํน์ attend ํจ์์ ๊ฐ์ด ๋ผ์์ง ํจํด์ผ ์๋ ์๋ค. ๋ํ attend ํจ์์ ์์ ์๋ ๊ฐ์ค์น ํ๋ ฌ์ ์ฐจ์์ $1/n_h$์ ๊ณ์๋ก ์ค์ด๋ค์ด ํ๋ผ๋ฏธํฐ์ ์๊ฐ $n_h$์ ๊ฐ์ ๋ฐ๋ผ ๋ณํ์ง ์๋๋ค. ๋ ผ๋ฌธ์์๋ ์ผ๋ฐ์ ์ผ๋ก ์ฌ๋ฌ head๊ฐ ์ ์๋ํ๋ ๊ฒ์ ๋ฐ๊ฒฌํ์์ง๋ง, ๋งค์ฐ ๊ธด ์ํ์ค์ ๊ฒฝ์ฐ ํ ๋ฒ์ ํ๋์ฉ ์์ฐจ์ ์ผ๋ก ์ํํ๋ ๊ฒ์ด ๋ ๊ฐ์น๊ฐ ์๋ค.
4-2. Scaling to hundreds of layers
Transformer๊ฐ ์ฌ๋ฌ ๋ ์ด์ด๋ฅผ ์ฌ์ฉํ์ฌ ํ์ตํ๋ ๊ฒ์ ์ด๋ ค์์ด ์๋ค๋ ๊ฒ์ ๋ฐ๊ฒฌํ์๋ค. auxillary loss๋ฅผ ํฌํจํ๋ ๋์ ์, ๋ค์์ ๊ตฌ์กฐ์ ๋ณํ๋ฅผ ์ ์ฉํ์๋ค.
์ฒซ ๋ฒ์งธ๋ก, ๋ ผ๋ฌธ์์๋ pre-activation residual block์ ์ฌ์ฉํ์๋ค. ๊ทธ๋์ $N$ ๋ ์ด์ด์ ๋คํธ์ํฌ๋ ๋ค์๊ณผ ๊ฐ์ ๋ฐฉ๋ฒ์ผ๋ก ์ ์๋์๋ค.
์ฌ๊ธฐ์ embed๋ ๋ค์ ์น์ ์์ ์ค๋ช ํ ํจ์์ด๊ณ , $W_{out}$์ ๊ฐ์ค์น ํ๋ ฌ, $resblock(h)$๋ attention ๋ธ๋ก์ ์ ๋ ฅ๊ณผ position wise feedforward ๋คํธ์ํฌ๋ฅผ ๋ค์๊ณผ ๊ฐ์ ๋ฐฉ์์ผ๋ก ์ ๊ทํํ๋ค.
norm ํจ์๋ layer Normalization์ ๋ํ๋ด๊ณ , $ff(x) = W_{2}f(W_{1}x + b_{1}) + b_{2}$์ด๋ค. $f$๋ก๋ Gaussian Error Linear Unit$($GELU$)$๊ฐ ์ฌ์ฉ๋์๋๋ฐ, $f(X) = X \odot sigmoid(1.702 \cdot X)$์ ํ์์ผ๋ก ์ฌ์ฉ๋๋ค. ๋ฐ๋ก ์ ์๋์ด ์์ง ์์ผ๋ฉด $W_{1}$์ ์ถ๋ ฅ ์ฐจ์์ ์ ๋ ฅ ์ฐจ์์ 4.0๋ฐฐ์ด๋ค.
$H_{N}$์ด ํจ์ $a$์ $b$์ $N$ ์์ฉ์ ํฉ์ด๊ณ , ๋ฐ๋ผ์ ๊ฐ๊ฐ์ ํจ์ ๋ธ๋ก์ ์ถ๋ ฅ ๋ ์ด์ด๋ก๋ถํฐ ์ง์ ๊ธฐ์ธ๊ธฐ๋ฅผ ๋ฐ๊ฒ ๋๋ค. ๋ ผ๋ฌธ์์๋ $W_2$์ $W_{p}$์ ์ด๊ธฐํ๋ฅผ $\frac {1}{\sqrt{2N}}$์ ์ํด scaleํ๋ฉด์ input embedding scale ๋ residual block scale์ ๋น์จ์ $N$ ๊ฐ์ ๊ฑธ์ณ ๋ถ๋ณ์ผ๋ก ์ ์งํ๋ค.
4-3. Modeling diverse data types
input symbol์ ์๋ฒ ๋ฉ์ ๋ํ์ฌ positional encoding์ ๋ณดํต Transformer์ ๋ค๋ฅธ location-agnostic architecture์ ๋ฐ์ดํฐ์ ๊ณต๊ฐ์ ๊ด๊ณ๋ฅผ ์ธ์ฝ๋ํ๊ธฐ ์ํด ์ฌ์ฉ๋๋ค. ๋ฐ์ดํฐ์ ๊ตฌ์กฐ ๋๋ factorized attention ํจํด์ ์ธ์ฝ๋ํ๋ ํ์ต๋ ์๋ฒ ๋ฉ์ ์ฌ์ฉํ๋ ๊ฒ์ ๋ชจ๋ธ์ ์ฑ๋ฅ์ ๋งค์ฐ ์ค์ํ๋ค.
๋ ผ๋ฌธ์์๋ $n_{emb} = d_{data}$ ๋๋ $n_{emb} = d_{attn}$ ์๋ฒ ๋ฉ์ ๊ฐ input location์ ์ถ๊ฐํ๋ค. ์ฌ๊ธฐ์ $d_{data}$๋ ๋ฐ์ดํฐ์ ์ฐจ์์ ์์ด๊ณ , $d_{attn}$์ factorized attention์ ์ฐจ์์ ์์ด๋ค. ๋ง์ฝ $\mathbf{x}_{i}$๊ฐ ์ํ์ค์์ $i$๋ฒ์งธ ์์์ one-hot ์ธ์ฝ๋ฉ๋ ๊ฐ์ด๊ณ , $\mathbf{o}_{i}^{(j)}$๋ $j$๋ฒ์งธ ์ฐจ์์์ $\mathbf{x}_{i}$์ one-hot ์ธ์ฝ๋ฉ๋ ํฌ์ง์ ์ ๋ํ๋ธ๋ค.
์ด๋ฏธ์ง์ ๋ํด์๋ ์ด๊ณผ ํ, ์ฑ๋์ ํํํ๊ธฐ ์ํด $d_{data} = 3$์ ์ฌ์ฉํ๊ณ , ํ ์คํธ์ ์ค๋์ค์ ๋ํด์๋ 2์ฐจ์ attention embedding์ ์ฌ์ฉํด์ $d_{attn} = 2$๋ฅผ ์ฌ์ฉํ๋ค.
4-4. Saving memory by recomputing attention weights
๊ธฐ์ธ๊ธฐ๋ฅผ ์ฒดํฌํฌ์ธํธ ํด๋๋ ๊ฒ์ ์ฌ์ธต ์ ๊ฒฝ๋ง์ ํ์ตํ๊ธฐ ์ํด ์๊ตฌ๋๋ ๋ฉ๋ชจ๋ฆฌ๋ฅผ ์ค์ด๋ ํจ๊ณผ์ ์ธ ๋ฐฉ๋ฒ์ผ๋ก ๋ณด์ฌ์ง๋ค. ์ด ๊ธฐ์ ์ ๋ฑํ ํจ๊ณผ๊ฐ ์์ง๋ง, ๊ธด ์ํ์ค๊ฐ ์ฒ๋ฆฌ๋ ๋ self-attention ๋ ์ด์ด์ ๋ํด ํนํ ํจ๊ณผ์ ์ด๋ค. ์ด๋ฌํ ๋ ์ด์ด์ ์ปดํจํ ๋น์ฉ์ ๋นํด ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋์ด ๋๊ธฐ ๋๋ฌธ์ด๋ค.
์ฌ๊ณ์ฐ๋ง์ ํผ์ ์ฌ์ฉํ๋ฉด 16,384์ ์ํ์ค ๊ธธ์ด์ ์๋ฐฑ๊ฐ์ ๋ ์ด์ด๋ฅผ ์ฌ์ฉํ๋ dense attention ๋คํธ์ํฌ๋ฅผ ํ์ต์ํฌ ์ ์๋ค. ๊ทธ๋ ์ง ์์ผ๋ฉด ์ต์ ํ๋์จ์ด์์๋ ์คํ์ด ๋ถ๊ฐ๋ฅํ๋ค.
์คํ์์๋ backward pass ๋์ค์ attention๊ณผ feedforward block์ ์ฌ๊ณ์ฐํ์๋ค. ์ํ์ ๊ฐ๋จํ๊ฒ ํ๊ธฐ ์ํด attention block ๋ด์์ dropout์ ์ ์ฉํ์ง ์์๊ณ , ๋์ ์ ๊ฐ residual ์ถ๊ฐ์ ๋ง์ง๋ง์ ์ ์ฉํ์๋ค. ์ด๋ ๊ทธ๋ฆผ 3์ ๋ํ๋์๋ค.
4-5. Mixed-precision training
๋ ผ๋ฌธ์์๋ ๋คํธ์ํฌ์ ๊ฐ์ค์น๋ฅผ single-precision floating-point์ ์ ์ฅํ์๋๋ฐ, ๊ทธ๋ ์ง ์์ ๊ฒฝ์ฐ ๋คํธ์ํฌ activation ๋ฐ ๊ธฐ์ธ๊ธฐ๋ฅผ half-precision์ผ๋ก ๊ณ์ฐํ๋ค. ์ด๊ฒ์ ํ์ต์ ๊ฐ์ํ์ํจ๋ค.
5. Training
๋ชจ๋ ์๋ฒ ๋ฉ์ ์ผ์ ํ ์ฐจ์ $d$์ด๊ณ , ๋ณดํต ${256, 512, 1024}$์ค ํ๋์ด๋ค. ๊ธฐ๋ณธ์ ์ผ๋ก ๋ชจ๋ ์ ํ ๋ณํ์ ์ ๋ ฅ์ 4d๋ก ํฌ์ํ๋ feedforward ๋คํธ์ํฌ๋ฅผ ์ ์ธํ๊ณ ๋ ๋์ผํ ์ฐจ์์ด๋ค. ๋ํ, ๊ฐ๋ query์ key ๋ณํ์ ํฌ๊ธฐ๋ ๋ฐ์ผ๋ก ์ค์ธ๋ค.
6. Experiments
๋ ผ๋ฌธ์์๋ architecture์ ์ด๋ฏธ์ง, ์์ฐ์ด, ์ค๋์ค๋ฅผ ํฌํจํ๋ density modeling task์ ๋ํด ํ๊ฐํ์๋ค. ์ด์ ๋ํ ๊ฒฐ๊ณผ์ ์์ฝ์ ๋ค์์ ํ 1์ ๋ํ๋์๋ค.
์ถ์ฒ
https://arxiv.org/abs/1904.10509
https://sh-tsang.medium.com/review-sparse-transformer-80cbba4ebaa4
'Paper Reading ๐ > Natural Language Processing' ์นดํ ๊ณ ๋ฆฌ์ ๋ค๋ฅธ ๊ธ
GPT-4 Techinal Report Review (0) | 2023.03.28 |
---|---|
BigBird: Transformers for Longer Sequences ๋ ผ๋ฌธ ๋ฆฌ๋ทฐ (0) | 2023.03.25 |
GPT-3: Language Models are Few-Shot Learners ๋ ผ๋ฌธ ๋ฆฌ๋ทฐ (0) | 2023.03.21 |
TinyBERT: Distilling BERT for Natural Language Understanding ๋ ผ๋ฌธ ๋ฆฌ๋ทฐ (0) | 2023.03.12 |
Pre-LN Transformer: On Layer Normalization in the Transformer Architecture ๋ ผ๋ฌธ ๋ฆฌ๋ทฐ (2) | 2023.03.09 |