지난 번에 공부했던 GQA(Grouped Query Attention)를 더 알아보기 위해 GQA가 적용된 LLaMA2 모델의 코드를 살펴보았다.
▲ GQA paper 정리
https://github.com/meta-llama/llama
▲ LLaMA2 Code
아래는 llama/llama/model.py 안의 코드 중 GQA가 활용되는 일부를 정리한 것이다.
1. ModelArgs 클래스
@dataclass
class ModelArgs:
dim: int = 4096
n_layers: int = 32
n_heads: int = 32
n_kv_heads: Optional[int] = None
...
먼저 ModelArgs 클래스에서 모델의 하이퍼파라미터를 설정한다. dim은 모델의 차원 수, n_layers는 레이어 수, n_heads는 어텐션 헤드 수, n_kv_heads는 key와 value의 헤드 수이다. n_kv_heads가 Optional로 지정되어 있고 값이 들어가있지 않다. 기본적으로는 n_heads를 사용하지만 GQA를 사용하는 경우에는 n_kv_heads 값을 다르게 설정한다.
2. Attention 클래스
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
...
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
...
self.n_rep = self.n_local_heads // self.n_local_kv_heads
...
Attention 클래스에서 Multi head attention을 구현한다. 일반적인 MHA일 경우 Query, Key, Value는 동일한 수의 헤드를 사용한다. 따라서 n_kv_heads에 할당된 값이 없다면 n_kv_heads는 n_heads가 된다.
n_kv_heads가 None이 아닐 경우 GQA로 구현된다.
n_rep에서는 Key와 Value의 헤드를 반복하는 횟수를 정의한다. 쿼리 헤드 수를 키-밸류 헤드 수로 나눈 값이다.
3. repeat_kv 클래스
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
"""torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
bs, slen, n_kv_heads, head_dim = x.shape
if n_rep == 1:
return x
return (
x[:, :, :, None, :]
.expand(bs, slen, n_kv_heads, n_rep, head_dim)
.reshape(bs, slen, n_kv_heads * n_rep, head_dim)
)
repeat_kv 함수는 Key 헤드와 Value 헤드를 여러 번 사용할 수 있도록 반복한다.
▶ line 별 코드 설명
bs, slen, n_kv_heads, head_dim = x.shape
입력 텐서 x의 모양을 설정한다.
- bs: 배치 크기
- slen: 시퀀스 길이
- n_kv_heads: 키와 밸류의 어텐션 헤드 수
- head_dim: 각 어텐션 헤드의 차원 수
ex) (bs, slen, 8, 64)
n_kv_heads이고 8이고, head_dim은 64이다.
x[:, :, :, None, :]
입력 텐서 x의 4번째 차원(3번째 인덱스)에 차원을 추가한다. n_rep이 들어갈 공간이다. n_rep은 쿼리 헤드 수를 키-밸류 헤드 수로 나눈 값이다. 위에서 보았듯 __init__메서드에서 self.n_rep = self.n_local_heads // self.n_local_kv_heads 로 정의된다.
ex) (bs, slen, n_kv_heads, n_rep, 1, head_dim)
.expand(bs, slen, n_kv_heads, n_rep, head_dim)
expand 메소드를 사용해 n_kv_heads의 차원을 n_rep만큼 반복해서 n_heads와 동일한 수의 헤드를 만든다. expand 메소드는 특정 차원을 반복해서 텐서의 크기를 확장하기 위해 사용된다. 실제로 메모리를 복사해서 같은 텐서를 4개 할당하는 것은 아니고, 기존 데이터를 여러 번 참조하도록 한다. 각 Query 헤드가 동일한 Key, Value를 참조할 수 있도록 하는 것이다.
ex) (bs, slen, n_kv_heads, n_rep, head_dim)에서 n_kv_heads가 8, n_heads가 32이면 n_rep가 4이므로 Key, Value 헤드는 4번씩 반복된다.
.reshape(bs, slen, n_kv_heads * n_rep, head_dim)
reshape 메소드를 사용해 텐서의 세 번째, 네 번째 차원을 합쳐 하나의 차원으로 만든다. n_kv_heads * n_rep이 곱해져 이는 n_heads와 같은 값이 된다. 그래야 이제 쿼리 텐서의 차원과 맞아 연산이 가능해진다.
4. transformerBlock 클래스
class TransformerBlock(nn.Module):
def __init__(self, layer_id: int, args: ModelArgs):
...
self.attention = Attention(args)
...
TransformerBlock 클래스는 Attention 레이어와 FeedForward 레이어로 구성된 하나의 transformer 블록이다. 여기서 어텐션 매커니즘이 활용된다.
큰 그림이 잘 그려지지 않아서 예시를 통해 살펴보았다.
"The cat sat on the mat" 이라는 문장이 입력으로 주어졌다고 하자. 쿼리의 어텐션 헤드 수(n_heads)는 4, 키-밸류의 어텐션 헤드 수(n_kv_heads)는 2이다. n_rep은 4/2인 2가 된다.
# 임의의 쿼리, 키, 밸류 텐서를 생성
bsz, seqlen, head_dim = 1, 6, 3 # 배치 크기 1, 시퀀스 길이 6 ("The cat sat on the mat"), 헤드 차원 3
n_heads, n_kv_heads = 4, 2
# (1, 6, 4, 3): 쿼리 텐서
xq = torch.rand(bsz, seqlen, n_heads, head_dim)
# (1, 6, 2, 3): 키-밸류 텐서 (n_kv_heads < n_heads)
xk = torch.rand(bsz, seqlen, n_kv_heads, head_dim)
xv = torch.rand(bsz, seqlen, n_kv_heads, head_dim)
각 단어에 대해서 쿼리, 키, 밸류를 생성한다. 각 단어마다 3차원의 4개의 쿼리 헤드를 가지고, 2개의 키와 2개의 밸류를 가진다.
def repeat_kv(x, n_rep):
# 차원 확장: (1, 6, 2, 1, 3)
x = x[:, :, :, None, :]
# n_rep = 2만큼 복제: (1, 6, 2, 2, 3)
x = x.expand(bsz, seqlen, n_kv_heads, n_rep, head_dim)
# 차원 결합: (1, 6, 4, 3)
x = x.reshape(bsz, seqlen, n_kv_heads * n_rep, head_dim)
return x
# 키와 밸류를 쿼리 헤드 수에 맞게 반복
xk = repeat_kv(xk, n_rep=2)
xv = repeat_kv(xv, n_rep=2)
키-밸류 텐서의 헤드 수가 쿼리의 헤드 수와 같지 않기 때문에 repeat_kv 함수로 키-밸류 헤드를 반복한다. 반복 후에는 쿼리의 각 헤드가 하나의 키-밸류 헤드와 대응할 수 있게 된다.
# 쿼리와 키 간 어텐션 점수 계산
scores = torch.matmul(xq, xk.transpose(-2, -1)) / math.sqrt(head_dim) # (1, 6, 4, 4)
# 소프트맥스로 어텐션 확률 구하기
attention_weights = torch.softmax(scores, dim=-1)
# 각 쿼리 헤드가 대응하는 키-밸류 헤드에 접근하여 최종 값을 구함
output = torch.matmul(attention_weights, xv) # (1, 6, 4, 3)
쿼리 텐서와 반복된 키-밸류 텐서 간의 어텐션 연산이 가능해진다. 이렇게 하여 쿼리의 각 헤드가 키-밸류의 반복된 헤드와 매칭될 수 있다.
'엣지컴퓨팅' 카테고리의 다른 글
Llama3 한국어 요약 task 실습 (Colab) (6) | 2024.11.08 |
---|---|
[논문 읽기] LLaMA: Open and Efficient Foundation Language Models (4) | 2024.10.14 |
[논문읽기] GQA: Training Generalized Multi-Query Transformer Models fromMulti-Head Checkpoints (0) | 2024.10.07 |
[논문 읽기] Gemma: Open Models Based on GeminiResearch and Technology (4) | 2024.09.25 |
[MIT 6.5940] EfficientML.ai Lec03: Pruning and Sparsity (1) | 2024.09.18 |