Vision Transformer (ViT) — это архитектура, которая буквально произвела революцию в том, как машины «видят» мир.
В этой статье я не просто объясню, что такое ViT — я покажу вам, как создать эту магию своими руками, шаг за шагом, даже если вы никогда раньше не работали с трансформерами для задач с изображениями.
Для начала давайте взглянем на архитектуру Vision Transformer:

Мы напишем код полностью с нуля, а затем обучим модель на датасете CIFAR-10.
Давайте начнём с реализации Patch Embedding:
class PatchEmbedding(nn.Module): def __init__(self, img_size = 32, patch_size = 4, in_channels = 3, embed_dim=256): super().__init__() assert img_size % patch_size == 0, "img_size must be divisible by patch_size" self.patch_size = patch_size self.num_patches = (img_size//patch_size)**2 self.conv = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size) def forward(self, x): x = self.conv(x) #(B, embed_dim, H/patch_size, W/patch_size) x = x.flatten(2).transpose(1, 2) #(B, num_patches, embed_dim) return x
Изображение будет разделено на патчи, и размер каждого патча можно задать с помощью параметра patch_size. При этом изображение не просто разбивается на патчи, но и пропускается через свёрточные ядра (CNN). В итоге мы получаем не просто патчи изображения — а встраивания (эмбеддинги) этих патчей.
Следующий шаг — реализовать самую интересную часть этой модели — механизм внимания (attention).

Q (Query) формально задаёт вопрос от каждого патча к другим патчам, K (Key) показывает, есть ли у каждого патча ответ на этот вопрос, а V (Value) содержит «значения» — фактические данные каждого патча, которые используются для формирования итогового представления.
Предположим, у нас есть X и Y, и мы хотим, чтобы X обращал внимание на Y. В этом случае матрица Query умножается на X, а матрицы Key и Value — на Y. Вместо прямого умножения на матрицы мы используем линейные слои.
attn_probs — это матрицы внимания, которые показывают, насколько токен i должен «обращать внимание» на токен j. Далее мы умножаем их на V, чтобы получить эмбеддинги изображения с учётом весов внимания attn_probs. V фактически хранит значения каждого патча изображения, а attn_probs показывает, сколько информации каждый патч должен получить от остальных патчей.
Вот как работает одна голова внимания; затем значения с всех голов объединяются. Такая конструкция основана на идее, что каждая голова фокусируется на разных аспектах.
class MultiHeadAttention(nn.Module): def __init__(self, dim, num_heads, dropout): super().__init__() assert dim % num_heads == 0, "Embedding dimension must be divisible by number of heads" self.num_heads = num_heads self.head_dim = dim // num_heads self.qkv = nn.Linear(dim, dim * 3, bias = False) self.out = nn.Linear(dim, dim, bias = False) self.scale = 1.0 / (self.head_dim ** 0.5) self.attn_dropout = nn.Dropout(dropout) def forward(self, x, mask = None, return_attn=False): B, num_patches, embed_dim = x.shape qkv = self.qkv(x) # (B, num_patches, 3*embed_dim) qkv = qkv.reshape(B, num_patches, 3, self.num_heads, self.head_dim) qkv = qkv.permute(2, 0, 3, 1, 4) #(3, B, num_heads, num_patches, head_dim) q, k, v = qkv[0], qkv[1], qkv[2] # each (B, num_heads, num_patches, head_dim) #How important it is for token i to pay attention to token j. attn_scores = (q @ k.transpose(-2, -1)) * self.scale #[B, num_heads, N, N] if mask is not None: # mask: (B, 1, N, N) or (1, 1, N, N) attn_scores = attn_scores.masked_fill(mask == 0, float("-inf")) attn_probs = attn_scores.softmax(dim=-1) #[B, num_heads, N, N] attn_probs = self.attn_dropout(attn_probs) attn_output = attn_probs @ v # (B, num_heads, num_patches, head_dim) attn_output = attn_output.transpose(1, 2).reshape(B, num_patches, embed_dim) if return_attn: return self.out(attn_output), attn_probs else: return self.out(attn_output) #(B, num_patches, embed_dim)
Давайте перейдём к сборке блока Transformer Encoder:
class TransformerEncoderBlock(nn.Module): def __init__(self, dim, num_heads, mlp_dim, dropout): super().__init__() self.norm1 = nn.LayerNorm(dim) self.attn = MultiHeadAttention(dim, num_heads, dropout) self.norm2 = nn.LayerNorm(dim) self.mlp = nn.Sequential( nn.Linear(dim, mlp_dim), nn.GELU(), nn.Dropout(dropout), nn.Linear(mlp_dim, dim), nn.Dropout(dropout) ) self.dropout = nn.Dropout(dropout) def forward(self, x, return_attn=False): if return_attn: attn_out, attn_weights = self.attn(self.norm1(x), return_attn=True) x = x + self.dropout(attn_out) x = x + self.dropout(self.mlp(self.norm2(x))) return x, attn_weights else: x = x + self.dropout(self.attn(self.norm1(x))) x = x + self.dropout(self.mlp(self.norm2(x))) return x
Здесь мы просто следуем архитектуре нашей сети — все необходимые блоки мы уже реализовали.
Мы уже почти на финишной прямой — теперь соберём сам Vision Transformer:
class VisualTransformer(nn.Module): def __init__(self,num_classes, img_size=32, patch_size=4, in_channels=3, embed_dim=256, num_layers=6, num_heads=7, mlp_dim=512, dropout=0.1): super().__init__() self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim) self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim)) self.pos_embed = nn.Parameter(torch.randn(1, 1 + self.patch_embed.num_patches, embed_dim)) self.dropout = nn.Dropout(dropout) self.encoder_blocks = nn.ModuleList([ TransformerEncoderBlock(embed_dim, num_heads, mlp_dim, dropout) for _ in range(num_layers) ]) self.norm = nn.LayerNorm(embed_dim) self.mlp_head = nn.Sequential( nn.LayerNorm(embed_dim), nn.Linear(in_features=embed_dim, out_features=num_classes) ) def forward(self, x, return_attn = False): B = x.size(0) x = self.patch_embed(x) # (B, N, D) cls_tokens = self.cls_token.expand(B, -1, -1) # (B, 1, D) x = torch.cat((cls_tokens, x), dim=1) # (B, 1+N, D) x = x + self.pos_embed x = self.dropout(x) attn_maps = [] for block in self.encoder_blocks: if return_attn: x, attn = block(x, return_attn=True) attn_maps.append(attn) # (B, heads, N, N) else: x = block(x) # (B, 1+N, D) x = self.norm(x) out = self.mlp_head(x[:, 0, :]) if return_attn: return out, attn_maps else: return out
Здесь нужно уточнить несколько моментов. Что такое cls_token? Это специальный токен, который мы добавляем вручную, и он имеет тот же размер, что и патчи изображения. Его задача — использоваться позже для классификации изображения. Идея в том, что, проходя через блоки внимания, этот токен собирает информацию обо всём изображении.
Далее посмотрим на pos_embed. Поскольку мы делим изображение на патчи и выстраиваем их в последовательность — как будто работаем с текстом — модель изначально не понимает пространственные взаимосвязи между патчами. Чтобы это исправить, мы добавляем позиционную информацию к патчам. В нашем случае pos_embed — это обучаемый параметр.
Что касается mlp_head, здесь всё просто: он берёт cls_token, пропускает его через линейный слой и классифицирует изображение.
После сборки нашей модели давайте перейдём к обучению.
Для обучения мы будем использовать следующие гиперпараметры:
BATCH_SIZE = 128 EPOCHS = 80 LEARNING_RATE = 3e-4 PATCH_SIZE = 4 NUM_CLASSES = 10 IMAGE_SIZE = 32 CHANNELS = 3 EMBED_DIM = 256 NUM_HEADS = 8 DEPTH = 6 MLP_DIM = 512 DROP_RATE = 0.1
Давайте посмотрим на количество параметров:
Total parameters: 3,189,514 Trained parameters: 3,189,514
А также следующие аугментации:
train_transforms = transforms.Compose([ transforms.Resize((70, 70)), transforms.RandomResizedCrop(32, scale=(0.8, 1.0)), transforms.RandomHorizontalFlip(p=0.5), transforms.RandomRotation(15), transforms.ColorJitter(0.2, 0.2, 0.2, 0.1), transforms.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.9, 1.1)), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), transforms.RandomErasing(p=0.25, scale=(0.02, 0.2), ratio=(0.3, 3.3), value='random'), ])
Мы обучаем модель и получаем следующие результаты:

Давайте посмотрим на предсказания модели:

Вот метрики получившейся модели:


А теперь к самой интересной части — вниманию. Давайте посмотрим, на что наша модель обращает внимание во время классификации:


В этой статье мы подробно рассмотрели реализацию Vision Transformer и его механизма внимания. Мы изучили, на что способна эта модель и как она «смотрит» на изображение с помощью механизма внимания. Vision Transformer открыл новые направления в исследовании компьютерного зрения, объединив идеи из NLP и обработки изображений. В будущем мы обязательно применим эту модель для задачи генерации подписей к изображениям (Image Captioning).
Полный код и процесс обучения вы можете найти на моём Kaggle:
https://www.kaggle.com/code/nickr0ot/visual-transformer-from-scratch
