Pull to refresh

Я построил Vision Transformer с нуля — и научил его обращать внимание

Level of difficultyEasy
Reading time6 min
Views4.3K

Vision Transformer (ViT) — это архитектура, которая буквально произвела революцию в том, как машины «видят» мир.

В этой статье я не просто объясню, что такое ViT — я покажу вам, как создать эту магию своими руками, шаг за шагом, даже если вы никогда раньше не работали с трансформерами для задач с изображениями.

Для начала давайте взглянем на архитектуру Vision Transformer:

Vision Transformer architecture
Vision Transformer architecture

Мы напишем код полностью с нуля, а затем обучим модель на датасете CIFAR-10.

Давайте начнём с реализации Patch Embedding:

class PatchEmbedding(nn.Module):
    def __init__(self, img_size = 32, patch_size = 4, in_channels = 3, embed_dim=256):
        super().__init__()

        assert img_size % patch_size == 0, "img_size must be divisible by patch_size"
        self.patch_size = patch_size
        self.num_patches = (img_size//patch_size)**2
        self.conv = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        x = self.conv(x) #(B, embed_dim, H/patch_size, W/patch_size)
        x = x.flatten(2).transpose(1, 2) #(B, num_patches, embed_dim)
        return x

Изображение будет разделено на патчи, и размер каждого патча можно задать с помощью параметра patch_size. При этом изображение не просто разбивается на патчи, но и пропускается через свёрточные ядра (CNN). В итоге мы получаем не просто патчи изображения — а встраивания (эмбеддинги) этих патчей.

Следующий шаг — реализовать самую интересную часть этой модели — механизм внимания (attention).

Self-Attention Mechanism
Self-Attention Mechanism

Q (Query) формально задаёт вопрос от каждого патча к другим патчам, K (Key) показывает, есть ли у каждого патча ответ на этот вопрос, а V (Value) содержит «значения» — фактические данные каждого патча, которые используются для формирования итогового представления.

Предположим, у нас есть X и Y, и мы хотим, чтобы X обращал внимание на Y. В этом случае матрица Query умножается на X, а матрицы Key и Value — на Y. Вместо прямого умножения на матрицы мы используем линейные слои.

attn_probs — это матрицы внимания, которые показывают, насколько токен i должен «обращать внимание» на токен j. Далее мы умножаем их на V, чтобы получить эмбеддинги изображения с учётом весов внимания attn_probs. V фактически хранит значения каждого патча изображения, а attn_probs показывает, сколько информации каждый патч должен получить от остальных патчей.

Вот как работает одна голова внимания; затем значения с всех голов объединяются. Такая конструкция основана на идее, что каждая голова фокусируется на разных аспектах.

class MultiHeadAttention(nn.Module):
    def __init__(self, dim, num_heads, dropout):
        super().__init__()
        assert dim % num_heads == 0, "Embedding dimension must be divisible by number of heads"
        self.num_heads = num_heads

        self.head_dim = dim // num_heads

        self.qkv = nn.Linear(dim, dim * 3, bias = False)
        self.out = nn.Linear(dim, dim, bias = False)

        self.scale = 1.0 / (self.head_dim ** 0.5)

        self.attn_dropout = nn.Dropout(dropout)

    def forward(self,  x, mask = None,  return_attn=False):
        B, num_patches, embed_dim = x.shape

        qkv = self.qkv(x) # (B, num_patches, 3*embed_dim)
        qkv = qkv.reshape(B, num_patches, 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4) #(3, B, num_heads, num_patches, head_dim)

        q, k, v = qkv[0], qkv[1], qkv[2]  # each (B, num_heads, num_patches, head_dim)

                                        #How important it is for token i to pay attention to token j.
        attn_scores = (q @ k.transpose(-2, -1)) * self.scale #[B, num_heads, N, N]

        if mask is not None:
            # mask: (B, 1, N, N) or (1, 1, N, N)
            attn_scores = attn_scores.masked_fill(mask == 0, float("-inf"))

        attn_probs = attn_scores.softmax(dim=-1) #[B, num_heads, N, N]
        attn_probs = self.attn_dropout(attn_probs)
        attn_output = attn_probs @ v  # (B, num_heads, num_patches, head_dim)
        attn_output = attn_output.transpose(1, 2).reshape(B, num_patches, embed_dim)

        if return_attn:
          return self.out(attn_output), attn_probs
        else:
          return self.out(attn_output) #(B, num_patches, embed_dim)

Давайте перейдём к сборке блока Transformer Encoder:

class TransformerEncoderBlock(nn.Module):
    def __init__(self, dim, num_heads, mlp_dim, dropout):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = MultiHeadAttention(dim, num_heads, dropout)
        self.norm2 = nn.LayerNorm(dim)

        self.mlp = nn.Sequential(
            nn.Linear(dim, mlp_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(mlp_dim, dim),
            nn.Dropout(dropout)
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, return_attn=False):
      if return_attn:
        attn_out, attn_weights = self.attn(self.norm1(x), return_attn=True)
        x = x + self.dropout(attn_out)
        x = x + self.dropout(self.mlp(self.norm2(x)))
        return x, attn_weights
      else:
        x = x + self.dropout(self.attn(self.norm1(x)))
        x = x + self.dropout(self.mlp(self.norm2(x)))
        return x

Здесь мы просто следуем архитектуре нашей сети — все необходимые блоки мы уже реализовали.

Мы уже почти на финишной прямой — теперь соберём сам Vision Transformer:

class VisualTransformer(nn.Module):
    def __init__(self,num_classes, img_size=32, patch_size=4, in_channels=3, embed_dim=256,
                 num_layers=6, num_heads=7, mlp_dim=512, dropout=0.1):
        super().__init__()

        self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
        self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.randn(1, 1 + self.patch_embed.num_patches, embed_dim))
        self.dropout = nn.Dropout(dropout)

        self.encoder_blocks = nn.ModuleList([
                TransformerEncoderBlock(embed_dim, num_heads, mlp_dim, dropout)
                for _ in range(num_layers)
            ])

        self.norm = nn.LayerNorm(embed_dim)

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(embed_dim),
            nn.Linear(in_features=embed_dim, out_features=num_classes)
        )

    def forward(self, x, return_attn = False):
        B = x.size(0)
        x = self.patch_embed(x)  # (B, N, D)

        cls_tokens = self.cls_token.expand(B, -1, -1)  # (B, 1, D)
        x = torch.cat((cls_tokens, x), dim=1)  # (B, 1+N, D)
        x = x + self.pos_embed
        x = self.dropout(x)

        attn_maps = []
        for block in self.encoder_blocks:
          if return_attn:
            x, attn = block(x, return_attn=True)
            attn_maps.append(attn)  # (B, heads, N, N)
          else:
            x = block(x) # (B, 1+N, D)

        x = self.norm(x)

        out = self.mlp_head(x[:, 0, :])

        if return_attn:
          return out, attn_maps
        else:
          return out

Здесь нужно уточнить несколько моментов. Что такое cls_token? Это специальный токен, который мы добавляем вручную, и он имеет тот же размер, что и патчи изображения. Его задача — использоваться позже для классификации изображения. Идея в том, что, проходя через блоки внимания, этот токен собирает информацию обо всём изображении.

Далее посмотрим на pos_embed. Поскольку мы делим изображение на патчи и выстраиваем их в последовательность — как будто работаем с текстом — модель изначально не понимает пространственные взаимосвязи между патчами. Чтобы это исправить, мы добавляем позиционную информацию к патчам. В нашем случае pos_embed — это обучаемый параметр.

Что касается mlp_head, здесь всё просто: он берёт cls_token, пропускает его через линейный слой и классифицирует изображение.

После сборки нашей модели давайте перейдём к обучению.

Для обучения мы будем использовать следующие гиперпараметры:

BATCH_SIZE = 128
EPOCHS = 80
LEARNING_RATE = 3e-4
PATCH_SIZE = 4
NUM_CLASSES = 10
IMAGE_SIZE = 32
CHANNELS = 3
EMBED_DIM = 256
NUM_HEADS = 8
DEPTH = 6
MLP_DIM = 512
DROP_RATE = 0.1

Давайте посмотрим на количество параметров:

Total parameters: 3,189,514
Trained parameters: 3,189,514

А также следующие аугментации:

train_transforms = transforms.Compose([
    transforms.Resize((70, 70)),
    transforms.RandomResizedCrop(32, scale=(0.8, 1.0)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(15),
    transforms.ColorJitter(0.2, 0.2, 0.2, 0.1),
    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.9, 1.1)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    transforms.RandomErasing(p=0.25, scale=(0.02, 0.2), ratio=(0.3, 3.3), value='random'),
])

Мы обучаем модель и получаем следующие результаты:

 Графики процесса обучения
Графики процесса обучения

Давайте посмотрим на предсказания модели:

Предсказания модели
Предсказания модели

Вот метрики получившейся модели:

Метрики модели
Метрики модели
 Матрица ошибок (Confusion Matrix)
Матрица ошибок (Confusion Matrix)

А теперь к самой интересной части — вниманию. Давайте посмотрим, на что наша модель обращает внимание во время классификации:

 Карта внимания (Attention Map)
Карта внимания (Attention Map)
 Карта внимания (Attention Map)
Карта внимания (Attention Map)

В этой статье мы подробно рассмотрели реализацию Vision Transformer и его механизма внимания. Мы изучили, на что способна эта модель и как она «смотрит» на изображение с помощью механизма внимания. Vision Transformer открыл новые направления в исследовании компьютерного зрения, объединив идеи из NLP и обработки изображений. В будущем мы обязательно применим эту модель для задачи генерации подписей к изображениям (Image Captioning).

Полный код и процесс обучения вы можете найти на моём Kaggle:

https://www.kaggle.com/code/nickr0ot/visual-transformer-from-scratch

Tags:
Hubs:
+10
Comments7

Articles