Paper Review

[Paper Review] Transformer Review (with Pytorch)

jiachoi 2022. 8. 25. 20:19

Transformer

  • Transformer's contribution
    • 이전의 RNN model이 불가능했던 병렬 처리를 가능하게 함
    • Attention 개념을 도입해 특정 시점에 집중
    • Positional Encoding을 사용해 sequential한 위치 정보 보존 
    • masking을 적용해 이전 시점의 값만이 이후에 영향을 미치도록 제한

Transformer 모델 

Transformer의 구조

: input sentence를 넣어서 output sentence를 생성해내는 모델. input과 동일한 sentence, input의 역방향인 sentence, 같은 의미의 다른 언어로된 sentence를 만들 수 있음. 이는 train 과정에서 어떤 것을 label로 두고 학습하느냐에 따라 다름. 

Transformer는 input을 사용해 sentence의 output을 만들어내는 함수
Transformer는 Encoder와 Decoder로 이루어짐

 

Encoder 

  • Encoder는 sentence를 input으로 받아서 하나의 vector를 생성하는 함수 (=Encoding)
  • Encoding으로 생성된 vector를 context라고 부름. 문맥을 함축해 담은 vector
  • Encoder는 이러한 context를 생성해내는 것을 목표로 학습 

Decoder

  • Decoder는 Encoder와 방향이 반대
  • context를 input으로 받아 sentence를 output으로 생성 (=Decoding)
  • Decoder는 sentence, contexxt를 input으로 받아 sentence를 만들어내는 함수 

Encoder & Decoder 

  • Encoder와 Decoder에는 모두 context vector가 들어감
  • Encoder는 Context를 생성, Decoder는 context를 사용함. 이런 흐름으로 Transofrmer가 구성됨 
# Simple Transformer Model 

class Transformer(nn.Module):
    
    def __init__(self, encoder, decoder): 
        super(Transformer, self).__init__()
        self.encoder=encoder
        self.decoder=decoder
        
    def encode(self, x):   # x: encoder의 input sentence 
        out=self.encoder(x)
        return out 
    
    def decode(self, z, c): # c: context vector, z: decoder의 input sentence 
        out=self.encoder(z, c)
        return out 
    
    def forward(self, x, z):
        c=self.encode(x)   # c: encoder의 output context vector 
        y=self.decode(z,c) # y: decoder의 output sentence 
        return y

 

Encoder

Encoder and Encoder Blocks

  • Encoder는 Encoder Block이 N개 쌓여진 형태
  • Encoder Block은 input, output 형태가 동일함 
  • 첫번쨰 Encoder Block의 input은 전체 Encoder의 input으로 들어오는 sentence의 embedding. 첫번째 Block이 해당 embedding을 생성해내면 이를 두번쨰 block이 input으로 사용되는 식으로 연결됨. 가장 마지막 N번째 Block의 output이 context vector가 되는 형식
  • context vector도 input sentence와 동일한 shape을 가짐
class Encoder(nn.Module):
    
    def __init__(self, encoder_block, n_layer): # n_layer=block의 수
        super(Encoder, self).__init__()
        self.layers=[]
        for i in range(n_layer):
            self.layers.append(copy.deepcopy(encoder_block))
    
    def forward(self, x): # 각 block의 output이 다음 block의 input으로 들어가는 형식 
        out=x
        for layer in self.layers: 
            out=layer(out)
        return out

 

Encoder Block 

  • Encoder Block은 Multi-Head Attention Layer, Position-wise Feed-Forward Layer로 구성됨

Encoder Block's configuration

class EncoderBlock(nn.Module):
    
    def __init__(self, self_attention, position_ff):
        super(EncoderBlock, self).__init__()
        self.self_attention=self_attention  # Encoder Block의 구성요소 (multi-head attention and positional feed-forward)
        self.position_ff=position_ff
    
    def forward(self, x):
        out=x
        out=self.self_attention(out)
        out=self.position_ff(out)
        return out

 

Attention 

  • Multi-Head Attention은 Scaled Dot-Product-Attention을 병렬적으로 여러개 수행하는 layer
  • Attention은 넓은 범위의 전체 data에서 특정한 부분에 집중한다는 의미임. Scaled Dot-Product-Attention을 줄여서 Attention으로 부르기도 함
  • Attention: 두 단어 사이의 연관 정도를 계산해내는 방법론
    • 같은 문장 내 두 단어의 Attention을 계산 --> Self-Attention
    • 서로 다른 문장에 각각 존재하는 두 token 사이 Attention 계산 --> Cross-Attention 

Query, Key, Value

  1. Query: 현재 시점의 token
  2. Key: attention 을 구하고자 하는 대상 token 
  3. Value: attention을 구하고자 하는 대상 token (Key와 동일한 token) 
The animal didn't cross the street, because it was too tired. 
  • 위 문장에서 "it"이 어떤 것을 지칭하는지 알고싶을 경우, "it"이 Query가 되며 이는 고정됨   
  • "it"과 "The"사이의 Attention을 알고 싶은 경우, "The"는 Key, Value가 됨.  
  • Query는 고정되어 하나의 token을 가리키고, Query와 가장 부합하는 (Attention이 가장 높은) token을 찾기 위해서 Key, Value가 문장의 처음부터 끝까지 탐색함

 

Scaled Dot-Product Attention의 흐름 

Query의 Attention
Attention 계산 구조

 

Scaled Dot-Product in Pytorch

def calculate_attention(query, key, value, mask):
    # query, key, value: (n_batch, seq_len, d_k)  
    # 실제 모델에 들어오는 input은 한 개의 문장이 아니라 mini-batch이므로, Q, K, V의 Shape에 n_batch가 추가됨 
    # Encoder의 input shape: n_batch×seq_len×d_k 
    
    # mask: (n_batch, seq_len, seq_len)
    d_k = key.shape[-1]
    attention_score = torch.matmul(query, key.transpose(-2, -1)) # Q x K^T, (n_batch, seq_len, seq_len)
    attention_score = attention_score / math.sqrt(d_k) # Scaling 
    if mask is not None:
        attention_score = attention_score.masked_fill(mask==0, -1e9)
    attention_prob = F.softmax(attention_score, dim=-1) # (n_batch, seq_len, seq_len) # Softmax로 Attention Prob로 변환 
    out = torch.matmul(attention_prob, value) # (n_batch, seq_len, d_k)  # Attention Prob * Value = Attention Score 도출 
    return out

 

Multi-Head Attention Layer 

  • Multi-Head Attention: Scaled-Dot Attention을 한 Encoder Layer마다 1회씩 수행하는 것이 아니라 병렬적으로 h회 수행한 뒤, 결과를 종합해서 사용함 --> 다양한 Attention을 잘 반영하기 위함
  • Q, K, V 자체를 n×d_k×d_k가 아닌, n×d_model×d_model로 생성해내서 한 번의 Self-Attention 계산으로 output을 만들어냄

 

Multi-Head Attention in Pytorch

class MultiHeadAttentionLayer(nn.Module):

    def __init__(self, d_model, h, qkv_fc, out_fc):
        super(MultiHeadAttentionLayer, self).__init__()
        self.d_model = d_model
        self.h = h
        self.q_fc = copy.deepcopy(qkv_fc) # (d_embed, d_model) # deepcopy 호출 -> 실제로 다른 weight를 갖고 별개로 진행 
        self.k_fc = copy.deepcopy(qkv_fc) # (d_embed, d_model)
        self.v_fc = copy.deepcopy(qkv_fc) # (d_embed, d_model)
        self.out_fc = out_fc              # (d_model, d_embed) # attention 계산 이후 거쳐가는 FC Layer, d_model X d_embed의 weight matrix를 갖음 
        
        def forward(self, *args, query, key, value, mask=None):
        # query, key, value: (n_batch, seq_len, d_embed) -> 인자로 받은 Q,K,V는 input sentence embedding -> 이를 3개의 FC Layer에 넣어 QKV를 구하는 것 
        # mask: (n_batch, seq_len, seq_len)
        # return value: (n_batch, h, seq_len, d_k)
        n_batch = query.size(0)
        
            # Transform 함수는 Q,K,V를 구하는 함수 
            def transform(x, fc):  # (n_batch, seq_len, d_embed)
                out = fc(x)        # (n_batch, seq_len, d_model)
                out = out.view(n_batch, -1, self.h, self.d_model//self.h) # (n_batch, seq_len, h, d_k)
                out = out.transpose(1, 2) # (n_batch, h, seq_len, d_k)
                return out

            query = transform(query, self.q_fc) # (n_batch, h, seq_len, d_k)
            key = transform(key, self.k_fc)     # (n_batch, h, seq_len, d_k)
            value = transform(value, self.v_fc) # (n_batch, h, seq_len, d_k)

            out = self.calculate_attention(query, key, value, mask) # (n_batch, h, seq_len, d_k)
            out = out.transpose(1, 2) # (n_batch, seq_len, h, d_k)
            out = out.contiguous().view(n_batch, -1, self.d_model) # (n_batch, seq_len, d_model)
            out = self.out_fc(out) # (n_batch, seq_len, d_embed)
            return out

        
        def calculate_attention(self, query, key, value, mask):
            # query, key, value: (n_batch, h, seq_len, d_k)
            # mask: (n_batch, 1, seq_len, seq_len)
            d_k = key.shape[-1]
            attention_score = torch.matmul(query, key.transpose(-2, -1)) # Q x K^T, (n_batch, h, seq_len, seq_len)
            attention_score = attention_score / math.sqrt(d_k)
            if mask is not None:
                attention_score = attention_score.masked_fill(mask==0, -1e9)
            attention_prob = F.softmax(attention_score, dim=-1) # (n_batch, h, seq_len, seq_len)
            out = torch.matmul(attention_prob, value) # (n_batch, h, seq_len, d_k)
            return out