Как стать автором
Обновить
84.71

Mixture of Experts: когда нейросеть учится делегировать

Уровень сложностиСредний
Время на прочтение8 мин
Количество просмотров918

Привет, чемпионы!

Представьте, что у вас есть большой и сложный проект, и вы наняли двух управленцев: Кабан-Кабаныча и Руководителева. Вы даете им одинаковую задачу: набрать штат сотрудников и выполнить ваш проект. Вся прибыль вместе с начальным бюджетом останется у них.

Кабан-Кабаныч решил, что нет смысла платить отдельным специалистам по DevOps, backend, ML и другим направлениям, и нанял всего одного сотрудника за 80 монеток. Этот бедняга работал в стиле «один за всех» и, естественно, быстро выгорел и «умер». Кабан-Кабаныч, не долго думая, нанял еще одного такого же сотрудника. В итоге вы вернулись и увидели печальную картину: задачу никто не решил, остался лишь Кабан-Кабаныч и кладбище несчастных сотрудников.

А вот Руководителев поступил иначе: он распределил бюджет на несколько похожих сотрудников, но сначала не понимал, кто из них в чём лучше. Тогда он стал давать им небольшие задачи и внимательно наблюдать за результатами. Через некоторое время он понял, что сотрудник №1 на 70% лучше справляется с задачами по ML, сотрудник №2 на 80% эффективнее в backend-разработке и так далее. Так Руководителев постепенно сформировал команду экспертов, сам став управляющим (или "gating"-узлом), который распределяет задачи на основе знаний о возможностях каждого сотрудника. Сотрудники углубляли экспертизу в своих направлениях, а Руководителев становился всё эффективнее в распределении задач.

Внезапно мы пришли к интересному решению:

  • Руководителев — это gating network, который распределяет задачи, исходя из предыдущих успехов сотрудников.

  • Сотрудники — это local experts, каждый из которых специализируется на своей части задач.

Таким образом, мы экономим ресурсы, получаем сильных специалистов и достигаем отличных результатов за короткое время.

Именно так в 1991 году и появилось решение Adaptive Mixtures of local Experts

Этот подход доказал эффективность, сокращая время обучения моделей почти вдвое.

Как работает MoE?

Представьте модель, у которой есть входные и выходные данные, а между ними набор экспертов. Этих экспертов организует управляющая сеть (gating network), определяющая, какие эксперты могут лучше справиться с конкретной задачей. Gating-сеть, которая присваивает веса результату каждого эксперта, объединяя их в итоговый ответ.

Звучит красиво, но не всё так просто... Во время обучения возникают интересные и даже «ломающие мозг» ситуации, особенно когда осознаёшь, что созданная тобой модель может «вынести» тебя самого.

Conditional Computation одна из фишек MoE: возможность отключать или частично использовать экспертов. Это позволяет комбинировать разные архитектуры, каждая из которых выявляет уникальные паттерны в данных. Модель становится гибкой: сама решает, каких экспертов задействовать активно, кого игнорировать, а кого подключить чуть-чуть.

Ключевая особенность — разреженность. С помощью MoE можно масштабировать модель без пропорционального увеличения вычислительной нагрузки. Это очень важно, ведь позволяет обучать огромное количество экспертов, используя при этом только нужных. В этом нам помогает важный гиперпараметр — top_k, определяющий, сколько лучших экспертов будет выбрано для каждого входа.

Но основные сложности начинаются с настройки гиперпараметров и архитектурных решений. Самая большая проблема MoE — это «прилипание гейта», когда маршрутизатор начинает постоянно выбирать одних и тех же экспертов. Эти избранные эксперты получают больше данных и быстрее обучаются, в то время как остальные «скучают и пьют кофе».

Возникает закономерный вопрос: зачем тогда вообще нужны остальные эксперты?

Как с этим бороться? В своём коде я добавил трекер распределения данных по экспертам, чтобы контролировать, не «залип» ли гейт. Также я внедрил несколько хитрых решений, подсмотренных на профессиональных форумах.

Давайте кратко резюмируем:

Технология MoE выгодна за счёт разреженности и гибкости использования экспертов. Однако это «сделка с дьяволом», поскольку возникают сложности:

  • Сложная балансировка работы экспертов.

  • Функция потерь должна учитывать как производительность экспертов, так и маршрутизатора.

  • Количество гиперпараметров (количество экспертов, архитектура gating-сети) усложняет настройку модели.

Где сейчас используют MoE?

Почти все современные LLM используют MoE. Например, недавно вышедшая модель Llama4 Scout с 16x17B параметрами — это 16 экспертов по 17 миллиардов параметров каждый. То есть на инференсе вы используете не все 272 млрд параметров, а только top_k выбранных. Впечатляющее снижение вычислительных затрат, правда?

Также технология активно применяется в компьютерном зрении, и сейчас мы её протестируем на простом примере V-MoEs.

Тест драйв технологии

Итак, для обучения возьмем простенький датасет CIFAR100 и обучим на нем нашу кастомную V-MoEs для классификации изображений.

Сама по себе архитектура будет состоять из следующего:

Классический VIT, но ее часть классификатора мы обернем в decoder блок, где у нас будет применена MOE

Начнем с маршрутизатора, в нашем случае он был реализован следующим образом

import torch
import torch.nn as nn
import torch.nn.functional as F

class GatingNetwork(nn.Module):
    def __init__(self,
                 input_dim = 151296,
                 num_experts=4,
                 top_k=2,
                 use_noise=True,
                 noise_std=1e-2,
                 temperature=1.0):
        super().__init__()

        self.num_experts = num_experts
        self.top_k = top_k
        self.use_noise = use_noise
        self.noise_std = noise_std
        self.temperature = temperature

        self.gate = nn.Sequential(
            nn.Linear(input_dim, 512),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, num_experts)
        )

    def forward(self, x):
        logits = self.gate(x)  # (B, num_experts)

        if self.use_noise and self.training:
            scale = logits.std(dim=1, keepdim=True).clamp(min=1e-3)  
            noise = torch.randn_like(logits) * self.noise_std * scale
            logits = logits + noise

        topk_vals, topk_indices = torch.topk(logits, self.top_k, dim=1)

        gates = F.softmax(topk_vals / self.temperature, dim=1)  # (B, top_k)

        return topk_indices, gates

Он берёт входной вектор x, оценивает, какие эксперты из num_experts лучше подойдут для каждого примера в батче, и возвращает top_k лучших экспертов с их весами.

То есть это — классическая Gating Network, которая решает, каким экспертам дать поработать с входом.

Обратите внимание, что тут есть noisy gating — это один из способов избежать "залипания гейта" на одном и том же эксперте. Во время тренировки шум масштабируется и в зависимости от поставленной нами пропорции влияет на решение о том какого эксперта повыбирать. Иными словами мы влияем на "результатова", чтобы он давал шансы большему числу экспертов, а не выбирал любимчиков.

Создадим экспертов

import torch.nn as nn


class FFNExpert(nn.Module):
    def __init__(self, input_dim, hidden_dims, output_dim, dropout_prob=0.5):
        super(FFNExpert, self).__init__()

        layers = []
        self.linears = nn.ModuleList()

        prev_dim = input_dim
        for hidden_dim in hidden_dims:
            linear = nn.Linear(prev_dim, hidden_dim)
            self.linears.append(linear)
            layers.append(linear)
            layers.append(nn.LayerNorm(hidden_dim))
            layers.append(nn.ReLU())
            layers.append(nn.Dropout(dropout_prob))
            prev_dim = hidden_dim

        final_linear = nn.Linear(prev_dim, output_dim)
        self.linears.append(final_linear)
        layers.append(final_linear)

        self.network = nn.Sequential(*layers)
        self._initialize_weights()

    def _initialize_weights(self):
        for linear in self.linears:
            nn.init.xavier_uniform_(linear.weight)
            if linear.bias is not None:
                nn.init.zeros_(linear.bias)

    def forward(self, x):
        return self.network(x)


class FFNExpertSmall(FFNExpert):
    def __init__(self, input_dim, output_dim):
        super(FFNExpertSmall, self).__init__(input_dim, hidden_dims=[256, 128], output_dim=output_dim, dropout_prob=0.3)


class FFNExpertMedium(FFNExpert):
    def __init__(self, input_dim, output_dim):
        super(FFNExpertMedium, self).__init__(input_dim, hidden_dims=[512, 256, 128], output_dim=output_dim,
                                              dropout_prob=0.4)


class FFNExpertLarge(FFNExpert):
    def __init__(self, input_dim, output_dim):
        super(FFNExpertLarge, self).__init__(input_dim, hidden_dims=[1024, 512, 256, 128], output_dim=output_dim,
                                             dropout_prob=0.5)


class FFNExpertVeryLarge(FFNExpert):
    def __init__(self, input_dim, output_dim):
        super(FFNExpertVeryLarge, self).__init__(input_dim, hidden_dims=[2048, 1024, 512, 256, 128],
                                                 output_dim=output_dim, dropout_prob=0.6)

Тут в целом все просто, мы набросали 4 эксперта с разными параметрами и посмотрим на то как они будут обучаться.

Начнем собирать модель

import torch.nn as nn
import timm

class ViT_backbone(nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone = timm.create_model('vit_base_patch16_224',
                                          pretrained=True)

        for param in self.backbone.parameters():
            param.requires_grad = False

        self.embed_dim = self.backbone.head.in_features

        self.backbone.reset_classifier(0)

        self.ln = nn.LayerNorm(self.embed_dim)
        self.ln2 = nn.LayerNorm(self.embed_dim)
        self.attn = nn.MultiheadAttention(embed_dim=self.embed_dim,
                                          num_heads=8,
                                          batch_first=True)

    def forward(self, x):
        skip = self.backbone.forward_features(x)  # [B, N, D]
        x_ln = self.ln(skip)
        attn_out, _ = self.attn(x_ln, x_ln, x_ln)
        x_attn = attn_out + skip
        x_final = self.ln2(x_attn).flatten(1)  # [B, N*D]
        return x_final

Тут все просто возьмем классическую модель VIT и добавим к ней слои нормализации после Multihead Attention и skip connection.

После сделаем наше объединение и наконец-то MOE

import torch
import torch.nn as nn
import torch.nn.functional as F
from model.gating_network import GatingNetwork
from model.Vit_model import ViT_backbone

class MoECNN(nn.Module):
    def __init__(self,
                 experts,
                 input_for_gating = 151296,
                 top_k=2,
                 output_dim=100,
                 use_aux_loss=True,
                 aux_loss_weight=0.01,
                 warmup_iters=500,
                 noise_std = 0.5):
        super().__init__()
        self.num_experts = len(experts)
        self.top_k = top_k
        self.output_dim = output_dim
        self.use_aux_loss = use_aux_loss
        self.aux_loss_weight = aux_loss_weight
        self.warmup_iters = warmup_iters
        self.iter = 0

        self.backbone = ViT_backbone()
        self.experts = nn.ModuleList(experts)
        self.gating = GatingNetwork( input_dim = input_for_gating,
                                     num_experts=self.num_experts,
                                     top_k = top_k,
                                     noise_std=noise_std)

        self.register_buffer("expert_usage",
                             torch.zeros(self.num_experts))

    def forward(self, x):
        batch_size = x.size(0)
        device = x.device
        x = self.backbone(x)

        if self.training and self.iter < self.warmup_iters:
            random_indices = torch.randint(0,
                                           self.num_experts,
                                           (batch_size, self.top_k),
                                           device=device)
            gates = torch.full((batch_size, self.top_k),
                               1.0 / self.top_k,
                               device=device)
            topk_indices = random_indices
            self.iter += 1
        else:
            topk_indices, gates = self.gating(x)

        output = torch.zeros(batch_size, self.output_dim, device=device)
        self.expert_usage.zero_()

        for i in range(self.top_k):
            idx = topk_indices[:, i]
            for expert_idx in torch.unique(idx):
                expert_mask = (idx == expert_idx)
                if expert_mask.sum() == 0:
                    continue
                x_sel = x[expert_mask]
                y_sel = self.experts[expert_idx](x_sel)
                gate_weight = gates[expert_mask, i].unsqueeze(1)
                output[expert_mask] += gate_weight * y_sel

                self.expert_usage[expert_idx] += expert_mask.sum()

        aux_loss = None
        if self.use_aux_loss and self.training:
            usage = self.expert_usage / batch_size
            aux_loss = ((usage - usage.mean()) ** 2).mean() * self.aux_loss_weight

        return output, aux_loss

Первое на что обратим внимание это это warmup_iters. Тут у нас это число итераций где мы как-бы отключаем gating-сеть , чтобы избежать коллапса распределения (один эксперт выбирается чаще остальных до того, как сеть обучится разумно маршрутизировать входы). Это дает нам "разогреть" экспертов передавая им равномерно данные и далее мы уже начинаем более тонко избирать экспертов за счет gating network.

Второй момент это добавление use_aux_loss. Данный параметр позволяет нам учитывать в общем лоссе неравномерное распределение по экспертам в общий loss.

Как итог модель выбирает tok_k экспертов и на основе их предсказаний делает взвешенную сумму, после чего выдает результат и loss по распределению.

Что в итоге?

При простом "на коленке" мы смогли получить f1 на тесте 89.%. Более явно поиграв с гипперпараметрами, типами экспертов и некоторыми изощренностями думаю, что можно получить результат лучше. Самое главное, что

Давайте теперь проведем модель через один батч и посмотрим, что там на одном батче, что произошло по графику и посмотрим на первые 10 сэмплов батча.

Как можем увидеть, у нас 2 эксперт оказался в данной итерации не востребован, а использовали мы с 0,1,3 эксперта в разной пропорции.

Можно сказать, что вот: "второй эксперт переобучился или обучился плохо". Однако давайте глянем глубже! Мы ведь отслеживаем все через clearml :)

На тестовом датасете в среднем можно заметить, что все эксперты примерно вышли на какую-то свою зону ответственности. Хотя конечно от первого эксперта хотелось ожидать побольше!

Теперь давайте посмотрим на визуализацию результатов:

Несмотря на шакальность(мы работаем с CIFAR100 напоминаю) мы получили весьма неплохие результаты. И теперь вишенка на торте - это отслеживание по экспертам. Их собственно говоря мы итак логируем и сейчас можем провести на маленьком сэмпле аналитику. Если у нас есть очень большой эксперт и он не пригодился в использовании в вычислениях, то мы можем сэкономить очень много памяти.

Подводя итоги основным концептом было показать какие проблемы бывают и сложности при работе с технологией, а также ее возможности и потенциал, который уже сейчас очень успешно реализуется!

🔥 Ставьте лайк и напишите какие темы было бы интересно разобрать дальше! Самое главное — пробуйте и экспериментируйте!

✔️ Присоединяйтесь к нашему Telegram-сообществу @datafeeling, чтобы первыми применять на практике передовые технологии!

Теги:
Хабы:
Если эта публикация вас вдохновила и вы хотите поддержать автора — не стесняйтесь нажать на кнопку
+7
Комментарии2

Публикации

Информация

Сайт
t.me
Дата регистрации
Дата основания
Численность
2–10 человек
Местоположение
Россия