Vision Transformer (ViT) — это архитектура, которая буквально произвела революцию в том, как машины «видят» мир.
В этой статье я не просто объясню, что такое ViT — я покажу вам, как создать эту магию своими руками, шаг за шагом, даже если вы никогда раньше не работали с трансформерами для задач с изображениями.
Для начала давайте взглянем на архитектуру Vision Transformer:

Мы напишем код полностью с нуля, а затем обучим модель на датасете CIFAR-10.
Давайте начнём с реализации Patch Embedding:
class PatchEmbedding(nn.Module):
def __init__(self, img_size = 32, patch_size = 4, in_channels = 3, embed_dim=256):
super().__init__()
assert img_size % patch_size == 0, "img_size must be divisible by patch_size"
self.patch_size = patch_size
self.num_patches = (img_size//patch_size)**2
self.conv = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x):
x = self.conv(x) #(B, embed_dim, H/patch_size, W/patch_size)
x = x.flatten(2).transpose(1, 2) #(B, num_patches, embed_dim)
return x
Изображение будет разделено на патчи, и размер каждого патча можно задать с помощью параметра patch_size. При этом изображение не просто разбивается на патчи, но и пропускается через свёрточные ядра (CNN). В итоге мы получаем не просто патчи изображения — а встраивания (эмбеддинги) этих патчей.
Следующий шаг — реализовать самую интересную часть этой модели — механизм внимания (attention).

Q (Query) формально задаёт вопрос от каждого патча к другим патчам, K (Key) показывает, есть ли у каждого патча ответ на этот вопрос, а V (Value) содержит «значения» — фактические данные каждого патча, которые используются для формирования итогового представления.
Предположим, у нас есть X и Y, и мы хотим, чтобы X обращал внимание на Y. В этом случае матрица Query умножается на X, а матрицы Key и Value — на Y. Вместо прямого умножения на матрицы мы используем линейные слои.
attn_probs — это матрицы внимания, которые показывают, насколько токен i должен «обращать внимание» на токен j. Далее мы умножаем их на V, чтобы получить эмбеддинги изображения с учётом весов внимания attn_probs. V фактически хранит значения каждого патча изображения, а attn_probs показывает, сколько информации каждый патч должен получить от остальных патчей.
Вот как работает одна голова внимания; затем значения с всех голов объединяются. Такая конструкция основана на идее, что каждая голова фокусируется на разных аспектах.
class MultiHeadAttention(nn.Module):
def __init__(self, dim, num_heads, dropout):
super().__init__()
assert dim % num_heads == 0, "Embedding dimension must be divisible by number of heads"
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.qkv = nn.Linear(dim, dim * 3, bias = False)
self.out = nn.Linear(dim, dim, bias = False)
self.scale = 1.0 / (self.head_dim ** 0.5)
self.attn_dropout = nn.Dropout(dropout)
def forward(self, x, mask = None, return_attn=False):
B, num_patches, embed_dim = x.shape
qkv = self.qkv(x) # (B, num_patches, 3*embed_dim)
qkv = qkv.reshape(B, num_patches, 3, self.num_heads, self.head_dim)
qkv = qkv.permute(2, 0, 3, 1, 4) #(3, B, num_heads, num_patches, head_dim)
q, k, v = qkv[0], qkv[1], qkv[2] # each (B, num_heads, num_patches, head_dim)
#How important it is for token i to pay attention to token j.
attn_scores = (q @ k.transpose(-2, -1)) * self.scale #[B, num_heads, N, N]
if mask is not None:
# mask: (B, 1, N, N) or (1, 1, N, N)
attn_scores = attn_scores.masked_fill(mask == 0, float("-inf"))
attn_probs = attn_scores.softmax(dim=-1) #[B, num_heads, N, N]
attn_probs = self.attn_dropout(attn_probs)
attn_output = attn_probs @ v # (B, num_heads, num_patches, head_dim)
attn_output = attn_output.transpose(1, 2).reshape(B, num_patches, embed_dim)
if return_attn:
return self.out(attn_output), attn_probs
else:
return self.out(attn_output) #(B, num_patches, embed_dim)
Давайте перейдём к сборке блока Transformer Encoder:
class TransformerEncoderBlock(nn.Module):
def __init__(self, dim, num_heads, mlp_dim, dropout):
super().__init__()
self.norm1 = nn.LayerNorm(dim)
self.attn = MultiHeadAttention(dim, num_heads, dropout)
self.norm2 = nn.LayerNorm(dim)
self.mlp = nn.Sequential(
nn.Linear(dim, mlp_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(mlp_dim, dim),
nn.Dropout(dropout)
)
self.dropout = nn.Dropout(dropout)
def forward(self, x, return_attn=False):
if return_attn:
attn_out, attn_weights = self.attn(self.norm1(x), return_attn=True)
x = x + self.dropout(attn_out)
x = x + self.dropout(self.mlp(self.norm2(x)))
return x, attn_weights
else:
x = x + self.dropout(self.attn(self.norm1(x)))
x = x + self.dropout(self.mlp(self.norm2(x)))
return x
Здесь мы просто следуем архитектуре нашей сети — все необходимые блоки мы уже реализовали.
Мы уже почти на финишной прямой — теперь соберём сам Vision Transformer:
class VisualTransformer(nn.Module):
def __init__(self,num_classes, img_size=32, patch_size=4, in_channels=3, embed_dim=256,
num_layers=6, num_heads=7, mlp_dim=512, dropout=0.1):
super().__init__()
self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
self.pos_embed = nn.Parameter(torch.randn(1, 1 + self.patch_embed.num_patches, embed_dim))
self.dropout = nn.Dropout(dropout)
self.encoder_blocks = nn.ModuleList([
TransformerEncoderBlock(embed_dim, num_heads, mlp_dim, dropout)
for _ in range(num_layers)
])
self.norm = nn.LayerNorm(embed_dim)
self.mlp_head = nn.Sequential(
nn.LayerNorm(embed_dim),
nn.Linear(in_features=embed_dim, out_features=num_classes)
)
def forward(self, x, return_attn = False):
B = x.size(0)
x = self.patch_embed(x) # (B, N, D)
cls_tokens = self.cls_token.expand(B, -1, -1) # (B, 1, D)
x = torch.cat((cls_tokens, x), dim=1) # (B, 1+N, D)
x = x + self.pos_embed
x = self.dropout(x)
attn_maps = []
for block in self.encoder_blocks:
if return_attn:
x, attn = block(x, return_attn=True)
attn_maps.append(attn) # (B, heads, N, N)
else:
x = block(x) # (B, 1+N, D)
x = self.norm(x)
out = self.mlp_head(x[:, 0, :])
if return_attn:
return out, attn_maps
else:
return out
Здесь нужно уточнить несколько моментов. Что такое cls_token? Это специальный токен, который мы добавляем вручную, и он имеет тот же размер, что и патчи изображения. Его задача — использоваться позже для классификации изображения. Идея в том, что, проходя через блоки внимания, этот токен собирает информацию обо всём изображении.
Далее посмотрим на pos_embed. Поскольку мы делим изображение на патчи и выстраиваем их в последовательность — как будто работаем с текстом — модель изначально не понимает пространственные взаимосвязи между патчами. Чтобы это исправить, мы добавляем позиционную информацию к патчам. В нашем случае pos_embed — это обучаемый параметр.
Что касается mlp_head, здесь всё просто: он берёт cls_token, пропускает его через линейный слой и классифицирует изображение.
После сборки нашей модели давайте перейдём к обучению.
Для обучения мы будем использовать следующие гиперпараметры:
BATCH_SIZE = 128
EPOCHS = 80
LEARNING_RATE = 3e-4
PATCH_SIZE = 4
NUM_CLASSES = 10
IMAGE_SIZE = 32
CHANNELS = 3
EMBED_DIM = 256
NUM_HEADS = 8
DEPTH = 6
MLP_DIM = 512
DROP_RATE = 0.1
Давайте посмотрим на количество параметров:
Total parameters: 3,189,514
Trained parameters: 3,189,514
А также следующие аугментации:
train_transforms = transforms.Compose([
transforms.Resize((70, 70)),
transforms.RandomResizedCrop(32, scale=(0.8, 1.0)),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomRotation(15),
transforms.ColorJitter(0.2, 0.2, 0.2, 0.1),
transforms.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.9, 1.1)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
transforms.RandomErasing(p=0.25, scale=(0.02, 0.2), ratio=(0.3, 3.3), value='random'),
])
Мы обучаем модель и получаем следующие результаты:

Давайте посмотрим на предсказания модели:

Вот метрики получившейся модели:


А теперь к самой интересной части — вниманию. Давайте посмотрим, на что наша модель обращает внимание во время классификации:


В этой статье мы подробно рассмотрели реализацию Vision Transformer и его механизма внимания. Мы изучили, на что способна эта модель и как она «смотрит» на изображение с помощью механизма внимания. Vision Transformer открыл новые направления в исследовании компьютерного зрения, объединив идеи из NLP и обработки изображений. В будущем мы обязательно применим эту модель для задачи генерации подписей к изображениям (Image Captioning).
Полный код и процесс обучения вы можете найти на моём Kaggle:
https://www.kaggle.com/code/nickr0ot/visual-transformer-from-scratch