Pull to refresh

Mamba. От начала до конца

Level of difficultyMedium
Reading time5 min
Views26K
Во времена повсеместного заполонения трансформерами, которые пожирали в себя все больше и больше кремниевых чипов; когда казалось, что лучше уже не будет и за каждый новый токен нужно платить в квадрате от предыдущих, в эту холодную зимнюю пору появилась она - Мамба.
Во времена повсеместного заполонения трансформерами, которые пожирали в себя все больше и больше кремниевых чипов; когда казалось, что лучше уже не будет и за каждый новый токен нужно платить в квадрате от предыдущих, в эту холодную зимнюю пору появилась она - Мамба.

Трансформеры произвели настоящий фурор в области Deep Learning и демонстрируют выдающуюся эффективность. Однако у них существует серьезное ограничение по длине входной последовательности (контекста) из-за квадратичной вычислительной сложности. Большинство моделей работают с контекстом длиной менее 10 000, что делает их малоприменимыми в задачах с большими объемами входных данных. И хотя ходили различные слухи, было бы странно увидеть сильный искусственный интеллект, который можно за пару минут заболтать до беспамятства.

Вычислительная сложность

Длина контекста L

Пропускная способность

Трансформер

L^2

\sim 10^3

x

Мамба

L

\sim 10^6

5x

Мамба основывается на принципиально другом подходе - SSM, который, хоть и сильно старше трансформера, в контексте глубокого обучения не показывал достаточной эффективности, особенно в качестве языковой модели. Мамба имеет линейную вычислительную зависимость и в 5 раз выше пропускную способность, чем у трансформеров. Авторы проверили свое детище на серии моделей только до 2.8 млрд. параметров, что еще мало похоже на Chatgpt, но уже утерли нос текущим топам языковых моделей в своей весовой категории. Длина контекста при этом была выбрана как у соответствующего трансформера, так что контекст размером в миллион был проверен только на простых синтетических тестах, что, однако, тоже немаловажно, так как ни трансформеры, ни свертки с этими тестами не справились. В этой статье мы детально рассмотрим всю математику новой архитектуры, заметая под ковер преимущества и недостатки.

Линейная модель пространства состояний (SSM)

Непрерывный случай

Модель пространства состояний, на которой построена вся идея, в непрерывном виде выглядит так:

\begin{array}{lcl} \boldsymbol{\dot h}(t) = \boldsymbol{Ah}(t)+\boldsymbol{Bx}(t)\\ \boldsymbol{y}(t) = \boldsymbol{Ch}(t)+\boldsymbol{Dx}(t) \end{array}

С помощью такой модели можно записать дифференциальное уравнение N-го порядка как N уравнений первого порядка в матричном виде, где \boldsymbol{h}(t) - вектор состояния, содержащий производные по возрастанию порядка от 0 до N-1, x(t) - входной сигнал, y(t) - выходной сигнал. Таким образом, сложность описываемой системы нелинейно растет от N.

По-другому на модель можно смотреть так - одномерный сигнал x(t) отображается в N- мерное латентное состояние \boldsymbol{h}(t), а затем проецируется в выходной сигнал y(t).

Чтобы взаимодействовать с моделью в векторном виде конечной размерности, нужно дискретизировать ее.

Дискретизация

Второе уравнение дискретизируется просто: \boldsymbol{y_k} = \boldsymbol{y}(k\Delta) \rightarrow \boldsymbol{y_k} = \boldsymbol{Ch_k}+\boldsymbol{Dx_k}. Рассмотрим подробнее для первого:

Решение в общем виде

Умножим исходное уравнение на e^{-\boldsymbol{A}t}:

e^{-\boldsymbol{A}t} \boldsymbol{\dot h}(t) - e^{-\boldsymbol{A}t} \boldsymbol{Ah}(t) = e^{-\boldsymbol{A}t} \boldsymbol{Bx}(t)\\ \frac{d}{dt} (e^{-\boldsymbol{A}t} \boldsymbol{h}(t)) = e^{-\boldsymbol{A}t} \boldsymbol{Bx}(t)

Тогда общее решение для непрерывной модели будет:

\boldsymbol{h}(t) = e^{\boldsymbol{A}t} \boldsymbol{h}(0) + \int_0^t{e^{\boldsymbol{A}(t-\tau)}} \boldsymbol{Bx}(\tau) d\tau

Шаг дискретизации \Delta \Rightarrow \boldsymbol{x_k} = \boldsymbol{x}(k \Delta), \boldsymbol{h_k} = \boldsymbol{h}(k \Delta):

\boldsymbol{h_k} = e^{\boldsymbol{A}k\Delta} \boldsymbol{h}(0) + \int_0^{k\Delta}{e^{\boldsymbol{A}(k\Delta-\tau)}} \boldsymbol{Bx}(\tau) d\tau \boldsymbol{h_{k+1}} = e^{\boldsymbol{A}\Delta}  \left[ e^{\boldsymbol{A}k\Delta} \boldsymbol{h}(0)+ \int_0^{k\Delta}{e^{\boldsymbol{A}(k\Delta-\tau)}} \boldsymbol{Bx}(\tau) d\tau\right] + \int_{k\Delta}^{(k+1)\Delta} e^{\boldsymbol{A}\left[(k+1)\Delta-\tau\right]} \boldsymbol{Bx}(\tau) d\tau =

Подставляем выражение для \boldsymbol{h_k} и учитываем, что \boldsymbol{x}=const внутри интервала Δ:

= e^{\boldsymbol{A}\Delta} \boldsymbol{h_k} + \left[ \int_0^{\Delta} e^{\boldsymbol{A}\nu} d\nu \right] \boldsymbol{Bx_k} =  e^{\boldsymbol{A}\Delta} \boldsymbol{h_k} + \frac{1}{\boldsymbol{A}} (e^{\boldsymbol{A}\Delta}-\boldsymbol{I})\boldsymbol{Bx_k}

Альтернативный способ

Запишем уравнение сразу в дискретном виде \Delta \Rightarrow \boldsymbol{x_k} = \boldsymbol{x}(k \Delta), \boldsymbol{h_k} = \boldsymbol{h}(k \Delta):

\dfrac{\boldsymbol{h_{k+1}}-\boldsymbol{h_k}}{\Delta} = \boldsymbol{Ah_k}+\boldsymbol{Bx_k}\\[3ex] \boldsymbol{h_{k+1}} = (\boldsymbol{I}+\boldsymbol{A}\Delta )\boldsymbol{h_k}+\Delta\boldsymbol{Bx_k}

В первом приближении e^{\boldsymbol{A}\Delta } \approx \boldsymbol{I} +  \boldsymbol{A}\Delta или \dfrac{1}{\boldsymbol{A}}(e^{\boldsymbol{A}\Delta }-\boldsymbol{I}) \approx \Delta, тогда:

\boldsymbol{h_{k+1} = }e^{\boldsymbol{A}\Delta} \boldsymbol{h_k} + \frac{1}{\boldsymbol{A}} (e^{\boldsymbol{A}\Delta}-\boldsymbol{I})\boldsymbol{Bx_k}

Таким образом получаем дискретную SSM модель:

\begin{array}{lcl} \boldsymbol{h_k} = \overline{\boldsymbol{A}} \boldsymbol{h_{k-1}} + \overline{\boldsymbol{B}} \boldsymbol{x_k}\\  \boldsymbol{y_k} = \boldsymbol{C h_k} + \boldsymbol{D} \boldsymbol{x_k}\\  \\ \boldsymbol{\overline{A}} = e^{\boldsymbol{A}\Delta}\\ \boldsymbol{\overline{B}} = \frac{1}{\boldsymbol{A}} (e^{\boldsymbol{A}\Delta}-\boldsymbol{I})\boldsymbol{B} \approx \Delta \boldsymbol{B} \end{array}

Если в параметре \boldsymbol{\overline{B}} разложить экспоненту до первого порядка, происходит очень удачное упрощение, поэтому авторы пренебрегают точностью этого, не самого важного, параметра в пользу уменьшения вычислений:

\boldsymbol{x_k}- вход модели, \boldsymbol{y_k}- выход модели, \boldsymbol{h_k}- скрытое состояние или память модели,  \boldsymbol{\overline{A}} - главный из параметров, отвечает за то, как мы преобразуем память с течением времени - или параметр запоминания, \boldsymbol{\overline{B}} - параметр преобразования входа, \boldsymbol{\overline{C}} - параметр преобразования выхода, \boldsymbol{\overline{D}} - своего рода skip connection или skip параметр, \Delta - шаг дискретизации.

В простейшем случае имеем такие размерности:

\boldsymbol{\overline{A}} (N, N), \; \boldsymbol{\overline{B}} (N, 1), \; \boldsymbol{\overline{C}} (1, N), \; \boldsymbol{\overline{D}} (1, 1), \; \boldsymbol{x_k} (1, 1), \; \boldsymbol{y_k} (1, 1), \; \boldsymbol{h_k}(N, 1), \; \Delta = const

Таким образом, мы получили простую рекуррентную систему, сохранив при этом всю математическую силу пространства состояний. Интуицию по стандартной SSM можно получить здесь.

Селективная SSM

Отличительная особенность Мамбы от предыдущих глубоких SSM в этой ветке эволюции состоит в добавлении селективности. Иначе говоря, мы хотим, чтобы в скрытое состояние \boldsymbol{h_k} попадали только значимые из всех \boldsymbol{h_i} [i \lt k], а остальные отсеивались.

Обозначения

N - размерность скрытого состояния
L - длина входной последовательности
b - размер батча
d - глубина модели
E=2 - коэффициент расширения
d_{in} = Ed - глубина модели в Mamba блоке
A,B,C,D - параметры SSM
\Delta - размер шага дискретизации
\Delta_R = \frac{d}{16} - размерность проекции

Параметризация

Итак, чтобы модель могла акцентировать внимание на определенных элементах входной последовательности, сделаем три параметра зависимыми от входа:

\boldsymbol{B} = \boldsymbol{xW_B}, \; \boldsymbol{C} = \boldsymbol{xW_C}, \; \Delta = Softplus[\boldsymbol{xW_{\Delta1} W_{\Delta2}}+\Delta_{bias}]

Параметр \Delta управляет балансом между тем, насколько сильно фокусироваться или игнорировать текущий входной сигнал. Большой \Delta сбрасывает состояние \boldsymbol{h_k} и фокусируется на текущий вход \boldsymbol{x_k}, в то время как маленький \Delta сохраняет состояние и игнорирует текущий вход. Параметры \boldsymbol{B} и \boldsymbol{C} позволяют более тонко контролировать, вводить ли вход \boldsymbol{x_k} в состояние \boldsymbol{h_k} или состояние в выход \boldsymbol{y_k}.

\boldsymbol{A} и \boldsymbol{D} остаются независимыми от входа, но сами становятся параметрами.\boldsymbol{A} будем хранить в логарифмической форме \boldsymbol{A_{log}}(см. S4D инициализацию):

\boldsymbol{A} = -\exp^{\boldsymbol{A_{log}}}

Здесь и далее все экспоненты и логарифмы поэлементные. Таким образом, обучаемые параметры для селективного блока:

\boldsymbol{A_{log}}(d_{in}, N), \boldsymbol{W_{B}}(d_{in}, N), \boldsymbol{W_{C}}(d_{in}, N), \boldsymbol{D}(d_{in}), \boldsymbol{W_{\Delta1}}(d_{in}, \Delta_R), \boldsymbol{W_{\Delta2}}(\Delta_R, d_{in}), \Delta_{bias}(d_{in})

Введем сразу остальные параметры, которые будут использоваться в архитектуре:

\boldsymbol{W_{in}}(d, 2d_{in}), \boldsymbol{W_{out}}(d_{in}, d), \boldsymbol{W_{emb}}(vocab\:size, d)=\boldsymbol{W_{vocab}}^T, \boldsymbol{W_{conv1d}}(d_{in}, 1, K)

Инициализация параметров

Каждый из вышеописанных параметров инициализируется по своему:

\boldsymbol{A_{log}} = \ln\begin{pmatrix}1 & 2 & 3 & ... & N\\1 & 2 & 3 & ... & N\\  &   &...\end{pmatrix}, \; \boldsymbol{D} = \overline{\boldsymbol{1}}\Delta_{bias} = Softplus^{-1}\left[Uniform(10^{-3}, 10^{-1}) \right]

Параметр \boldsymbol{W_{conv1d}} задается стандартной инициализацией conv1d слоя с bias=True, тогда как все оставшиеся веса задаются Linear слоем с bias=False.

Инференс селективной SSM с аппаратным ускорением

Рисунок 1: Устройство Selective SSM блока (Mamba).
Рисунок 1: Устройство Selective SSM блока (Mamba).
\boldsymbol{x}(b, L, d_{in}), \boldsymbol{h_t}(b, d_{in}, N) \rightarrow \boldsymbol{y}(b, L, d_{in})

Для ускорения вычислений авторы разделили инференс selective SSM блока на два этапа - сначала подготовка (трехмерных массивов) на обычной (медленной) памяти видеокарты, затем дискретизация и вычисление рекурсии (четырехмерных массивов) в быстрой памяти видеокарты:

1) Подготовка (GPU HBM):

Возвращение \boldsymbol{A} в человеческий вид и проекция входа:

\begin{array}{ccc}\boldsymbol{A}(d_{in}, N) = -\exp^{\boldsymbol{A_{log}}}\\\boldsymbol{B}(b, L, N) = \boldsymbol{xW_B}\\\boldsymbol{C}(b, L, N) = \boldsymbol{xW_C}\\\Delta(b, L, d_{in}) = Softplus[\boldsymbol{xW_{\Delta1} W_{\Delta2}}+\Delta_{bias}]\end{array}

2) Selective scan (GPU SRAM):

Инициализация скрытого состояния:

\boldsymbol{h_{-1}} = \overline{\boldsymbol{0}}

Дискретизация:

\begin{array}{ccc}\boldsymbol{\overline{A}}(b, L, d_{in}, N) = e^{\Delta \boldsymbol{A}}\\\boldsymbol{\overline{B}x}(b, L, d_{in}, N) = \Delta \boldsymbol{Bx}\end{array}

В цикле по t вдоль оси L (по каждому токену) пересчет всех скрытых состояний \boldsymbol{h} и соответствующих им выходов \boldsymbol{y}:

\begin{array}{lcl}\boldsymbol{h_t} = \overline{\boldsymbol{A_t}} \boldsymbol{h_{t-1}} + \boldsymbol{(\overline{B} x)_t}\\ \boldsymbol{y_t} = \boldsymbol{C_t h_t} + \boldsymbol{Dx_t}\\ \end{array}

Архитектура Mamba

Рисунок 2: Устройство архитектуры Mamba
Рисунок 2: Устройство архитектуры Mamba

Mamba

Устройство архитектуры не сильно отличается от трансформерной:

  1. На входе имеем последовательность длиной L, которая может представлять из себя хоть текстовые токены, хоть элементы изображения.

  2. Векторизуем элементы последовательности матрицой эмбеддингов \boldsymbol{W_{emb}}, получая тот самый \boldsymbol{x}(b, L, d).

  3. Прогоняем его через n_{layers} мамба-слоев, сохраняя при этом размерность.

  4. Возвращаем размерность (b, L, vocab\;size) матричным умножением на \boldsymbol{W_{vocab}=W_{emb}^T} - той же матрицей, что и на входе.

  5. И, наконец, получаем вероятности для каждого токена по словарю.

Mamba Layer

Слой Мамба представляет из себя:

  1. Нормализацию по слою

  2. Непосредственно сам Мамба блок

  3. Skip connection

Mamba Block

Принцип блока основан на gated MLP, который при помощи дополнительной ветки с линейным слоем, активацией и последующим Element-wise умножением может управлять потоком информации основной ветки, определяя какая часть должна быть сохранена, а какая подавлена.

По основной же ветке идет, так называемый, inverted bottleneck:

  • Расширение (\boldsymbol{W_{in}}) \rightarrow depthwise convolution (в данном случае одномерная) \rightarrow проекция (\boldsymbol{W_{out}}),

с добавлением активации и основного блока - selective SSM из предыдущего раздела.

Заключение

Модель Мамба успешно унаследовала ключевые характеристики от трансформеров, такие как внимание к контексту и мультимодальность, открывая при этом новые перспективы для будущего развития. Способность Мамба эффективно работать в различных доменах, особенно в модальностях, где требуется учет большого объема контекста, таких как геномика, аудио и видео, выделяет ее среди передовых разработок.

Хотя данный обзор сосредоточен исключительно на математических аспектах нового подхода, результаты показывают, что Мамба может стать мощным кандидатом на роль нового общего мультимодального бэкбона. Подробности про синтетические тесты, результаты и сравнения в областях LLM, аудио и геномики доступны в оригинальной статье (ссылка).

Интересного нам 2024 года!

Материалы

Tags:
Hubs:
Total votes 23: ↑22 and ↓1+27
Comments25

Articles