Как стать автором
Обновить

Пишем свой Transformer

Время на прочтение12 мин
Количество просмотров7.1K

Захотелось более детально разобраться и попробовать самостоятельно написать Transformer на PyTorch, а результатом поделиться здесь. Надеюсь, так же как и мне, это поможет ответить на какие-то вопросы в данной архитектуре.

Оставляю ссылку на свой канал: not_magic_neural_networks

0 Intro

Впервые архитектуру трансформер предложили использовать в 2017 году в статье Google "Attention is all you need" (именно для задачи машинного перевода). Вскоре, Google начал впервые в истории использовать нейросети для перевода (в google translate).

Архитектура оказалась настолько эффективна, что ее адаптировали и под другие задачи, а весь NLP начал так активно развиваться именно с 2017 года.

Итак, transformer - архитектура, основа которой механизм внимания attention. Он так же, как и RNN, состоит из encoder и decoder части.

Основная статья: "Attention is all you need"
Пост с разбором: Transformer
Еще можно почитать тут: 6.3. Трансформеры
Или посмотреть: Лекция. Архитектура Transformer. Decoder, QKV Attention

Далее, я буду полагать, что читатели немного знакомы с transformer и буду рассуждать больше в контексте реализации.

Transformer
Transformer

1 Multi-head Attention

An Explanation of Self-Attention mechanism in Transformer
What is Multi-Head Attention (MHA)
PyTorch: MultiheadAttention

Начнем с главной части трансформера - Multi-head Attention.

Multi-head Attention присутствует в трех местах:

  1. Self Multi-head Attention в энкодере

  2. Masked Multi-head Attention в декодере

  3. Multi-head Attention, соединяющий энкодер и декодер части (так же в декодере)

Multi-head Attention
Multi-head Attention

1.1. QKV - attention

Attention (или механизм внимания) несет информацию о связях слов друг с другом.

QKV-attention - один из способов вычисления связей:

Attention(Q, K, V) = softmax\Bigg(\frac{QK^T}{\sqrt{d_k}}\Bigg) \cdot V

Понятия Q(query), K(key) и V(value) пришли из поиска:
Q - запрос, K - краткая сводка по документу и V - сам документ. Задача состоит в том чтобы по Q и K понять насколько документ релевантен (какой V выбирать).

В attention, Q, K и V были придуманы для того чтобы отобразить вход x в три линейный пространства со следующими смыслами: Q отобразит слово в пространство "откуда"; K отобразит в пространство "куда"; V отобразит в пространство в "что важно".

Пусть на вход нам приходит тензор x.

Получаем матрицы запросов(Q), ключей(K) и значений(V), пропуская эмбеддинги x через линейные слои W_q, W_k и W_v.
q = x * W_q
k = x * W_k
v = x * W_v

Интуиция такая: query-вектор спрашивает у key-вектора, есть ли у него какая-то полезная информация чтобы обновить информацию о себе. То есть, когда мы делаем перемножение q ✕ k.T, то мы фильтруем нужную информацию (выбираем наиболее релевантную) из k для q.

Считаем релевантность какrelevance = q @ k.T / math.sqrt(head_size), где math.sqrt(head_size) - нормировочная константа.

Получаем вероятности:relevance = softmax(relevance)

И считаем выход:head = relevance @ v

1.2. Multi-head Attention

Одну операцию attention(Q, K, V) принято называть "головой". Именно поэтому переменная выше была названа head. Multi-head Attention подразумевает что таких голов будет несколько и их количество будет определяться гиперпараметром num_head.

Подразумевается, что разные головы attention могли бы обращать внимание на разные типы связей между словами. Это и есть multi-head attention или attention с несколькими головами.

MultiHead(Q, K, V) = Concat(head_1, ..., head_H) \cdot W^O

head_i = Attention(Q_i, K_i, V_i)

Multi-head attention - конкатенация результатов нескольких отдельных attention операций (head_i), отображенная в заданную размерность с помощью еще одного линейного слоя.

1.3. Encoder Multi-Head Attention (или Self)

Self-attention работает только с вектором энкодера, то есть энкодер учится обращать внимание сам на себя (потому self). Его идея заключается в следующем: self-attention обновляет embedding-и каждого токена, добавляя в них полезную информацию на основе контекста (всех остальных токенов в предложении) в котором эти embedding-и находятся.

Self Multi-Head Attention
Self Multi-Head Attention

1.4 Encoder-Decoder Multi-Head Attention

Multi-head Attention соединяющий энкодер и декодер трансформера должен учитывать, что размерность Q (пришедшей из декодера) может отчитаться от V и K (пришедшей из энкодера).

Encoder-Decoder Multi-Head Attention
Encoder-Decoder Multi-Head Attention

1.5. Decoder Multi-head Attention (или Masked)

Последнее место, где используется Multi-head Attention:

Masked Multi-head Attention
Masked Multi-head Attention

Изначально, Transformer применялся к задачу перевода текста, в которой выходы (собственно перевод) должны генерироваться авторегрессионно, слово за словом.

Однако, Masked Multi-Head Attention реализует более хитрый вариант: к attention (матрице, которая отвечает "кто на кого смотрит") применяется авторегрессивная маска, обращаюая в −∞ веса до softmax для токенов из будущего, чтобы после softmax их вероятности стали нулевыми. Эта маска имеет нижнетреугольный вид (левый нижний треугольник включая главную диагональ).

Таким образом, каждый элемент смотрит только на те элементы, которые были в прошлом и не заглядывает в будущее. Это позволяет подавать данные не авторегрессионно, а все сразу.

При этом, на инференсе Masked Multi-Head Attention работает как обычный Multi-Head Attention.

В функции Multi-Head Attention применим к relevance маску, если она задана, до функции softmax: в тех местах, где маска содержит True, ставим -inf, чтобы после применения softmax значение в этой ячейке занулилось (так как экспонента возведенная в степень -inf будет равняться нулю).

1.6. Пишем Multi-head Attention

class MultiHeadAttention(nn.Module):
    def __init__(self, input_size, head_size, num_heads, out_size, query_input_size=None, masked=False):
        super(MultiHeadAttention, self).__init__()
        
        self.input_size = input_size
        self.head_size = head_size
        self.num_heads = num_heads
        self.out_size = out_size
        # для attention, который смешивает информацию энкодера и декодера
        self.query_input_size = self.input_size if query_input_size is None else query_input_size
        
        self.W_Q = nn.Linear(self.query_input_size, self.num_heads * self.head_size, bias=False)
        self.W_K = nn.Linear(self.input_size, self.num_heads * self.head_size, bias=False)
        self.W_V = nn.Linear(self.input_size, self.num_heads * self.head_size, bias=False)
        
        # последний линейный слой
        self.out = nn.Linear(self.head_size * self.num_heads, self.out_size) 
        self.masked = masked

    # создает маску для Masked Multi-Head Attention
    def make_mask(self, embedding):
        batch_size, emb_len, _ = embedding.shape
        mask = torch.tril(torch.ones((emb_len, emb_len))).expand(batch_size, 1, emb_len, emb_len).bool()
        return mask
    
    def forward(self, query, key, value):
        # batch_size, emb_len, input_size

        batch_size = key.size(0)
        emb_len = key.size(1)
        query_emb_len = query.size(1)
       
        # Применяем линейные преобразования на входе
        q = self.W_Q(query)
        k = self.W_K(key)
        v = self.W_V(value)
        # batch_size, emb_len, self.num_heads * self.head_size
        
        # разделяем головы
        q = q.view(batch_size, query_emb_len, self.num_heads, self.head_size)
        k = k.view(batch_size, emb_len, self.num_heads, self.head_size)
        v = v.view(batch_size, emb_len, self.num_heads, self.head_size)
        # batch_size, emb_len, num_heads, head_size
        
        q = q.transpose(1,2) 
        k = k.transpose(1,2) 
        v = v.transpose(1,2) 
        # batch_size, num_heads, emb_len, head_size
        
        k_T = k.transpose(2, 3) 
        # batch_size, num_heads, head_size, emb_len 
        
        relevance = q @ k_T / math.sqrt(self.head_size)  
        # batch_size, num_heads, query_emb_len, emb_len

        if self.masked:
            mask = self.make_mask(key)
            relevance = relevance.masked_fill(~mask, -torch.inf)
        
        # softmax вдоль последней размерности
        relevance = F.softmax(relevance, dim=-1)
        
        heads = relevance @ v
        # batch_size, num_heads, query_emb_len, head_size
        
        heads = heads.transpose(1, 2)
        # batch_size, query_emb_len, num_heads, head_size
        
        concat = heads.reshape(batch_size, query_emb_len, self.head_size * self.num_heads)  
        # batch_size, query_emb_len, num_heads * head_size
                
        out = self.out(concat)
        # batch_size, query_emb_len, out_size
        
        return out

2 Positional Encoding

Главное преимущество RNN моделей - это то, что они сохраняют временнУю связь между словами, читая текст слово за словом.

Positional Encoding добавляет к изначальным эмбеддингам информацию о позиции эмбеддинга в предложении (позиционный вектор такого же размера, что и сам эмбеддинг):

Существуют разные подходы получения позициональных векторов. Один из самых распространенных - через тригонометрические функции. Он и был предложен в статье "Attention is all you need" (см. 3.5 Positional Encoding):

PE_{(pos,2i)} = \sin{(pos/10000^{2i/d_{model}})}PE_{(pos,2i+1)} = \cos{(pos/10000^{2i/d_{model}})}

d_model - размер эмбеддинга, i - номер элемента позиционного вектора, pos - позиция токена.

class PositionalEncoding(nn.Module):
    def __init__(self, max_emb_len, d_model):
        super(PositionalEncoding, self).__init__()
        
        self.max_emb_len = max_emb_len
        self.d_model = d_model
        
        pos = torch.arange(max_emb_len)[:, None] # [[0], [1], [2], [3], [4], ...[max_emb_len]]]
        i = torch.arange(d_model)[None, :] # i = [[0, 1, 2, 3, 4, ..., d_model]]

        pe = torch.zeros(self.max_emb_len, self.d_model)
        # max_emb_len, d_model
        
        sin = torch.sin(pos / (10000 ** (i[:, ::2] / self.d_model)))
        # max_emb_len, d_model // 2
        
        cos = torch.cos(pos / (10000 ** (i[:, 1::2] / self.d_model)))
        # max_emb_len, d_model // 2
        
        pe[:, ::2] = sin
        pe[:, 1::2] = cos
        
        pe = pe.unsqueeze(0)
        # 1, max_emb_len, d_model
        
        # Данный тензор будем хранить в stage dict этого модуля
        # Optimizer его не будет оптимизировать
        self.register_buffer('pe', pe)

    def forward(self, emb):
        # batch_size, emb_len, input_size
        
        emb_len = emb.size(1)
        
        emb = emb + self.pe[:, :emb_len]
        # batch_size, emb_len, input_size
        
        return emb

3. Собираем Encoder

Encoder
Encoder

Enсoder Block повторяется N раз и состоит из:

  1. Self Multi-Head Attention

  2. Add & Norm(Add - прибавляет к embedding-ам self-attention слой; Norm - нормализует значения embedding-ов по слоям)

  3. Feed Forward(два полносвязных слоя, с функцией активацией ReLU между ними)

  4. Add & Norm

class EncoderBlock(nn.Module):
    # оставляем query_input_size для декодера
    # (ниже объяснение зачем)
    def __init__(self, input_size, head_size, num_heads, out_size, ff_hidden_size, query_input_size=None):
        super(EncoderBlock, self).__init__()
        
        self.input_size = input_size
        self.head_size = head_size
        self.num_heads = num_heads
        self.out_size = out_size
        self.query_input_size = input_size if query_input_size is None else query_input_size
        
        # для feed forward
        self.ff_hidden_size = ff_hidden_size
        
        self.attention = MultiHeadAttention(
            input_size=self.input_size, 
            head_size=self.head_size, 
            num_heads=self.num_heads, 
            out_size=self.out_size,
            query_input_size=self.query_input_size
        )
        
        if self.query_input_size != self.out_size:
            self.adapt = nn.Linear(self.query_input_size, self.out_size) 
        else:
            self.adapt = nn.Identity() # возвращает входные данные без изменений
        
        self.norm_1 = nn.LayerNorm(self.out_size) 
        
        self.feed_forward = nn.Sequential(OrderedDict([
            ("Linear_1", nn.Linear(self.out_size, self.ff_hidden_size)),
            ("Activation", nn.ReLU()),
            ("Linear_2", nn.Linear(self.ff_hidden_size, self.out_size)),
        ]))
        
        self.norm_2 = nn.LayerNorm(self.out_size)

    def forward(self, query, key, value):
        # batch_size, seq_len, in_size
        
        # Self Multi-Head Attention
        attention_out = self.attention(query, key, value)  
        # batch_size, seq_len, out_size
        
        # Add + Norm
        # add_1_out = attention_out + query
        # out_size - гиперпараметр, потому in_size (из query) может отличаться от out_size (из attention_out)
        # Потому воспользуемся дополнительным линейным слоем adapt чтобы привести к одной размерности out_size
        add_1_out = attention_out + self.adapt(query)
        norm_1_out = self.norm_1(add_1_out)
        # batch_size, seq_len, out_size
        
        # Feed Forward
        feed_forward_out = self.feed_forward(norm_1_out)
        # batch_size, seq_len, out_size
        
        # Add + Norm
        add_out = feed_forward_out + norm_1_out
        norm_2_out = self.norm_2(add_out)
        
        return norm_2_out

Повторяем EncoderBlock N раз:

  1. На вход подается embedding размером batch_size, seq_len, input_size.

  2. Применяем слой позиционного кодирования PositionalEncoding

  3. Прогоняем информацию N раз EncoderBlock

class TransformerEncoder(nn.Module):
    def __init__(self, N, max_seq_len, num_embeddings, emb_size, att_out_size, att_head_size, num_heads, ff_hidden_size):
        super(TransformerEncoder, self).__init__()

        self.N = N
        self.max_seq_len = max_seq_len
        
        # для embedding_layer
        self.num_embeddings = num_embeddings # количество уникальных индексов (например, количество слов в вашем корпусе данных)
        self.emb_size = emb_size # размерность векторов эмбеддингов
        
        self.att_out_size = att_out_size
        self.att_head_size = att_head_size
        self.num_heads = num_heads
        
        self.ff_hidden_size = ff_hidden_size
        
        self.embedding_layer = nn.Embedding(
            num_embeddings=self.num_embeddings, 
            embedding_dim=self.emb_size
        )
        self.positional_encoder = PositionalEncoding(
            max_emb_len=self.max_seq_len, 
            d_model=self.emb_size
        )

        self.encoder_blocks = nn.ModuleDict({
            f"encoder_block_{i}": EncoderBlock(
                input_size=self.emb_size if i==0 else self.att_out_size,
                head_size=self.att_head_size,
                num_heads=self.num_heads,
                out_size=self.att_out_size,
                ff_hidden_size=self.ff_hidden_size,
            ) for i in range(self.N)
        })
    
    def forward(self, encoder_input):
        # batch_size, seq_len
        
        encoder_emb = self.embedding_layer(encoder_input)  
        # batch_size, seq_len, emb_size
        
        out = self.positional_encoder(encoder_emb)
        # batch_size, seq_len, emb_size
        
        for block in self.encoder_blocks.values():
            out = block(out, out, out)  
        # batch_size, seq_len, att_out_size

        return out

4. Собираем Decoder

Сначала, аналогично энкодеру, соберем декодер блок, который будем повторять N раз:

Причем, вместо переписывания слое, можно повзаимствовать Encoder-блок для Decoder-а:

Часть декодера, где можно переиспользовать энкодер
Часть декодера, где можно переиспользовать энкодер
class DecoderBlock(nn.Module):
    def __init__(self, input_size, head_size, num_heads, out_size, ff_hidden_size, query_input_size=None, encoder_out_size=None):
        super(DecoderBlock, self).__init__()
        """
        input_size = query_input_size -> decoder_input
        encoder_out_size -> encoder_input
        """
        self.input_size = input_size
        self.head_size = head_size
        self.num_heads = num_heads
        self.out_size = out_size
        self.ff_hidden_size = ff_hidden_size
        self.encoder_out_size = input_size if encoder_out_size is None else encoder_out_size
        
        self.masked_attention = MultiHeadAttention(
            input_size=self.input_size, 
            head_size=self.head_size, 
            num_heads=self.num_heads, 
            out_size=self.out_size,
            query_input_size=None,
            masked=True
        )
        
        if self.input_size != self.out_size:
            self.adapt_1 = nn.Linear(self.input_size, self.out_size)  
        else:
            self.adapt_1 = nn.Identity()
            
        self.norm_1 = nn.LayerNorm(self.out_size)
        
        '''
        могло бы быть, но мы переиспользуем энкодер

        self.attention = MultiHeadAttention(
            input_size=self.encoder_out_size, 
            head_size=self.head_size, 
            num_heads=self.num_heads, 
            out_size=self.out_size, 
            query_input_size=self.out_size
        )
            
        self.norm_2 = nn.LayerNorm(self.out_size)
              
        self.feed_forward = nn.Sequential(OrderedDict([
            ("Linear_1", nn.Linear(self.out_size, self.ff_hidden_size)),
            ("Activation", nn.ReLU()),
            ("Linear_2", nn.Linear(self.ff_hidden_size, self.out_size)),
        ]))
        
        self.norm_3 = nn.LayerNorm(self.out_size)
        '''
        
        # Вместо того, что выше, переиспользуем часть Encoder-а
        self.encoder_block = EncoderBlock(
            input_size=self.encoder_out_size, # K, V из encoder
            head_size=self.head_size,
            num_heads=self.num_heads,
            out_size = self.out_size,
            ff_hidden_size=self.ff_hidden_size,
            query_input_size=self.out_size, # Q из decoder
        )
        
    def forward(self, decoder_emb, encoder_output):
        # decoder_emb.size() = batch_size, seq_len, input_size
        # encoder_output.size() = batch_size, encoder_seq_len, encoder_out_size
        
        
        out_mask_att = self.masked_attention(decoder_emb, decoder_emb, decoder_emb)
        # batch_size, emb_len, out_size
        
        out_add_1 = out_mask_att + self.adapt_1(decoder_emb)
        
        out_norm_1 = self.norm_1(out_add_1)
        # batch_size, emb_len, out_size
        
        '''
        attention = self.attention(out_norm_1, encoder_output, encoder_output)
        # batch_size, emb_len, out_size
        out_add_2 = attention + out_norm_1
        # batch_size, emb_len, out_size
        out_norm_2 = self.norm_2(out_add_2)
        # batch_size, emb_len, out_size
        
        out_feed_forward = self.feed_forward(out_norm_2)
        out_add_3 = masked_attention + out_norm_2
        out_norm_3 = self.norm_3(out_add_2)
        # batch_size, seq_len, out_size
        return out_norm_3
        '''
        # переиспользуем часть Encoder-а
        out_encoder_block = self.encoder_block(
            query=out_norm_1, 
            key=encoder_output, 
            value=encoder_output
        )
        # batch_size, seq_len, out_size
        
        return  out_norm_1
  1. На вход декодеру приходят эмбеддинги энкодера

  2. Добавляем к ним PositionalEnсoding

  3. Проходим N раз через декодер-блок

  4. Завершаем последним линейным слоем и softmax-ом

class TransformerDecoder(nn.Module):
    def __init__(self, N, max_seq_len, num_embeddings, emb_size, att_out_size, att_head_size, num_heads, ff_hidden_size, encoder_out_size=None):
        super(TransformerDecoder, self).__init__()
        
        self.N = N
        self.max_seq_len = max_seq_len
        
        # для embedding_layer
        self.num_embeddings = num_embeddings # количество уникальных индексов (например, количество слов в вашем корпусе данных)
        self.emb_size = emb_size # размерность векторов эмбеддингов
        
        self.att_out_size = att_out_size
        self.att_head_size = att_head_size
        self.num_heads = num_heads
        self.ff_hidden_size = ff_hidden_size
        self.encoder_out_size = input_size if encoder_out_size is None else encoder_out_size
        
        self.embedding_layer = nn.Embedding(self.num_embeddings, self.emb_size)
        self.positional_encoder = PositionalEncoding(self.max_seq_len, self.emb_size)

        self.decoder_blocks = nn.ModuleDict({
            f"decoder_block_{i}": DecoderBlock(
                input_size=self.emb_size if i==0 else self.att_out_size,
                head_size=self.att_head_size,
                num_heads=self.num_heads,
                out_size=self.att_out_size,
                ff_hidden_size=self.ff_hidden_size,
                encoder_out_size=self.encoder_out_size,
            ) for i in range(self.N)
        })
        
        self.fc = nn.Linear(self.att_out_size, self.num_embeddings)

    def forward(self, decoder_input, encoder_output):
        # decoder_input.size() = batch_size, seq_len
        # encoder_output.size() = batch_size, encoder_seq_le, encoder_out_size
        
        decoder_emb = self.embedding_layer(decoder_input)  
        # batch_size, seq_len, emb_size
        
        decoder_emb = self.positional_encoder(decoder_emb)
        # batch_size, seq_len, emb_size
        
        out = decoder_emb
     
        for block in self.decoder_blocks.values():
            out = block(out, encoder_output) 
        # batch_size, seq_len, att_out_size
            
        # out = self.fc(out)
        return out

5. Собираем Transformer

class Transformer(nn.Module):
    def __init__(self, max_seq_len, num_embeddings, emb_size,
                 N_encoder, enc_att_out_size, enc_att_head_size, enc_num_heads, enc_ff_hidden_size,
                 N_decoder, dec_att_out_size, dec_att_head_size, dec_num_heads, dec_ff_hidden_size,):
        super(Transformer, self).__init__()

        self.max_seq_len = max_seq_len
        self.num_embeddings = num_embeddings
        self.emb_size = emb_size
        
        self.N_encoder = N_encoder
        self.enc_att_out_size = enc_att_out_size
        self.enc_att_head_size = enc_att_head_size
        self.enc_num_heads = enc_num_heads
        self.enc_ff_hidden_size = enc_ff_hidden_size
        
        self.N_decoder = N_decoder
        self.dec_att_out_size = dec_att_out_size
        self.dec_att_head_size = dec_att_out_size
        self.dec_num_heads = dec_num_heads
        self.dec_ff_hidden_size = dec_ff_hidden_size

        # Encoder
        self.encoder = TransformerEncoder(
            N=self.N_encoder,
            max_seq_len=self.max_seq_len,
            num_embeddings=self.num_embeddings,
            emb_size=self.emb_size,
            att_head_size=self.enc_att_head_size,
            num_heads=self.enc_num_heads,
            att_out_size=self.enc_att_out_size,
            ff_hidden_size=self.enc_ff_hidden_size,
        )
        
        # Decoder
        self.decoder = TransformerDecoder(
            N=self.N_decoder,
            max_seq_len=self.max_seq_len,
            num_embeddings=self.num_embeddings,
            emb_size=self.emb_size,
            att_head_size=self.dec_att_head_size,
            num_heads=self.dec_num_heads,
            att_out_size=self.dec_att_out_size,
            ff_hidden_size=self.dec_ff_hidden_size,
            encoder_out_size=self.enc_att_out_size,
        )
    
    def forward(self, encoder_input, decoder_input):
        # encoder_input.size() = batch_size, enc_seq_len
        # decoder_input.size() = batch_size, dec_seq_len
        
        encoder_output = self.encoder(encoder_input)  
        # batch_size, enc_seq_len, enc_att_out_size
        
        decoder_output = self.decoder(decoder_input, encoder_output) 
        # batch_size, dec_seq_len, num_embeddings
        
        return decoder_output

Замечания

  1. Нет DropOut слоев


Ноутбук лежит в телеграм канале: not_magic_neural_networks

Теги:
Хабы:
Всего голосов 19: ↑18 и ↓1+20
Комментарии14

Публикации

Работа

Data Scientist
51 вакансия

Ближайшие события