엣지컴퓨팅

LLaMA2의 GQA 코드 살펴보기

ima9ine4 2024. 10. 30. 00:26
728x90

지난 번에 공부했던 GQA(Grouped Query Attention)를 더 알아보기 위해 GQA가 적용된 LLaMA2 모델의 코드를 살펴보았다. 

2024.10.07 - [엣지컴퓨팅] - [논문읽기] GQA: Training Generalized Multi-Query Transformer Models fromMulti-Head Checkpoints

 

[논문읽기] GQA: Training Generalized Multi-Query Transformer Models fromMulti-Head Checkpoints

Ainslie, J., Lee-Thorp, J., de Jong, M., Zemlyanskiy, Y., Lebrón, F., & Sanghai, S. (2023). Gqa: Training generalized multi-query transformer models from multi-head checkpoints. arXiv preprint arXiv:2305.13245. (291회 인용)https://arxiv.org/abs/2305.1

ima9ine.tistory.com

▲ GQA paper 정리


https://github.com/meta-llama/llama

 

GitHub - meta-llama/llama: Inference code for Llama models

Inference code for Llama models. Contribute to meta-llama/llama development by creating an account on GitHub.

github.com

▲ LLaMA2 Code


아래는 llama/llama/model.py 안의 코드 중 GQA가 활용되는 일부를 정리한 것이다.

1. ModelArgs 클래스

@dataclass
class ModelArgs:
    dim: int = 4096
    n_layers: int = 32
    n_heads: int = 32
    n_kv_heads: Optional[int] = None
    ...

먼저 ModelArgs 클래스에서 모델의 하이퍼파라미터를 설정한다. dim은 모델의 차원 수, n_layers는 레이어 수, n_heads는 어텐션 헤드 수, n_kv_heads는 key와 value의 헤드 수이다. n_kv_heads가 Optional로 지정되어 있고 값이 들어가있지 않다. 기본적으로는 n_heads를 사용하지만 GQA를 사용하는 경우에는 n_kv_heads 값을 다르게 설정한다.


2. Attention 클래스

class Attention(nn.Module):
    def __init__(self, args: ModelArgs):
        ...
        self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
        ...
        self.n_rep = self.n_local_heads // self.n_local_kv_heads
        ...

Attention 클래스에서 Multi head attention을 구현한다. 일반적인 MHA일 경우 Query, Key, Value는 동일한 수의 헤드를 사용한다. 따라서 n_kv_heads에 할당된 값이 없다면 n_kv_heads는 n_heads가 된다. 

n_kv_heads가 None이 아닐 경우 GQA로 구현된다. 

n_rep에서는 Key와 Value의 헤드를 반복하는 횟수를 정의한다. 쿼리 헤드 수를 키-밸류 헤드 수로 나눈 값이다.


3.  repeat_kv 클래스

def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
    """torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
    bs, slen, n_kv_heads, head_dim = x.shape
    if n_rep == 1:
        return x
    return (
        x[:, :, :, None, :]
        .expand(bs, slen, n_kv_heads, n_rep, head_dim)
        .reshape(bs, slen, n_kv_heads * n_rep, head_dim)
    )

repeat_kv 함수는 Key 헤드와 Value 헤드를 여러 번 사용할 수 있도록 반복한다. 


▶ line 별 코드 설명

bs, slen, n_kv_heads, head_dim = x.shape

 

입력 텐서 x의 모양을 설정한다. 

  • bs: 배치 크기
  • slen: 시퀀스 길이
  • n_kv_heads: 키와 밸류의 어텐션 헤드 수
  • head_dim: 각 어텐션 헤드의 차원 수

ex) (bs, slen, 8, 64)
n_kv_heads이고 8이고, head_dim은 64이다. 

 

 

x[:, :, :, None, :]

입력 텐서 x의 4번째 차원(3번째 인덱스)에 차원을 추가한다. n_rep이 들어갈 공간이다. n_rep은 쿼리 헤드 수를 키-밸류 헤드 수로 나눈 값이다. 위에서 보았듯 __init__메서드에서 self.n_rep = self.n_local_heads // self.n_local_kv_heads 로 정의된다.

ex) (bs, slen, n_kv_heads, n_rep, 1, head_dim)

 

.expand(bs, slen, n_kv_heads, n_rep, head_dim)

expand 메소드를 사용해 n_kv_heads의 차원을 n_rep만큼 반복해서 n_heads와 동일한 수의 헤드를 만든다. expand 메소드는 특정 차원을 반복해서 텐서의 크기를 확장하기 위해 사용된다. 실제로 메모리를 복사해서 같은 텐서를 4개 할당하는 것은 아니고, 기존 데이터를 여러 번 참조하도록 한다. 각 Query 헤드가 동일한 Key, Value를 참조할 수 있도록 하는 것이다. 

ex) (bs, slen, n_kv_heads, n_rep, head_dim)에서 n_kv_heads가 8, n_heads가 32이면 n_rep가 4이므로 Key, Value 헤드는 4번씩 반복된다.

 

.reshape(bs, slen, n_kv_heads * n_rep, head_dim)

reshape 메소드를 사용해 텐서의 세 번째, 네 번째 차원을 합쳐 하나의 차원으로 만든다. n_kv_heads * n_rep이 곱해져 이는 n_heads와 같은 값이 된다. 그래야 이제 쿼리 텐서의 차원과 맞아 연산이 가능해진다.

 


4.  transformerBlock 클래스

class TransformerBlock(nn.Module):
    def __init__(self, layer_id: int, args: ModelArgs):
        ...
        self.attention = Attention(args)
        ...

TransformerBlock 클래스는 Attention 레이어와 FeedForward 레이어로 구성된 하나의 transformer 블록이다. 여기서 어텐션 매커니즘이 활용된다. 

 


큰 그림이 잘 그려지지 않아서 예시를 통해 살펴보았다.

"The cat sat on the mat" 이라는 문장이 입력으로 주어졌다고 하자. 쿼리의 어텐션 헤드 수(n_heads)는 4, 키-밸류의 어텐션 헤드 수(n_kv_heads)는 2이다. n_rep은 4/2인 2가 된다. 

# 임의의 쿼리, 키, 밸류 텐서를 생성
bsz, seqlen, head_dim = 1, 6, 3  # 배치 크기 1, 시퀀스 길이 6 ("The cat sat on the mat"), 헤드 차원 3
n_heads, n_kv_heads = 4, 2

# (1, 6, 4, 3): 쿼리 텐서
xq = torch.rand(bsz, seqlen, n_heads, head_dim)
# (1, 6, 2, 3): 키-밸류 텐서 (n_kv_heads < n_heads)
xk = torch.rand(bsz, seqlen, n_kv_heads, head_dim)
xv = torch.rand(bsz, seqlen, n_kv_heads, head_dim)

각 단어에 대해서 쿼리, 키, 밸류를 생성한다. 각 단어마다 3차원의 4개의 쿼리 헤드를 가지고, 2개의 키와 2개의 밸류를 가진다. 

 

def repeat_kv(x, n_rep):
    # 차원 확장: (1, 6, 2, 1, 3)
    x = x[:, :, :, None, :]
    # n_rep = 2만큼 복제: (1, 6, 2, 2, 3)
    x = x.expand(bsz, seqlen, n_kv_heads, n_rep, head_dim)
    # 차원 결합: (1, 6, 4, 3)
    x = x.reshape(bsz, seqlen, n_kv_heads * n_rep, head_dim)
    return x

# 키와 밸류를 쿼리 헤드 수에 맞게 반복
xk = repeat_kv(xk, n_rep=2)
xv = repeat_kv(xv, n_rep=2)

키-밸류 텐서의 헤드 수가 쿼리의 헤드 수와 같지 않기 때문에 repeat_kv 함수로 키-밸류 헤드를 반복한다. 반복 후에는 쿼리의 각 헤드가 하나의 키-밸류 헤드와 대응할 수 있게 된다. 

 

# 쿼리와 키 간 어텐션 점수 계산
scores = torch.matmul(xq, xk.transpose(-2, -1)) / math.sqrt(head_dim)  # (1, 6, 4, 4)

# 소프트맥스로 어텐션 확률 구하기
attention_weights = torch.softmax(scores, dim=-1)

# 각 쿼리 헤드가 대응하는 키-밸류 헤드에 접근하여 최종 값을 구함
output = torch.matmul(attention_weights, xv)  # (1, 6, 4, 3)

쿼리 텐서와 반복된 키-밸류 텐서 간의 어텐션 연산이 가능해진다. 이렇게 하여 쿼리의 각 헤드가 키-밸류의 반복된 헤드와 매칭될 수 있다. 

728x90
반응형