2024년 2월 4일 일요일

트랜스포머 디코더 핵심 코드 구현을 통한 동작 메커니즘 이해하기

이 글은 생성AI의 핵심인 딥러닝 모델 트랜스포머(Transformer) 구현 메커니즘을 코드 수준에서 이해하기 위해, 트랜스포머의 작동 과정을 상세히 설명한다. 이 글은 앞서 설명한 인코더에 이어, 디코더를 구현하는 방법을 나눔한다.
트랜스포머 인코더-디코더 모델 적용된 생성 AI 멀티모달 Stable Diffusion 모델(Generative AI Models in Image Generation: Overview - Synthesis AI)

이 과정을 통해, 텍스트 번역기 등 앞뒤 문맥 관계가 있는 딥러닝 모델을 직접 개발하거나, 이와 유사한 멀티모델 데이터 스트림을 다른 형식으로 변환(트랜스폼)시킬 수 있는 스테이블 디퓨전(Stable Diffusion)같은 생성AI 모델을 개발할 수 있다. 실제, 트랜스포머는 텍스트에서 비전, 음성, 비디오 데이터로 맵핑하는 주요 컴포넌트 중 하나로 사용된다. 

관련 내용이 좀 많아, 글을 인코더 부분(참고)과 디코더 부분크게 두 개로 나누어 진행하도록 한다. 이 글은 트랜스포머 디코더 구현 방법에 대한 글이다.  

이 글에서 표시된 트랜스포머 내부 실행 소스 코드는 다음 링크에서 다운로드할 수 있다. 
딥러닝 및 컴퓨터 비전에 대한 개념은 다음 링크를 참고한다. 
이 글은 많은 레퍼런스들을 참고해, 가능한 트랜스포머 동작방식을 이해하기 쉽도록 정리한 것이다. 관련해 궁금하다면, 이 글 마지막에 있는 레퍼런스들을 살펴보길 바란다. 

디코딩 처리 순서
디코더는 앞서 인코더에서 설명하였듯이, 텍스트 입력 토큰 각각 임베딩 벡터로 처리되어 디코더에 입력된다. 

영어 > 프랑스어 번역의 경우, 프랑어 텍스트 토큰이 입력된다고 생각하면 된다. 토큰 당 임베딩벡터 크기는 구글 논문에서 언급한 512를 사용한다.

리마인드 차원에서 앞서 설명한 디코드 단계를 한번 더 확인해 보자.
  1. Output Embedding: 입력 데이터를 토큰으로 구분하고 임베딩한다.  
  2. Positional Encoding: 단어의 순서를 표현하는 위치를 인코딩해, 임베딩 벡터에 포함해준다.
  3. Masked Multi-Head Attention: 디코딩에서는 다음 단어가 예측되도록, 앞의 단어 임베딩 벡터에 해당하는 계산은 Mask 처리해야 한다. 입력된 인코딩 벡터를 8개의 Multi-Head 벡터로 나누고, 다시 Query, Key, Value 벡터로 나누어, 어텐션을 코사인 벡터 유사도 함수 계산으로 해결한다. 
  4. Add & Normal: 잔차 연결과 정규화를 수행한다. 
  5. Multi-Head Attention: 인코더 출력은 Key, Value 값으로 디코더 Multi-Head Attention에 입력한다. Q는 앞에서 출력을 입력한다. 이를 통해, 인코더 어텐션과 디코딩될 Q가 함께 고려되어, 학습된다. 
  6. Add & Normal: 잔차 연결과 정규화를 수행한다. 
  7. Forward Feedback: 포워드 신경망을 연결해, 가중치를 계산한다. 
  8. Add & Normal: 6번 단계와 동일하다.
  9. Linear: 선형 레이어로 가중치를 계산한다. 
  10. Softmax: 과거 토큰 A에 대한 미래 토큰 B의 확률을 얻기 위해, softmax를 적용한다.
  11. 단어 예측 결과를 출력한다.
이 과정은 다음 그림과 같이, 인코더와 유사하다. 단, 3, 5번 과정에서 인코더에서 출력된 K, V 값을 입력받는 다는 점만 차이가 있다. 

트랜스포머 디코더 학습 예시

멀티헤드 크로스 어텐션 처리
디코더는 인코더의 K, V값을 입력받고, 디코더의 4번에서 인코더 출력된 Q값을 입력받아, 어텐션 계산을 한다. 인코더와 수식은 동일하다. 

코드 구현 및 설명은 다음과 같다('()'안은 차원 shape를 말함).
class MultiHeadCrossAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.d_model = d_model                       # 차원수
        self.num_heads = num_heads                # 멀티헤더 수. 논문에선 8 사용
        self.head_dim = d_model // num_heads   # 헤더 당 차원수. 512/8=64가 됨.
        self.kv_layer = nn.Linear(d_model , 2 * d_model)  # 가중치 학습용 KV텐서 신경망 준비. (512, 2*512)
        self.q_layer = nn.Linear(d_model , d_model)        # 가중치 학습용 Q신경망 준비. (512,512) 
        self.linear_layer = nn.Linear(d_model, d_model)    # 선형 레이어 준비. (512,512)
    
    def forward(self, x, y, mask):  # x는 인코더의 KV 벡터값, y는 Q벡터값 입력임
        batch_size, sequence_length, d_model = x.size() # 배치크기, 시퀀스크기=200, 512
        kv = self.kv_layer(x)   # 입력값(예. 영어) X에서 KV 레이어 계산
        q = self.q_layer(y)     # 목표값(예. 독일어) Y에서 Q레이어 계산
        # 어텐션 계산을 위해 텐서 행렬 모양 reshape
        kv = kv.reshape(batch_size, sequence_length, self.num_heads, 2 * self.head_dim)
        q = q.reshape(batch_size, sequence_length, self.num_heads, self.head_dim)
        kv = kv.permute(0, 2, 1, 3)
        q = q.permute(0, 2, 1, 3)
        k, v = kv.chunk(2, dim=-1)  # 계산을 위해 KV텐서를 K와 V로 나눔
        values, attention = scaled_dot_product(q, k, v, mask) # KQ 유사도 계산. V와 어텐션 스코어 계산됨
        values = values.permute(0, 2, 1, 3).reshape(batch_size, sequence_length, d_model)
        out = self.linear_layer(values)   # 출력된 V를 입력받아, 선형 레이어 계산
        return out

나머지 코드는 인코더와 거의 동일하다. 

디코더 처리
앞서 정의된 멀티헤더 크로스 어텐션 처리만 인코더 KV, 디코더 Q벡터값을 받아 처리하고, 마스크 행렬을 입력하는 하는 것 이외에는 논문 내용과 동일하다. 

여기서, Q = Query (what I'm looking for). 질의어. 디코더의 출력값
K = Key (what I can offer). 토큰 간 관계 유사도 계산을 위해 Query와 비교할 때 사용됨
V = Value (what I actually offer). Query, Key에 대한 최종 출력으로 관계성 계산에 사용
dk = k 벡터의 차원
M = Mask (미래 데이터 토큰만 학습데이터로 고려함)

코드는 다음과 같다. 
class DecoderLayer(nn.Module):
    def __init__(self, d_model, ffn_hidden, num_heads, drop_prob):
        super(DecoderLayer, self).__init__()
        # 멀티헤드, 레이어 정규화, 드롭아웃 정의
        self.self_attention = MultiHeadAttention(d_model=d_model, num_heads=num_heads)
        self.layer_norm1 = LayerNormalization(parameters_shape=[d_model])
        self.dropout1 = nn.Dropout(p=drop_prob)

        # 인코더 KV와 디코더 Q입력받아 멀티헤드 클로스 어텐션하는 모듈 정의
        self.encoder_decoder_attention = MultiHeadCrossAttention(d_model=d_model, num_heads=num_heads)
        self.layer_norm2 = LayerNormalization(parameters_shape=[d_model])
        self.dropout2 = nn.Dropout(p=drop_prob)

        # 논문의 FF 레이어 정의
        self.ffn = PositionwiseFeedForward(d_model=d_model, hidden=ffn_hidden, drop_prob=drop_prob)
        self.layer_norm3 = LayerNormalization(parameters_shape=[d_model])
        self.dropout3 = nn.Dropout(p=drop_prob)

    def forward(self, x, y, self_attention_mask, cross_attention_mask):
        _y = y.clone()
        y = self.self_attention(y, mask=self_attention_mask)
        y = self.dropout1(y)
        y = self.layer_norm1(y + _y)

        _y = y.clone()
        y = self.encoder_decoder_attention(x, y, mask=cross_attention_mask) # 인코더 KV 벡터 입력, 디코더 Q 벡터 입력. 마스크 처리. 
        y = self.dropout2(y)
        y = self.layer_norm2(y + _y)

        _y = y.clone()
        y = self.ffn(y)
        y = self.dropout3(y)
        y = self.layer_norm3(y + _y)
        return y

다음과 같이 디코더를 실행해본다. 
d_model = 512
num_heads = 8
drop_prob = 0.1
batch_size = 30
max_sequence_length = 200
ffn_hidden = 2048
num_layers = 5

decoder = Decoder(d_model, ffn_hidden, num_heads, drop_prob, num_layers)

x = torch.randn((batch_size, max_sequence_length, d_model)) 
y = torch.randn((batch_size, max_sequence_length, d_model)) 
mask = torch.zeros(max_sequence_length, max_sequence_length)

out = decoder(x, y, mask)
     
print(out)

결과는 다음과 같다. 
트랜스포머 디코더 출력 결과 일부

트랜스포머 전체 코드 구현
인코더와 디코더 전체를 포함한 트랜스포머 코드는 다음과 같다. 
import numpy as np
import torch
import math
from torch import nn
import torch.nn.functional as F

def get_device():
    return torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

def scaled_dot_product(q, k, v, mask=None):  # 논문의 cosine 유사도 계산 로직
    d_k = q.size()[-1]
    scaled = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(d_k)
    if mask is not None:
        scaled = scaled.permute(1, 0, 2, 3) + mask
        scaled = scaled.permute(1, 0, 2, 3)
    attention = F.softmax(scaled, dim=-1)
    values = torch.matmul(attention, v)
    return values, attention

class PositionalEncoding(nn.Module):            # 포지션 인코딩 구현 모듈
    def __init__(self, d_model, max_sequence_length):
        super().__init__()
        self.max_sequence_length = max_sequence_length
        self.d_model = d_model

    def forward(self):
        even_i = torch.arange(0, self.d_model, 2).float()
        denominator = torch.pow(10000, even_i/self.d_model)
        position = (torch.arange(self.max_sequence_length)
                          .reshape(self.max_sequence_length, 1))
        even_PE = torch.sin(position / denominator)
        odd_PE = torch.cos(position / denominator)
        stacked = torch.stack([even_PE, odd_PE], dim=2)
        PE = torch.flatten(stacked, start_dim=1, end_dim=2)
        return PE

class SentenceEmbedding(nn.Module):
    def __init__(self, max_sequence_length, d_model, language_to_index, START_TOKEN, END_TOKEN, PADDING_TOKEN):
        super().__init__()
        self.vocab_size = len(language_to_index)
        self.max_sequence_length = max_sequence_length
        self.embedding = nn.Embedding(self.vocab_size, d_model)
        self.language_to_index = language_to_index
        self.position_encoder = PositionalEncoding(d_model, max_sequence_length)
        self.dropout = nn.Dropout(p=0.1)
        self.START_TOKEN = START_TOKEN
        self.END_TOKEN = END_TOKEN
        self.PADDING_TOKEN = PADDING_TOKEN
    
    def batch_tokenize(self, batch, start_token, end_token):
        def tokenize(sentence, start_token, end_token):
            sentence_word_indicies = [self.language_to_index[token] for token in list(sentence)]
            if start_token:
                sentence_word_indicies.insert(0, self.language_to_index[self.START_TOKEN])
            if end_token:
                sentence_word_indicies.append(self.language_to_index[self.END_TOKEN])
            for _ in range(len(sentence_word_indicies), self.max_sequence_length):
                sentence_word_indicies.append(self.language_to_index[self.PADDING_TOKEN])
            return torch.tensor(sentence_word_indicies)

        tokenized = []
        for sentence_num in range(len(batch)):
           tokenized.append( tokenize(batch[sentence_num], start_token, end_token) )
        tokenized = torch.stack(tokenized)
        return tokenized.to(get_device())
    
    def forward(self, x, start_token, end_token): # sentence
        x = self.batch_tokenize(x, start_token, end_token)   # 입력 문장 토큰화
        x = self.embedding(x)                                      # 임베딩
        pos = self.position_encoder().to(get_device())         # 포지션 인코딩
        x = self.dropout(x + pos)                                  # 임베딩 + 포지션 임코딩 결과
        return x

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        self.qkv_layer = nn.Linear(d_model , 3 * d_model)
        self.linear_layer = nn.Linear(d_model, d_model)
    
    def forward(self, x, mask):        # 논문에서 설명한 것과 동일
        batch_size, sequence_length, d_model = x.size()
        qkv = self.qkv_layer(x)
        qkv = qkv.reshape(batch_size, sequence_length, self.num_heads, 3 * self.head_dim)
        qkv = qkv.permute(0, 2, 1, 3)
        q, k, v = qkv.chunk(3, dim=-1)
        values, attention = scaled_dot_product(q, k, v, mask)  # QK 간 cosine 유사도 계산. V와 어텐션 스코어 리턴
        values = values.permute(0, 2, 1, 3).reshape(batch_size, sequence_length, self.num_heads * self.head_dim)
        out = self.linear_layer(values)
        return out

class LayerNormalization(nn.Module):   # 레이어 가중치 정규화
    def __init__(self, parameters_shape, eps=1e-5):
        super().__init__()
        self.parameters_shape=parameters_shape
        self.eps=eps
        self.gamma = nn.Parameter(torch.ones(parameters_shape))
        self.beta =  nn.Parameter(torch.zeros(parameters_shape))

    def forward(self, inputs):
        dims = [-(i + 1) for i in range(len(self.parameters_shape))]
        mean = inputs.mean(dim=dims, keepdim=True)
        var = ((inputs - mean) ** 2).mean(dim=dims, keepdim=True)
        std = (var + self.eps).sqrt()
        y = (inputs - mean) / std
        out = self.gamma * y + self.beta
        return out

  
class PositionwiseFeedForward(nn.Module):   # FF 레이어. 라벨 결과 출력으로 맵핑하는 역할
    def __init__(self, d_model, hidden, drop_prob=0.1):
        super(PositionwiseFeedForward, self).__init__()
        self.linear1 = nn.Linear(d_model, hidden)
        self.linear2 = nn.Linear(hidden, d_model)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(p=drop_prob)

    def forward(self, x):
        x = self.linear1(x)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.linear2(x)
        return x

class EncoderLayer(nn.Module):
    def __init__(self, d_model, ffn_hidden, num_heads, drop_prob):
        super(EncoderLayer, self).__init__()
        self.attention = MultiHeadAttention(d_model=d_model, num_heads=num_heads)
        self.norm1 = LayerNormalization(parameters_shape=[d_model])
        self.dropout1 = nn.Dropout(p=drop_prob)
        self.ffn = PositionwiseFeedForward(d_model=d_model, hidden=ffn_hidden, drop_prob=drop_prob)
        self.norm2 = LayerNormalization(parameters_shape=[d_model])
        self.dropout2 = nn.Dropout(p=drop_prob)

    def forward(self, x, self_attention_mask):
        residual_x = x.clone()
        x = self.attention(x, mask=self_attention_mask)   # 어텐션 계산 후 V리턴. 
        x = self.dropout1(x)                                      # 드롭 아웃
        x = self.norm1(x + residual_x)                         # 잔차 연결 및 정규화
        residual_x = x.clone()
        x = self.ffn(x)                                               # FF 레이어 계산
        x = self.dropout2(x)                                      # 드롭 아웃
        x = self.norm2(x + residual_x)                         # 잔차 연결 및 정규화
        return x
    
class SequentialEncoder(nn.Sequential):
    def forward(self, *inputs):
        x, self_attention_mask  = inputs
        for module in self._modules.values():
            x = module(x, self_attention_mask)
        return x

class Encoder(nn.Module):
    def __init__(self, 
                 d_model, 
                 ffn_hidden, 
                 num_heads, 
                 drop_prob, 
                 num_layers,
                 max_sequence_length,
                 language_to_index,
                 START_TOKEN,
                 END_TOKEN, 
                 PADDING_TOKEN):
        super().__init__()
        self.sentence_embedding = SentenceEmbedding(max_sequence_length, d_model, language_to_index, START_TOKEN, END_TOKEN, PADDING_TOKEN)
        self.layers = SequentialEncoder(*[EncoderLayer(d_model, ffn_hidden, num_heads, drop_prob)
                                      for _ in range(num_layers)])  # 입력 레이어 수만큼 레이어 생성

    def forward(self, x, self_attention_mask, start_token, end_token):
        x = self.sentence_embedding(x, start_token, end_token)
        x = self.layers(x, self_attention_mask)
        return x

class MultiHeadCrossAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        self.kv_layer = nn.Linear(d_model , 2 * d_model)
        self.q_layer = nn.Linear(d_model , d_model)
        self.linear_layer = nn.Linear(d_model, d_model)
    
    def forward(self, x, y, mask):  # 정의한 멀티헤드 어텐션과 거의 동일. 단, y는 디코더 출력, x는 인코더 출력임
        batch_size, sequence_length, d_model = x.size() # in practice, this is the same for both languages...so we can technically combine with normal attention
        kv = self.kv_layer(x)
        q = self.q_layer(y)
        kv = kv.reshape(batch_size, sequence_length, self.num_heads, 2 * self.head_dim)
        q = q.reshape(batch_size, sequence_length, self.num_heads, self.head_dim)
        kv = kv.permute(0, 2, 1, 3)
        q = q.permute(0, 2, 1, 3)
        k, v = kv.chunk(2, dim=-1)
        values, attention = scaled_dot_product(q, k, v, mask) # 인코더 출력 KV와 디코더 출력 Q가 함께 고려된 V출력. 어텐션 스코더 계산 출력.
        values = values.permute(0, 2, 1, 3).reshape(batch_size, sequence_length, d_model)
        out = self.linear_layer(values)
        return out

class DecoderLayer(nn.Module):
    def __init__(self, d_model, ffn_hidden, num_heads, drop_prob):
        super(DecoderLayer, self).__init__()
        self.self_attention = MultiHeadAttention(d_model=d_model, num_heads=num_heads)
        self.layer_norm1 = LayerNormalization(parameters_shape=[d_model])
        self.dropout1 = nn.Dropout(p=drop_prob)

        self.encoder_decoder_attention = MultiHeadCrossAttention(d_model=d_model, num_heads=num_heads)
        self.layer_norm2 = LayerNormalization(parameters_shape=[d_model])
        self.dropout2 = nn.Dropout(p=drop_prob)

        self.ffn = PositionwiseFeedForward(d_model=d_model, hidden=ffn_hidden, drop_prob=drop_prob)
        self.layer_norm3 = LayerNormalization(parameters_shape=[d_model])
        self.dropout3 = nn.Dropout(p=drop_prob)

    def forward(self, x, y, self_attention_mask, cross_attention_mask):
        _y = y.clone()
        y = self.self_attention(y, mask=self_attention_mask)  # 멀티헤드 어텐션 계산
        y = self.dropout1(y)
        y = self.layer_norm1(y + _y)

        _y = y.clone()
        y = self.encoder_decoder_attention(x, y, mask=cross_attention_mask)  # 멀티헤드 크로스 어텐션 계산
        y = self.dropout2(y)
        y = self.layer_norm2(y + _y)

        _y = y.clone()
        y = self.ffn(y)
        y = self.dropout3(y)
        y = self.layer_norm3(y + _y)
        return y

class SequentialDecoder(nn.Sequential):
    def forward(self, *inputs):
        x, y, self_attention_mask, cross_attention_mask = inputs
        for module in self._modules.values():
            y = module(x, y, self_attention_mask, cross_attention_mask)
        return y

class Decoder(nn.Module):
    def __init__(self, 
                 d_model, 
                 ffn_hidden, 
                 num_heads, 
                 drop_prob, 
                 num_layers,
                 max_sequence_length,
                 language_to_index,
                 START_TOKEN,
                 END_TOKEN, 
                 PADDING_TOKEN):
        super().__init__()
        self.sentence_embedding = SentenceEmbedding(max_sequence_length, d_model, language_to_index, START_TOKEN, END_TOKEN, PADDING_TOKEN)
        self.layers = SequentialDecoder(*[DecoderLayer(d_model, ffn_hidden, num_heads, drop_prob) for _ in range(num_layers)])

    def forward(self, x, y, self_attention_mask, cross_attention_mask, start_token, end_token):
        y = self.sentence_embedding(y, start_token, end_token)
        y = self.layers(x, y, self_attention_mask, cross_attention_mask)
        return y

class Transformer(nn.Module):
    def __init__(self, 
                d_model, 
                ffn_hidden, 
                num_heads, 
                drop_prob, 
                num_layers,
                max_sequence_length, 
                kn_vocab_size,
                english_to_index,
                kannada_to_index,
                START_TOKEN, 
                END_TOKEN, 
                PADDING_TOKEN
                ):
        super().__init__()
        self.encoder = Encoder(d_model, ffn_hidden, num_heads, drop_prob, num_layers, max_sequence_length, english_to_index, START_TOKEN, END_TOKEN, PADDING_TOKEN)
        self.decoder = Decoder(d_model, ffn_hidden, num_heads, drop_prob, num_layers, max_sequence_length, kannada_to_index, START_TOKEN, END_TOKEN, PADDING_TOKEN)
        self.linear = nn.Linear(d_model, kn_vocab_size)
        self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

    def forward(self, 
                x, 
                y, 
                encoder_self_attention_mask=None, 
                decoder_self_attention_mask=None, 
                decoder_cross_attention_mask=None,
                enc_start_token=False,
                enc_end_token=False,
                dec_start_token=False, # We should make this true
                dec_end_token=False): # x, y are batch of sentences
        x = self.encoder(x, encoder_self_attention_mask, start_token=enc_start_token, end_token=enc_end_token)
        out = self.decoder(x, y, decoder_self_attention_mask, decoder_cross_attention_mask, start_token=dec_start_token, end_token=dec_end_token) # 인코더는 소스 언어(영어), 디코더는 목표언어(독일어)를 학습하고, 인코더에서 계산된 값이 KV벡터로 디코더에 입력됨. 디코더 어텐션 출력은 앞의 KV값과 함께 어텐션 계산되어, 어텐션 스코어가 출력됨. 이 과정을 반복학습함.
        out = self.linear(out)
        return out

트랜스포머 학습하기
이제 트랜스포머에 인코더와 디코더가 구현되었으므로, 다음과 같은 순서로 학습을 진행하면 된다. 학습은 소스(입력) 언어(영어)의 문장과 이에 일치되어 번역되어야 할 목표 언어(독일어)가 배치 데이터로 동일하게 각각 인코더, 디코더에 입력된다. 이후, 어텐션 수식에 따라 인코더 학습 후, 디코더가 인코더 출력을 받아 학습한다. 
total_loss = 0
num_epochs = 10

for epoch in range(num_epochs):
  print(f"Epoch {epoch}")
  iterator = iter(train_loader)
  for batch_num, batch in enumerate(iterator):  # 배치마다 30개 소스 언어, 목표 언어 문장이 있음. 이를 학습하게 됨.
    transformer.train()
    eng_batch, target_batch = batch
    encoder_self_attention_mask, decoder_self_attention_mask, decoder_cross_attention_mask = create_masks(eng_batch, target_batch)
    optim.zero_grad()
    target_predictions = transformer(eng_batch,
               target_batch,
               encoder_self_attention_mask.to(device), 
               decoder_self_attention_mask.to(device), 
               decoder_cross_attention_mask.to(device),
               enc_start_token=False,
               enc_end_token=False,
               dec_start_token=True,
               dec_end_token=True)  # 학습 시작. 인코더는 소스 언어(영어), 디코더는 목표언어(독일어)를 학습하고, 인코더에서 계산된 값이 KV벡터로 디코더에 입력됨. 디코더 어텐션 출력은 앞의 KV값과 함께 어텐션 계산되어, 어텐션 스코어가 출력됨. 이 과정을 반복학습함. 
    labels = transformer.decoder.sentence_embedding.batch_tokenize(target_batch, start_token=False, end_token=True)
    loss = criterian(
      target_predictions.view(-1, target_vocab_size).to(device),
      labels.view(-1).to(device)
    ).to(device)
    valid_indicies = torch.where(labels.view(-1) == kannada_to_index[PADDING_TOKEN], False, True)
    loss = loss.sum() / valid_indicies.sum()
    loss.backward()
    optim.step()
    #train_losses.append(loss.item())
    if batch_num % 100 == 0:
      print(f"Iteration {batch_num} : {loss.item()}")
      print(f"English: {eng_batch[0]}")
      target_sentence_predicted = torch.argmax(target_predictions[0], axis=1)
      predicted_sentence = ""
      for idx in target_sentence_predicted:
        if idx == kannada_to_index[END_TOKEN]:
          break
        predicted_sentence += index_to_kannada[idx.item()]
      print(f"Target Prediction: {predicted_sentence}")

      transformer.eval()
      target_sentence = ("",)
      for word_counter in range(max_sequence_length):
        encoder_self_attention_mask, decoder_self_attention_mask, decoder_cross_attention_mask= create_masks(eng_sentence, target_sentence)
        predictions = transformer(eng_sentence,
                      target_sentence,
                      encoder_self_attention_mask.to(device), 
                      decoder_self_attention_mask.to(device), 
                      decoder_cross_attention_mask.to(device),
                      enc_start_token=False,
                      enc_end_token=False,
                      dec_start_token=True,
                      dec_end_token=False)
        next_token_prob_distribution = predictions[0][word_counter] # not actual probs
        next_token_index = torch.argmax(next_token_prob_distribution).item()
        next_token = index_to_kannada[next_token_index]
        target_sentence = (target_sentence[0] + next_token, )
        if next_token == END_TOKEN:
          break

결과적으로 소스 언어(영어)의 텍스트에서 토큰 간 상호관계를 유사도를 통해 계산하여 학습된 결과는 목표 언어(독일어)에서 동일한 과정을 거쳐 출력된 Q값과 디코더 KV를 유사도 계산하여, 어텐션하는 과정을 반복 학습한다. 결과적으로, 문맥에서 초점이 되는 단어는 목표 언어에서도 동일하게 초점이 어텐션되므로, 이를 합해 유사도를 계산해, Query에 대한 Value를 얻는 어텐션 스코어를 얻을 수 있다.

트랜스포머 전체 로직 호출 순서 분석
학습 부분의 트랜스포머 호출 순서를 다음과 같이 분석해 본다.
트랜스포머 처리 분석 화면 일부

결론적으로, 다음 순서와 같이 논문에서 설계한 방식과 동일하다. 학습 부분은 다른 딥러닝 학습 루프와 크게 다르지 않다. 다만, transformer 호출 부분만 차이가 있다. 여기서, eng_batch는 소스(입력) 언어 문장들 배치 데이터셋, target_batch는 목표 언어 문장들 배치 데이터셋이다. 
  1. Epoch 만큼 학습
    1. 배치 데이터셋 학습
      1. target_prediction = transformer(eng_batch, target_batch, encoder_mask, decoder_mask, decoder_cross_attention_mask)
        1. x = encoder(eng_batch, encoder_mask)
          1. x = sentence_embedding(eng_batch)
          2. x = laysers(x, encoder_mask)
          3. return x
        2. out = decoder(x, target_batch, decoder_mask, decoder_cross_attention_mask)
          1. y = sentence_embedding(target_batch)
          2. y = layers(x, y, decoder_mask, decoder_cross_attention_mask)
          3. return y
        3. out = linear(out)
      2. labels = transformer.batch_tokenize(target_batch)
      3. loss = CrossEntropyLoss(target_prediction, labels)
      4. loss.backward()
      5. optim.step()
인코더의 출력 KV와 디코더 Q를 이용해 어텐션 스코어를 계산하는 부분은 앞의 코드에서 적색 부분에 해당한다. 이 디코더의 멀티헤드 크로스 어텐션 부분만 좀 더 자세히 보자. 우선, 여기에서 사용되는 인코더의 출력 KV값이 계산되는 부분을 살펴보면 다음과 같다.  
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        ...
    
    def forward(self, x, mask):
        batch_size, sequence_length, d_model = x.size() # (1,200,512)
        qkv = self.qkv_layer(x)  # (1,200,1536)
        qkv = qkv.reshape(batch_size, sequence_length, self.num_heads, 3 * self.head_dim) # (1,200,8,192)
        qkv = qkv.permute(0, 2, 1, 3)   # (1,8,200,192)
        q, k, v = qkv.chunk(3, dim=-1) # (1,8,200,64), (1,8,200,64), (1,8,200,64)
        values, attention = scaled_dot_product(q, k, v, mask)  # (1,8,200,64), (1,8,200,200)
        values = values.permute(0, 2, 1, 3).reshape(batch_size, sequence_length, self.num_heads * self.head_dim) # (1,200,512)
        out = self.linear_layer(values)   # (1,200,512)
        return out

class EncoderLayer(nn.Module):
    def __init__(self, d_model, ffn_hidden, num_heads, drop_prob):
        super(EncoderLayer, self).__init__()
        self.attention = MultiHeadAttention(d_model=d_model, num_heads=num_heads)
        self.norm1 = LayerNormalization(parameters_shape=[d_model])
        self.dropout1 = nn.Dropout(p=drop_prob)
        self.ffn = PositionwiseFeedForward(d_model=d_model, hidden=ffn_hidden, drop_prob=drop_prob)
        self.norm2 = LayerNormalization(parameters_shape=[d_model])
        self.dropout2 = nn.Dropout(p=drop_prob)

    def forward(self, x, self_attention_mask):
        residual_x = x.clone()
        x = self.attention(x, mask=self_attention_mask)  # (1,200,512)
        x = self.dropout1(x)                                     # (1,200,512)
        x = self.norm1(x + residual_x)                        # (1,200,512)
        residual_x = x.clone()                                   # ...
        x = self.ffn(x) 
        x = self.dropout2(x)
        x = self.norm2(x + residual_x)                        # (1,200,512)
        return x      # (1,200,512)

class Encoder(nn.Module):
    def __init__(self, ...
        super().__init__()
        self.sentence_embedding = SentenceEmbedding(max_sequence_length, d_model, language_to_index)
        self.layers = SequentialEncoder(*[EncoderLayer(d_model, ffn_hidden, num_heads, drop_prob)
                                      for _ in range(num_layers)])

    def forward(self, x, self_attention_mask, start_token, end_token):
        x = self.sentence_embedding(x, start_token, end_token)  # (1,200,512)
        x = self.layers(x, self_attention_mask)                          # (1,200,512)
        return x

단순히, 계산과정을 보았을 때는 소스 텍스트 입력 받아, 토큰으로 분리하고, 임베딩한 후, 이를 멀티헤드 어텐션 계산해, linear 레이어를 통과해준 것에 불과하다. 이 값을 디코더에서 받아, 멀티헤드 크로스 어텐션을 계산한다. 텐서 계산 과정 이해를 위해, 입력에 대해 계산 중간 텐서 모양이 어떻게 변화하는 지 확인해 보자.

class Transformer(nn.Module):
    def __init__(self, ...):
        super().__init__()
        self.encoder = Encoder(d_model, ffn_hidden, num_heads, drop_prob, num_layers, ...)
        self.decoder = Decoder(d_model, ffn_hidden, num_heads, drop_prob, num_layers, ...)
        self.linear = nn.Linear(d_model, kn_vocab_size)
        self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

    def forward(self, x, y, ...):
        x = self.encoder(x, encoder_self_attention_mask, start_token=enc_start_token, ...)  # (1,200,512)
        out = self.decoder(x, y, decoder_self_attention_mask, decoder_cross_attention_mask, ...)   # 
        out = self.linear(out)
        return out

class Decoder(nn.Module):
    def __init__(self, ...):
        super().__init__()
        self.sentence_embedding = SentenceEmbedding(max_sequence_length, d_model, ...)
        self.layers = SequentialDecoder(*[DecoderLayer(d_model, ffn_hidden, num_heads, drop_prob) for _ in range(num_layers)])

    def forward(self, x, y, self_attention_mask, cross_attention_mask, start_token, end_token):
        y = self.sentence_embedding(y, start_token, end_token)          # y=(1,200,512) from y input(token string)
        y = self.layers(x, y, self_attention_mask, cross_attention_mask)  # 
        return y

class MultiHeadCrossAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.kv_layer = nn.Linear(d_model , 2 * d_model)  # (512) => (2 * 512)
        self.q_layer = nn.Linear(d_model , d_model)        # (512) => (512)
        ...
    
    def forward(self, x, y, mask):
        batch_size, sequence_length, d_model = x.size() # in practice, this is the same for both languages...so we can technically combine with normal attention
        kv = self.kv_layer(x)  # x=(1, 200, 512), kv=(1,200,1024)
        q = self.q_layer(y) # y=(1,200,512), kv=(1, 200, 1024)
        kv = kv.reshape(batch_size, sequence_length, self.num_heads, 2 * self.head_dim) # (1, 200, 8, 128)
        q = q.reshape(batch_size, sequence_length, self.num_heads, self.head_dim) # (1, 200, 8, 64)
        kv = kv.permute(0, 2, 1, 3) # (1, 8, 200, 128)
        q = q.permute(0, 2, 1, 3)   # (1, 8, 200, 64)
        k, v = kv.chunk(2, dim=-1) # k=(1, 8, 200, 64), v=(1, 8, 200, 64)
        values, attention = scaled_dot_product(q, k, v, mask) # q=(1,8,200,64), k=(1,8,200,64), v=(1,8,200,64), values=(1,8,200,64), attention=(1,8,200,200)
        values = values.permute(0, 2, 1, 3).reshape(batch_size, sequence_length, d_model) # (1, 200, 512)
        out = self.linear_layer(values)  # (1,200,512)
        return out

트랜스포머에서 계산된 라벨값은 다음 수식으로 loss를 계산한 후 역전파하여, 학습 에포크가 진행될 수도록 loss가 줄어들도록 각 신경망층의 가중치를 조정해 나간다.

   loss = CrossEntropyLoss(target_prediction, labels)
   loss.backward()

결론적으로 디코더의 어텐션 스코어는 입력 텍스트와 출력 텍스트 간의 label 오차가 최소가 되도록 계산되게 된다. 

마무리
이 글에서 트랜스포머 디코더 구현 과정을 살펴보았다. 지금까지 깊게 트랜스포머의 동작 원리를 코드 수준에서 구현하고 확인해 보았다. 허깅페이스 서비스 등에서는 관련된 코드와 예제를 모두 구현해 제공하고 있어, 간단한 호출로 트랜스포머를 사용할 수 있다.
이 글은 많은 레퍼런스들을 참고해, 가능한 트랜스포머 동작방식을 이해하기 쉽도록 정리한 것이다. 관련해 궁금하다면, 이 글 마지막에 있는 레퍼런스들을 살펴보길 바란다. 

레퍼런스







댓글 없음:

댓글 쓰기