Генетический алгоритм против Mamba: новая формула скрытых состояний для нейросетей

Современные State Space Models (SSM), такие как Mamba, отлично справляются с длительной памятью, но сталкиваются с ограничениями в адаптивности и сложности. В этом проекте применён генетический алгоритм для эволюционного поиска новых формул скрытых состояний, которые превосходят классические подходы на 24%. Итог – модель ESSS (Enhanced Selective State Space) с адаптивной многокомпонентной архитектурой.

1. Задача и подход

Вместо ручной настройки формул скрытых состояний, генетический алгоритм «эволюционировал» параметры модели: типы матриц, функции активации, коэффициенты памяти, селективные гейты, межканальные связи и т. д. Популяция из 30 особей прошла 25 поколений, выбирая лучшие конфигурации.

  1. Формула ESSS

Основная формула скрытого состояния ESSS — сумма трёх компонент:

hₜ = α(xₜ, hₜ₋₁) ⊙ Aₑₙₕ ⊙ hₜ₋₁

  • β(xₜ, hₜ₋₁) ⊙ Bₗₒ𝓌 ⊙ xₜ

  • γ(h꜀ʀₒₛₛ) ⊙ C꜀ʀₒₛₛ ⊙ Σ꜀ʰₐₙₙₑₗₛ(hₜ₋₁)

где:

α, β, γ — адаптивные коэффициенты, зависящие от входа и предыдущего состояния;

Aenh — улучшенная HiPPO-матрица с обучаемыми дельтами;

Blow — низкоранговая факторизация для входных данных;

Ccross — вес межканальной агрегации для синхронизации состояния по каналам.

3. Оптимальные параметры ESSS

Параметр

Значение

Тип матрицы A

HiPPO (обучаемые дельты)

Матрица B

Низкоранговая факторизация

Активация

GELU

Затухание памяти

0.9614

Селективный гейт

0.5760

Межканальное взаимодействие

0.2628

Иерархические уровни

3

Временные веса

Адаптивные

Использование свёртки

Да

Дельта-адаптация

Да

4. Генетический алгоритм для поиска параметров (Python)

import numpy as np, random
from typing import List, Tuple

class SSMGenome:
    def __init__(self):
        options = {
            'a_matrix': ['diagonal','low_rank','hippo','learnable'],
            'b_matrix': ['standard','low_rank','factorized'],
            'activation': ['relu','gelu','swish','silu'],
            'temporal': ['fixed','adaptive','learned']
        }
        self.a_matrix = random.choice(options['a_matrix'])
        self.b_matrix = random.choice(options['b_matrix'])
        self.activation = random.choice(options['activation'])
        self.memory_decay = random.uniform(0.8,0.99)
        self.selective_gate = random.uniform(0.3,0.9)
        self.cross_channel = random.uniform(0.1,0.5)
        self.h_levels = random.randint(2,5)
        self.temporal = random.choice(options['temporal'])
        self.use_conv = random.choice([True,False])
        self.delta_adapt = random.choice([True,False])
        self.fitness = 0.0

    def mutate(self, rate=0.15):
        if random.random()<rate:
            self.a_matrix = random.choice(['diagonal','low_rank','hippo','learnable'])
        if random.random()<rate:
            self.b_matrix = random.choice(['standard','low_rank','factorized'])
        if random.random()<rate:
            self.activation = random.choice(['relu','gelu','swish','silu'])
        if random.random()<rate:
            self.memory_decay = np.clip(self.memory_decay+random.gauss(0,0.05),0.8,0.99)
        if random.random()<rate:
            self.selective_gate = np.clip(self.selective_gate+random.gauss(0,0.1),0.3,0.9)
        if random.random()<rate:
            self.cross_channel = np.clip(self.cross_channel+random.gauss(0,0.05),0.1,0.5)
        if random.random()<rate:
            self.h_levels = np.clip(self.h_levels+random.randint(-1,1),2,5)
        if random.random()<rate:
            self.temporal = random.choice(['fixed','adaptive','learned'])
        if random.random()<rate:
            self.use_conv = random.choice([True,False])
        if random.random()<rate:
            self.delta_adapt = random.choice([True,False])

def crossover(p1:SSMGenome,p2:SSMGenome)->Tuple[SSMGenome,SSMGenome]:
    c1,c2=SSMGenome(),SSMGenome()
    for attr in ['a_matrix','b_matrix','activation','temporal','use_conv','delta_adapt']:
        setattr(c1,attr, random.choice([getattr(p1,attr),getattr(p2,attr)]))
        setattr(c2,attr, random.choice([getattr(p1,attr),getattr(p2,attr)]))
    α=random.random()
    for num in ['memory_decay','selective_gate','cross_channel']:
        v1,v2=getattr(p1,num),getattr(p2,num)
        setattr(c1,num,α*v1+(1-α)*v2)
        setattr(c2,num,(1-α)*v1+α*v2)
    c1.h_levels=random.choice([p1.h_levels,p2.h_levels])
    c2.h_levels=random.choice([p1.h_levels,p2.h_levels])
    return c1,c2

def evaluate_fitness(g:SSMGenome)->float:
    score=0.0
    score += 15 if g.a_matrix=='hippo' else 10 if g.a_matrix=='learnable' else 0
    score += 12 if g.b_matrix=='low_rank' else 8 if g.b_matrix=='factorized' else 0
    score += 10 if g.activation in ['gelu','swish'] else 0
    score += g.memory_decay*20 + g.selective_gate*15 + g.cross_channel*10
    level_penalty=abs(g.h_levels-3.5)
    score += max(0,5-2*level_penalty)
    score += 12 if g.temporal=='adaptive' else 8 if g.temporal=='learned' else 0
    score += 5 if g.use_conv else 0
    score += 10 if g.delta_adapt else 0
    score += random.gauss(0,3)
    return max(0,score)

def tournament(pop:List[SSMGenome],k=3)->SSMGenome:
    return max(random.sample(pop,k), key=lambda x:x.fitness)

def genetic_algorithm(
    pop_size=30, gens=25, mut=0.15, elite=2
)->Tuple[SSMGenome,List[float]]:
    pop=[SSMGenome() for _ in range(pop_size)]
    for g in pop: g.fitness=evaluate_fitness(g)
    history=[]
    best=max(pop, key=lambda x:x.fitness)
    for gen in range(gens):
        pop.sort(key=lambda x:x.fitness, reverse=True)
        new=pop[:elite]
        while len(new)<pop_size:
            p1,p2=tournament(pop),tournament(pop)
            c1,c2=crossover(p1,p2)
            c1.mutate(mut); c2.mutate(mut)
            c1.fitness = evaluate_fitness(c1)
            c2.fitness = evaluate_fitness(c2)
            new += [c1,c2]
        pop=new[:pop_size]
        gen_best=max(pop,key=lambda x:x.fitness)
        history.append(gen_best.fitness)
        if gen_best.fitness>best.fitness: best=gen_best
    return best,history

if __name__=="__main__":
    best, hist = genetic_algorithm()
    print("Лучшие параметры:", vars(best))

5. ESSS модель на PyTorch

import torch, torch.nn as nn, torch.nn.functional as F

class EnhancedSelectiveSSM(nn.Module):
    def __init__(self, d_model=768, d_state=64, d_conv=4, expand=2,
                 memory_decay=0.9614, selective_gate=0.5760,
                 cross_channel=0.2628, hierarchical_levels=3):
        super().__init__()
        self.d_state, self.d_inner = d_state, d_model*expand
        self.memory_decay, self.selective_gate = memory_decay, selective_gate
        self.cross_channel, self.hier_levels = cross_channel, hierarchical_levels

        self.in_proj = nn.Linear(d_model, self.d_inner*2, bias=False)
        self.conv1d = nn.Conv1d(self.d_inner,self.d_inner,d_conv,
                               groups=self.d_inner, padding=d_conv-1)
        self.x_proj = nn.Linear(self.d_inner, d_state*2 + self.d_inner, bias=False)

        A = torch.zeros(d_state)
        for n in range(d_state): A[n]=-(2*n+1)
        self.A_log = nn.Parameter(torch.log(-A))

        rank = max(1,d_state//4)
        self.B1 = nn.Parameter(torch.randn(d_state,rank))
        self.B2 = nn.Parameter(torch.randn(rank,self.d_inner))

        self.cross_w = nn.Parameter(torch.randn(self.d_inner,self.d_inner)*0.02)
        self.hierarchy = nn.ModuleList([nn.Linear(self.d_inner,self.d_inner,bias=False)
                                        for _ in range(hierarchical_levels)])
        self.temporal = nn.Linear(self.d_inner,3,bias=True)
        self.out_proj = nn.Linear(self.d_inner,d_model,bias=False)
        self.D = nn.Parameter(torch.ones(self.d_inner))

    def forward(self,x, h_prev=None):
        B,L,D = *x.shape, x.shape[-1]
        xz = self.in_proj(x)
        x_inner,z = xz.chunk(2,-1)
        x_conv = F.silu(self.conv1d(x_inner.transpose(1,2))[:,:, :L].transpose(1,2))
        x_dbl = self.x_proj(x_conv)
        delta, B_sel, C_sel = torch.split(x_dbl,[self.d_state,self.d_state,self.d_inner],-1)
        alpha,beta,gamma = torch.sigmoid(self.temporal(x_conv)).chunk(3,-1)
        delta = F.softplus(delta)

        A = -torch.exp(self.A_log)
        A_disc = torch.exp(delta.unsqueeze(-1)*A.unsqueeze(0).unsqueeze(0))
        B_low = self.B1 @ self.B2
        B_disc = delta.unsqueeze(-1) * (B_sel.unsqueeze(-1) @ B_low.unsqueeze(0).unsqueeze(0))

        if h_prev is None:
            h_prev = torch.zeros(B,self.d_state,self.d_inner,device=x.device)
        h = h_prev
        h_states=[]
        for t in range(L):
            mem = alpha[:,t:t+1,:].unsqueeze(-1)*A_disc[:,t:t+1,:].unsqueeze(-1)*h
            inp = beta[:,t:t+1,:].unsqueeze(-1)*B_disc[:,t,:,:].unsqueeze(1)*x_conv[:,t:t+1,:].unsqueeze(1)
            cross = torch.matmul(h.sum(1,keepdim=True), self.cross_w)
            cross_contrib = gamma[:,t:t+1,:].unsqueeze(-1)*self.cross_channel*cross.unsqueeze(1)
            h = mem + inp + cross_contrib
            h_states.append(h)
        h_all = torch.stack(h_states,1)
        y = torch.einsum('blsd,ld->bld',h_all,C_sel.mean(1))
        for lvl,layer in enumerate(self.hierarchy):
            y = y + layer(F.silu(y))*(self.hier_levels-lvl)/self.hier_levels
        y = y + x_conv*self.D
        y = y * F.silu(z)
        return self.out_proj(y), h_all[:,-1]

Итоги

  • Генетический алгоритм эффективно искал новые формулы и показал улучшение на 24%.

  • ESSS — адаптивная, быстрая и мощная модель с многомасштабной памятью.


P.S Мои эксперименты и всякие штуки здесь: https://t.me/RevolutionTimellesAI.