Генетический алгоритм против Mamba: новая формула скрытых состояний для нейросетей
Генетический алгоритм против Mamba: новая формула скрытых состояний для нейросетей
Современные State Space Models (SSM), такие как Mamba, отлично справляются с длительной памятью, но сталкиваются с ограничениями в адаптивности и сложности. В этом проекте применён генетический алгоритм для эволюционного поиска новых формул скрытых состояний, которые превосходят классические подходы на 24%. Итог – модель ESSS (Enhanced Selective State Space) с адаптивной многокомпонентной архитектурой.
1. Задача и подход
Вместо ручной настройки формул скрытых состояний, генетический алгоритм «эволюционировал» параметры модели: типы матриц, функции активации, коэффициенты памяти, селективные гейты, межканальные связи и т. д. Популяция из 30 особей прошла 25 поколений, выбирая лучшие конфигурации.
Формула ESSS
Основная формула скрытого состояния ESSS — сумма трёх компонент:
hₜ = α(xₜ, hₜ₋₁) ⊙ Aₑₙₕ ⊙ hₜ₋₁
β(xₜ, hₜ₋₁) ⊙ Bₗₒ𝓌 ⊙ xₜ
γ(h꜀ʀₒₛₛ) ⊙ C꜀ʀₒₛₛ ⊙ Σ꜀ʰₐₙₙₑₗₛ(hₜ₋₁)
где:
α, β, γ — адаптивные коэффициенты, зависящие от входа и предыдущего состояния;
Aenh — улучшенная HiPPO-матрица с обучаемыми дельтами;
Blow — низкоранговая факторизация для входных данных;
Ccross — вес межканальной агрегации для синхронизации состояния по каналам.
3. Оптимальные параметры ESSS
Параметр | Значение |
|---|---|
Тип матрицы A | HiPPO (обучаемые дельты) |
Матрица B | Низкоранговая факторизация |
Активация | GELU |
Затухание памяти | 0.9614 |
Селективный гейт | 0.5760 |
Межканальное взаимодействие | 0.2628 |
Иерархические уровни | 3 |
Временные веса | Адаптивные |
Использование свёртки | Да |
Дельта-адаптация | Да |
4. Генетический алгоритм для поиска параметров (Python)
import numpy as np, random
from typing import List, Tuple
class SSMGenome:
def __init__(self):
options = {
'a_matrix': ['diagonal','low_rank','hippo','learnable'],
'b_matrix': ['standard','low_rank','factorized'],
'activation': ['relu','gelu','swish','silu'],
'temporal': ['fixed','adaptive','learned']
}
self.a_matrix = random.choice(options['a_matrix'])
self.b_matrix = random.choice(options['b_matrix'])
self.activation = random.choice(options['activation'])
self.memory_decay = random.uniform(0.8,0.99)
self.selective_gate = random.uniform(0.3,0.9)
self.cross_channel = random.uniform(0.1,0.5)
self.h_levels = random.randint(2,5)
self.temporal = random.choice(options['temporal'])
self.use_conv = random.choice([True,False])
self.delta_adapt = random.choice([True,False])
self.fitness = 0.0
def mutate(self, rate=0.15):
if random.random()<rate:
self.a_matrix = random.choice(['diagonal','low_rank','hippo','learnable'])
if random.random()<rate:
self.b_matrix = random.choice(['standard','low_rank','factorized'])
if random.random()<rate:
self.activation = random.choice(['relu','gelu','swish','silu'])
if random.random()<rate:
self.memory_decay = np.clip(self.memory_decay+random.gauss(0,0.05),0.8,0.99)
if random.random()<rate:
self.selective_gate = np.clip(self.selective_gate+random.gauss(0,0.1),0.3,0.9)
if random.random()<rate:
self.cross_channel = np.clip(self.cross_channel+random.gauss(0,0.05),0.1,0.5)
if random.random()<rate:
self.h_levels = np.clip(self.h_levels+random.randint(-1,1),2,5)
if random.random()<rate:
self.temporal = random.choice(['fixed','adaptive','learned'])
if random.random()<rate:
self.use_conv = random.choice([True,False])
if random.random()<rate:
self.delta_adapt = random.choice([True,False])
def crossover(p1:SSMGenome,p2:SSMGenome)->Tuple[SSMGenome,SSMGenome]:
c1,c2=SSMGenome(),SSMGenome()
for attr in ['a_matrix','b_matrix','activation','temporal','use_conv','delta_adapt']:
setattr(c1,attr, random.choice([getattr(p1,attr),getattr(p2,attr)]))
setattr(c2,attr, random.choice([getattr(p1,attr),getattr(p2,attr)]))
α=random.random()
for num in ['memory_decay','selective_gate','cross_channel']:
v1,v2=getattr(p1,num),getattr(p2,num)
setattr(c1,num,α*v1+(1-α)*v2)
setattr(c2,num,(1-α)*v1+α*v2)
c1.h_levels=random.choice([p1.h_levels,p2.h_levels])
c2.h_levels=random.choice([p1.h_levels,p2.h_levels])
return c1,c2
def evaluate_fitness(g:SSMGenome)->float:
score=0.0
score += 15 if g.a_matrix=='hippo' else 10 if g.a_matrix=='learnable' else 0
score += 12 if g.b_matrix=='low_rank' else 8 if g.b_matrix=='factorized' else 0
score += 10 if g.activation in ['gelu','swish'] else 0
score += g.memory_decay*20 + g.selective_gate*15 + g.cross_channel*10
level_penalty=abs(g.h_levels-3.5)
score += max(0,5-2*level_penalty)
score += 12 if g.temporal=='adaptive' else 8 if g.temporal=='learned' else 0
score += 5 if g.use_conv else 0
score += 10 if g.delta_adapt else 0
score += random.gauss(0,3)
return max(0,score)
def tournament(pop:List[SSMGenome],k=3)->SSMGenome:
return max(random.sample(pop,k), key=lambda x:x.fitness)
def genetic_algorithm(
pop_size=30, gens=25, mut=0.15, elite=2
)->Tuple[SSMGenome,List[float]]:
pop=[SSMGenome() for _ in range(pop_size)]
for g in pop: g.fitness=evaluate_fitness(g)
history=[]
best=max(pop, key=lambda x:x.fitness)
for gen in range(gens):
pop.sort(key=lambda x:x.fitness, reverse=True)
new=pop[:elite]
while len(new)<pop_size:
p1,p2=tournament(pop),tournament(pop)
c1,c2=crossover(p1,p2)
c1.mutate(mut); c2.mutate(mut)
c1.fitness = evaluate_fitness(c1)
c2.fitness = evaluate_fitness(c2)
new += [c1,c2]
pop=new[:pop_size]
gen_best=max(pop,key=lambda x:x.fitness)
history.append(gen_best.fitness)
if gen_best.fitness>best.fitness: best=gen_best
return best,history
if __name__=="__main__":
best, hist = genetic_algorithm()
print("Лучшие параметры:", vars(best))
5. ESSS модель на PyTorch
import torch, torch.nn as nn, torch.nn.functional as F
class EnhancedSelectiveSSM(nn.Module):
def __init__(self, d_model=768, d_state=64, d_conv=4, expand=2,
memory_decay=0.9614, selective_gate=0.5760,
cross_channel=0.2628, hierarchical_levels=3):
super().__init__()
self.d_state, self.d_inner = d_state, d_model*expand
self.memory_decay, self.selective_gate = memory_decay, selective_gate
self.cross_channel, self.hier_levels = cross_channel, hierarchical_levels
self.in_proj = nn.Linear(d_model, self.d_inner*2, bias=False)
self.conv1d = nn.Conv1d(self.d_inner,self.d_inner,d_conv,
groups=self.d_inner, padding=d_conv-1)
self.x_proj = nn.Linear(self.d_inner, d_state*2 + self.d_inner, bias=False)
A = torch.zeros(d_state)
for n in range(d_state): A[n]=-(2*n+1)
self.A_log = nn.Parameter(torch.log(-A))
rank = max(1,d_state//4)
self.B1 = nn.Parameter(torch.randn(d_state,rank))
self.B2 = nn.Parameter(torch.randn(rank,self.d_inner))
self.cross_w = nn.Parameter(torch.randn(self.d_inner,self.d_inner)*0.02)
self.hierarchy = nn.ModuleList([nn.Linear(self.d_inner,self.d_inner,bias=False)
for _ in range(hierarchical_levels)])
self.temporal = nn.Linear(self.d_inner,3,bias=True)
self.out_proj = nn.Linear(self.d_inner,d_model,bias=False)
self.D = nn.Parameter(torch.ones(self.d_inner))
def forward(self,x, h_prev=None):
B,L,D = *x.shape, x.shape[-1]
xz = self.in_proj(x)
x_inner,z = xz.chunk(2,-1)
x_conv = F.silu(self.conv1d(x_inner.transpose(1,2))[:,:, :L].transpose(1,2))
x_dbl = self.x_proj(x_conv)
delta, B_sel, C_sel = torch.split(x_dbl,[self.d_state,self.d_state,self.d_inner],-1)
alpha,beta,gamma = torch.sigmoid(self.temporal(x_conv)).chunk(3,-1)
delta = F.softplus(delta)
A = -torch.exp(self.A_log)
A_disc = torch.exp(delta.unsqueeze(-1)*A.unsqueeze(0).unsqueeze(0))
B_low = self.B1 @ self.B2
B_disc = delta.unsqueeze(-1) * (B_sel.unsqueeze(-1) @ B_low.unsqueeze(0).unsqueeze(0))
if h_prev is None:
h_prev = torch.zeros(B,self.d_state,self.d_inner,device=x.device)
h = h_prev
h_states=[]
for t in range(L):
mem = alpha[:,t:t+1,:].unsqueeze(-1)*A_disc[:,t:t+1,:].unsqueeze(-1)*h
inp = beta[:,t:t+1,:].unsqueeze(-1)*B_disc[:,t,:,:].unsqueeze(1)*x_conv[:,t:t+1,:].unsqueeze(1)
cross = torch.matmul(h.sum(1,keepdim=True), self.cross_w)
cross_contrib = gamma[:,t:t+1,:].unsqueeze(-1)*self.cross_channel*cross.unsqueeze(1)
h = mem + inp + cross_contrib
h_states.append(h)
h_all = torch.stack(h_states,1)
y = torch.einsum('blsd,ld->bld',h_all,C_sel.mean(1))
for lvl,layer in enumerate(self.hierarchy):
y = y + layer(F.silu(y))*(self.hier_levels-lvl)/self.hier_levels
y = y + x_conv*self.D
y = y * F.silu(z)
return self.out_proj(y), h_all[:,-1]
Итоги
Генетический алгоритм эффективно искал новые формулы и показал улучшение на 24%.
ESSS — адаптивная, быстрая и мощная модель с многомасштабной памятью.
P.S Мои эксперименты и всякие штуки здесь: https://t.me/RevolutionTimellesAI.