
Привет, чемпионы!
Представьте, что у вас есть большой и сложный проект, и вы наняли двух управленцев: Кабан-Кабаныча и Руководителева. Вы даете им одинаковую задачу: набрать штат сотрудников и выполнить ваш проект. Вся прибыль вместе с начальным бюджетом останется у них.
Кабан-Кабаныч решил, что нет смысла платить отдельным специалистам по DevOps, backend, ML и другим направлениям, и нанял всего одного сотрудника за 80 монеток. Этот бедняга работал в стиле «один за всех» и, естественно, быстро выгорел и «умер». Кабан-Кабаныч, не долго думая, нанял еще одного такого же сотрудника. В итоге вы вернулись и увидели печальную картину: задачу никто не решил, остался лишь Кабан-Кабаныч и кладбище несчастных сотрудников.

А вот Руководителев поступил иначе: он распределил бюджет на несколько похожих сотрудников, но сначала не понимал, кто из них в чём лучше. Тогда он стал давать им небольшие задачи и внимательно наблюдать за результатами. Через некоторое время он понял, что сотрудник №1 на 70% лучше справляется с задачами по ML, сотрудник №2 на 80% эффективнее в backend-разработке и так далее. Так Руководителев постепенно сформировал команду экспертов, сам став управляющим (или "gating"-узлом), который распределяет задачи на основе знаний о возможностях каждого сотрудника. Сотрудники углубляли экспертизу в своих направлениях, а Руководителев становился всё эффективнее в распределении задач.
Внезапно мы пришли к интересному решению:
Руководителев — это
gating network
, который распределяет задачи, исходя из предыдущих успехов сотрудников.Сотрудники — это
local experts
, каждый из которых специализируется на своей части задач.

Таким образом, мы экономим ресурсы, получаем сильных специалистов и достигаем отличных результатов за короткое время.
Именно так в 1991 году и появилось решение Adaptive Mixtures of local Experts
Этот подход доказал эффективность, сокращая время обучения моделей почти вдвое.
Как работает MoE?
Представьте модель, у которой есть входные и выходные данные, а между ними набор экспертов. Этих экспертов организует управляющая сеть (gating network
), определяющая, какие эксперты могут лучше справиться с конкретной задачей. Gating-сеть, которая присваивает веса результату каждого эксперта, объединяя их в итоговый ответ.
Звучит красиво, но не всё так просто... Во время обучения возникают интересные и даже «ломающие мозг» ситуации, особенно когда осознаёшь, что созданная тобой модель может «вынести» тебя самого.
Conditional Computation одна из фишек MoE: возможность отключать или частично использовать экспертов. Это позволяет комбинировать разные архитектуры, каждая из которых выявляет уникальные паттерны в данных. Модель становится гибкой: сама решает, каких экспертов задействовать активно, кого игнорировать, а кого подключить чуть-чуть.
Ключевая особенность — разреженность. С помощью MoE можно масштабировать модель без пропорционального увеличения вычислительной нагрузки. Это очень важно, ведь позволяет обучать огромное количество экспертов, используя при этом только нужных. В этом нам помогает важный гиперпараметр — top_k
, определяющий, сколько лучших экспертов будет выбрано для каждого входа.

Но основные сложности начинаются с настройки гиперпараметров и архитектурных решений. Самая большая проблема MoE — это «прилипание гейта», когда маршрутизатор начинает постоянно выбирать одних и тех же экспертов. Эти избранные эксперты получают больше данных и быстрее обучаются, в то время как остальные «скучают и пьют кофе».
Возникает закономерный вопрос: зачем тогда вообще нужны остальные эксперты?
Как с этим бороться? В своём коде я добавил трекер распределения данных по экспертам, чтобы контролировать, не «залип» ли гейт. Также я внедрил несколько хитрых решений, подсмотренных на профессиональных форумах.
Давайте кратко резюмируем:
Технология MoE выгодна за счёт разреженности и гибкости использования экспертов. Однако это «сделка с дьяволом», поскольку возникают сложности:
Сложная балансировка работы экспертов.
Функция потерь должна учитывать как производительность экспертов, так и маршрутизатора.
Количество гиперпараметров (количество экспертов, архитектура gating-сети) усложняет настройку модели.
Где сейчас используют MoE?
Почти все современные LLM используют MoE. Например, недавно вышедшая модель Llama4 Scout с 16x17B параметрами — это 16 экспертов по 17 миллиардов параметров каждый. То есть на инференсе вы используете не все 272 млрд параметров, а только top_k выбранных. Впечатляющее снижение вычислительных затрат, правда?
Также технология активно применяется в компьютерном зрении, и сейчас мы её протестируем на простом примере V-MoEs.
Тест драйв технологии
Итак, для обучения возьмем простенький датасет CIFAR100 и обучим на нем нашу кастомную V-MoEs для классификации изображений.
Сама по себе архитектура будет состоять из следующего:
Классический VIT, но ее часть классификатора мы обернем в decoder блок, где у нас будет применена MOE

Начнем с маршрутизатора, в нашем случае он был реализован следующим образом
import torch
import torch.nn as nn
import torch.nn.functional as F
class GatingNetwork(nn.Module):
def __init__(self,
input_dim = 151296,
num_experts=4,
top_k=2,
use_noise=True,
noise_std=1e-2,
temperature=1.0):
super().__init__()
self.num_experts = num_experts
self.top_k = top_k
self.use_noise = use_noise
self.noise_std = noise_std
self.temperature = temperature
self.gate = nn.Sequential(
nn.Linear(input_dim, 512),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(512, 256),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(256, num_experts)
)
def forward(self, x):
logits = self.gate(x) # (B, num_experts)
if self.use_noise and self.training:
scale = logits.std(dim=1, keepdim=True).clamp(min=1e-3)
noise = torch.randn_like(logits) * self.noise_std * scale
logits = logits + noise
topk_vals, topk_indices = torch.topk(logits, self.top_k, dim=1)
gates = F.softmax(topk_vals / self.temperature, dim=1) # (B, top_k)
return topk_indices, gates
Он берёт входной вектор x
, оценивает, какие эксперты из num_experts
лучше подойдут для каждого примера в батче, и возвращает top_k
лучших экспертов с их весами.
То есть это — классическая Gating Network, которая решает, каким экспертам дать поработать с входом.
Обратите внимание, что тут есть noisy gating — это один из способов избежать "залипания гейта" на одном и том же эксперте. Во время тренировки шум масштабируется и в зависимости от поставленной нами пропорции влияет на решение о том какого эксперта повыбирать. Иными словами мы влияем на "результатова", чтобы он давал шансы большему числу экспертов, а не выбирал любимчиков.
Создадим экспертов
import torch.nn as nn
class FFNExpert(nn.Module):
def __init__(self, input_dim, hidden_dims, output_dim, dropout_prob=0.5):
super(FFNExpert, self).__init__()
layers = []
self.linears = nn.ModuleList()
prev_dim = input_dim
for hidden_dim in hidden_dims:
linear = nn.Linear(prev_dim, hidden_dim)
self.linears.append(linear)
layers.append(linear)
layers.append(nn.LayerNorm(hidden_dim))
layers.append(nn.ReLU())
layers.append(nn.Dropout(dropout_prob))
prev_dim = hidden_dim
final_linear = nn.Linear(prev_dim, output_dim)
self.linears.append(final_linear)
layers.append(final_linear)
self.network = nn.Sequential(*layers)
self._initialize_weights()
def _initialize_weights(self):
for linear in self.linears:
nn.init.xavier_uniform_(linear.weight)
if linear.bias is not None:
nn.init.zeros_(linear.bias)
def forward(self, x):
return self.network(x)
class FFNExpertSmall(FFNExpert):
def __init__(self, input_dim, output_dim):
super(FFNExpertSmall, self).__init__(input_dim, hidden_dims=[256, 128], output_dim=output_dim, dropout_prob=0.3)
class FFNExpertMedium(FFNExpert):
def __init__(self, input_dim, output_dim):
super(FFNExpertMedium, self).__init__(input_dim, hidden_dims=[512, 256, 128], output_dim=output_dim,
dropout_prob=0.4)
class FFNExpertLarge(FFNExpert):
def __init__(self, input_dim, output_dim):
super(FFNExpertLarge, self).__init__(input_dim, hidden_dims=[1024, 512, 256, 128], output_dim=output_dim,
dropout_prob=0.5)
class FFNExpertVeryLarge(FFNExpert):
def __init__(self, input_dim, output_dim):
super(FFNExpertVeryLarge, self).__init__(input_dim, hidden_dims=[2048, 1024, 512, 256, 128],
output_dim=output_dim, dropout_prob=0.6)
Тут в целом все просто, мы набросали 4 эксперта с разными параметрами и посмотрим на то как они будут обучаться.
Начнем собирать модель
import torch.nn as nn
import timm
class ViT_backbone(nn.Module):
def __init__(self):
super().__init__()
self.backbone = timm.create_model('vit_base_patch16_224',
pretrained=True)
for param in self.backbone.parameters():
param.requires_grad = False
self.embed_dim = self.backbone.head.in_features
self.backbone.reset_classifier(0)
self.ln = nn.LayerNorm(self.embed_dim)
self.ln2 = nn.LayerNorm(self.embed_dim)
self.attn = nn.MultiheadAttention(embed_dim=self.embed_dim,
num_heads=8,
batch_first=True)
def forward(self, x):
skip = self.backbone.forward_features(x) # [B, N, D]
x_ln = self.ln(skip)
attn_out, _ = self.attn(x_ln, x_ln, x_ln)
x_attn = attn_out + skip
x_final = self.ln2(x_attn).flatten(1) # [B, N*D]
return x_final
Тут все просто возьмем классическую модель VIT и добавим к ней слои нормализации после Multihead Attention и skip connection.
После сделаем наше объединение и наконец-то MOE
import torch
import torch.nn as nn
import torch.nn.functional as F
from model.gating_network import GatingNetwork
from model.Vit_model import ViT_backbone
class MoECNN(nn.Module):
def __init__(self,
experts,
input_for_gating = 151296,
top_k=2,
output_dim=100,
use_aux_loss=True,
aux_loss_weight=0.01,
warmup_iters=500,
noise_std = 0.5):
super().__init__()
self.num_experts = len(experts)
self.top_k = top_k
self.output_dim = output_dim
self.use_aux_loss = use_aux_loss
self.aux_loss_weight = aux_loss_weight
self.warmup_iters = warmup_iters
self.iter = 0
self.backbone = ViT_backbone()
self.experts = nn.ModuleList(experts)
self.gating = GatingNetwork( input_dim = input_for_gating,
num_experts=self.num_experts,
top_k = top_k,
noise_std=noise_std)
self.register_buffer("expert_usage",
torch.zeros(self.num_experts))
def forward(self, x):
batch_size = x.size(0)
device = x.device
x = self.backbone(x)
if self.training and self.iter < self.warmup_iters:
random_indices = torch.randint(0,
self.num_experts,
(batch_size, self.top_k),
device=device)
gates = torch.full((batch_size, self.top_k),
1.0 / self.top_k,
device=device)
topk_indices = random_indices
self.iter += 1
else:
topk_indices, gates = self.gating(x)
output = torch.zeros(batch_size, self.output_dim, device=device)
self.expert_usage.zero_()
for i in range(self.top_k):
idx = topk_indices[:, i]
for expert_idx in torch.unique(idx):
expert_mask = (idx == expert_idx)
if expert_mask.sum() == 0:
continue
x_sel = x[expert_mask]
y_sel = self.experts[expert_idx](x_sel)
gate_weight = gates[expert_mask, i].unsqueeze(1)
output[expert_mask] += gate_weight * y_sel
self.expert_usage[expert_idx] += expert_mask.sum()
aux_loss = None
if self.use_aux_loss and self.training:
usage = self.expert_usage / batch_size
aux_loss = ((usage - usage.mean()) ** 2).mean() * self.aux_loss_weight
return output, aux_loss
Первое на что обратим внимание это это warmup_iters
. Тут у нас это число итераций где мы как-бы отключаем gating-сеть , чтобы избежать коллапса распределения (один эксперт выбирается чаще остальных до того, как сеть обучится разумно маршрутизировать входы). Это дает нам "разогреть" экспертов передавая им равномерно данные и далее мы уже начинаем более тонко избирать экспертов за счет gating network.
Второй момент это добавление use_aux_loss
. Данный параметр позволяет нам учитывать в общем лоссе неравномерное распределение по экспертам в общий loss.
Как итог модель выбирает tok_k
экспертов и на основе их предсказаний делает взвешенную сумму, после чего выдает результат и loss по распределению.
Что в итоге?
При простом "на коленке" мы смогли получить f1 на тесте 89.%. Более явно поиграв с гипперпараметрами, типами экспертов и некоторыми изощренностями думаю, что можно получить результат лучше. Самое главное, что

Давайте теперь проведем модель через один батч и посмотрим, что там на одном батче, что произошло по графику и посмотрим на первые 10 сэмплов батча.

Как можем увидеть, у нас 2 эксперт оказался в данной итерации не востребован, а использовали мы с 0,1,3 эксперта в разной пропорции.
Можно сказать, что вот: "второй эксперт переобучился или обучился плохо". Однако давайте глянем глубже! Мы ведь отслеживаем все через clearml :)

На тестовом датасете в среднем можно заметить, что все эксперты примерно вышли на какую-то свою зону ответственности. Хотя конечно от первого эксперта хотелось ожидать побольше!
Теперь давайте посмотрим на визуализацию результатов:

Несмотря на шакальность(мы работаем с CIFAR100 напоминаю) мы получили весьма неплохие результаты. И теперь вишенка на торте - это отслеживание по экспертам. Их собственно говоря мы итак логируем и сейчас можем провести на маленьком сэмпле аналитику. Если у нас есть очень большой эксперт и он не пригодился в использовании в вычислениях, то мы можем сэкономить очень много памяти.

Подводя итоги основным концептом было показать какие проблемы бывают и сложности при работе с технологией, а также ее возможности и потенциал, который уже сейчас очень успешно реализуется!
🔥 Ставьте лайк и напишите какие темы было бы интересно разобрать дальше! Самое главное — пробуйте и экспериментируйте!
✔️ Присоединяйтесь к нашему Telegram-сообществу @datafeeling, чтобы первыми применять на практике передовые технологии!