Search
Write a publication
Pull to refresh

Xe vs He: кого брать-то?

Level of difficultyEasy
Reading time5 min
Views1.3K

Как известно из простеньких курсов по DS, есть два алгоритма инициализации весов в скрытых слоях нейронных сетей, будто свертки или полносвязные слои. В первые, когда лично я про них узнал, не сразу понял суть различия принципов их работы. В этой статье я попробую обойтись без сложной математики ( базовые выкладки мат.стата все равно будут, крепитесь) и показать на простых примерах разницу между этими двумя.

Почему нам вообще нужны какие-либо алгоритмы?

Как изменяется BCELoss с каждой эпохой при инициализации весов случайными числами
Как изменяется BCELoss с каждой эпохой при инициализации весов случайными числами

Допустим, вы решили проинициализировать веса случайными числами, в силу того, что это кажется вам простым и элегантным решением. Давайте обучим простой перцептрон на датасете MNIST c такой инициализацией и посмотрим, к чему это может привести. Возьмем за бейслайн результат изменения лоса ( я взял кросс энтропию ). Вроде неплохо, но все познается в сравнении. Теперь применим Xe, но сначала пару слов про принцип его работы

Инициализация Глорота ( Xe )

Справедливости ради сначала нужно уточнить, что инициализация весов случайными числами в целом неплоха для ситуации, когда размерности выходного и входного слоев совпадают, но в датасете MNIST на вход перцептрону мы передаем 784 фичей, а на выходе хотим видеть 10 нейронов, отвечающих за 10 цифр, которые мы и классифицируем.
Итак, суть алгоритма предложенного Xavier Glorot и Yoshua Bengio (hence the name btw) заключается в том, что веса беруться из распределения, дисперсия которого задается следующей формулой

Var[w_{i}] = \frac{2}{n_{in} + n_{out}}

Следовательно для, допустим, нормального распределения N это будет выглядить следующим образом

w _{i}\sim N  \left( 0,\frac{2}{n_{in} + n_{out}} \right)

Где w итая - это итый вес в сети соответственно, ребята. Давайте теперь применим на практике и посмотрим как меняется Loss

Оранжевая - Xe. Синяя - randint
Оранжевая - Xe. Синяя - randint

Видно, что, во первых, значения первого лосса стабильно начинаются с более низкого числа, да и в целом тенденция хорошая.
Вам все еще может быть не понятно, почему Xe улучшает качество обучения модели.Проще будет рассуждать от противного. Случайная инициализация весов ( даже если мы можем, к примеру, регулировать ее размерность) может допустить слишком большие значения или слишкмо малые, что приведет или к затуханию нейронов или к их взрыву, т.е.

Если веса слишком большие → активации взрываются (значения уходят в ±∞).Если веса слишком маленькие → активации затухают (все значения стремятся к 0).

Xe как бы регулирует размерность начальных весов в соответствии с кол-вом нейронов на входы и выходе сети и позволяет сохранить дисперсию активаций на прямом проходе / дисперсию градиентов на обратном проходе для линейных и симметричных насыщаемых (tanh) функций. Для примера приведу график из оригинальной статьи (https://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf)



Основная слабость Xe все же в его адоптированности к симметричным функциям в силу его распределения ( посмотрите на формулы выше ). В следствии чего он может выдавать отрицательные числа, которые, в свою очередь, плохо сочетаются с функцией активации ReLU.

Тут видно, что все отрицательные числа обнуляются, что замедляет обучение сети при backpropogation.
Почему нам вообще важно думать о существовании этой функции активации? Все просто - модификации ReLU (GELU, Leaky ReLU) используются в огромном кол-ве моделей, так как в них симмертричные функции, такие как тот же тангенс, просто выдают результаты хуже.
Достаточно учесть, что половина активаций будет равна нулю или около нулевому значению, что можно сделать, просто увеличив дисперсию. Так и появилась He инициализация

Инициализация Кайминга  ( He )

Суть таже, но дисперсия считается иначе ( почему так, я описал сверху )

Var[w_{i}] = \frac{2}{n_{in}}

И для нормального распределения N это будет выглядить следующим образом

w _{i}\sim N  \left( 0,\frac{2}{n_{in} } \right)

На предыдущих графиках я сравнивал результаты работы модели, функцией активации которой был tanh. Теперь же, чтобы показать, что He лучше показывает себя в работе с ReLU, я поменяю функцию активации на нее, соответственно.

оранж. - Xe, син. - He
оранж. - Xe, син. - He

Все в целом очевидно, в добавок приведу графики из оригинальной статьи (https://arxiv.org/abs/1502.01852v1)

Выводы

Отвечая на вопрос в названии статьи, можем прийти к простому выводу

He — де-факто стандарт для современных сетей c ReLU/ReLU-like. Xavier — для сетей c Tanh/Sigmoid, НО
Современные архитектуры ( которые в основе имеют ResNet, ViT) используют He (Kaiming) для весов Conv/Linear слоев перед ReLU/GELU. Это стало стандартом и поэтому подытоживая можно сказать: Выбирайте He!

Немного о коде

Ну, не могу не привести примеры того, как в pytorch использовать инициализацию весов. Все достаточно просто

# в разделе nn библиотеки pytorch есть две такие функии
nn.init.xavier_uniform_() #Xe
nn.init.kaiming_uniform_() #He
# Использование можно описать следующим образом
nn.init.xavier_uniform_(self.lin.weight)
nn.init.kaiming_uniform_(self.lin.weight,nonlinearity='relu')
# где self.lin это полносвязный слой в классе модели перцептрона, а переменная weight
# это очевидно, веса слоя.

Ну и оставлю весь код, который я использовал. Конкретно тут приведен код для сравнения случайной инициализации против Xe алгоритма, когда модель использует tanh ф.а.

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt

class Perc(nn.Module):
    def __init__(self, hid_l_size: int, activation: str, classes: int, in_size: int, init_type: str):
        super().__init__()

        self.lin = nn.Linear(in_size, hid_l_size)
        self.lin2 = nn.Linear(hid_l_size, classes)
        self.init_type = init_type
        self.init_weights()

        if activation == 'relu':
            self.activation = nn.ReLU()
        elif activation == 'tanh':
            self.activation = nn.Tanh()
        else:
            raise AttributeError('No such activation func')

    def forward(self, x):
        x = self.lin(x)
        x = self.activation(x)
        x = self.lin2(x)
        return x

    def init_weights(self):
        if self.init_type == 'xe':
            nn.init.xavier_uniform_(self.lin.weight)
            nn.init.xavier_uniform_(self.lin2.weight)
        elif self.init_type == 'he':
            nn.init.kaiming_uniform_(self.lin.weight, mode='fan_in', nonlinearity='relu')
            nn.init.kaiming_uniform_(self.lin2.weight, mode='fan_in', nonlinearity='relu')
        else:
            nn.init.normal_(self.lin.weight, mean=0.0, std=0.01)
            nn.init.normal_(self.lin2.weight, mean=0.0, std=0.01)


transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)


model = Perc(128, 'tanh', 10, 784, 'randint')
opt = optim.Adam(model.parameters())
loss_fn = nn.CrossEntropyLoss()

data_for_graph_1 = []

for epoch in range(30):
    epoch_loss = 0.0
    for X, target in train_loader:
        X = X.view(X.size(0), -1)

        opt.zero_grad()
        res = model(X)
        loss = loss_fn(res, target)
        loss.backward()
        opt.step()

        epoch_loss += loss.item()


    avg_loss = epoch_loss / len(train_loader)
    data_for_graph_1.append(avg_loss)
    print(f'Epoch {epoch+1}, Loss: {avg_loss:.4f}')
model = Perc(128, 'tanh',10, 784, 'xe')
opt = optim.Adam(model.parameters())
loss_fn = nn.CrossEntropyLoss()

data_for_graph = []

for epoch in range(30):
    epoch_loss = 0.0
    for X, target in train_loader:
        X = X.view(X.size(0), -1)

        opt.zero_grad()
        res = model(X)
        loss = loss_fn(res, target)
        loss.backward()
        opt.step()

        epoch_loss += loss.item()


    avg_loss = epoch_loss / len(train_loader)
    data_for_graph.append(avg_loss)
    print(f'Epoch {epoch+1}, Loss: {avg_loss:.4f}')


plt.plot(data_for_graph_1,label = 'randint')
plt.plot(data_for_graph,label = 'xe')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss over Epochs')
plt.grid(True)
plt.show()

Статьи

Tags:
Hubs:
+7
Comments0

Articles