Это продолжение цикла статей о масштабировании тренировки и инференса LLM.
Считаем количество операций
А теперь перейдем к чему-то более практическому, а именно к тому, сколько нужно FLOPs и байт для работы трансформера, той самой архитектуры, которая лежит в основе практически всех современных LLM. Подразумевается, что у вас уже есть представление о том, что такое архитектура трансформера, как работает механизм внимания и т.д.
Давайте начнем с векторов x, y и матриц A, B, имеющих вот такие размеры, допустим один элемент занимает при этом один байт.

Скалярное произведение двух векторов x * y требует P операций сложения и умножения, итого 2P FLOPs и 2P байт.
Произведение матрицы А на вектор x требует N скалярных произведений строк А и вектора х, итого 2NP FLOPs и NP + P байт
Произведение двух матриц А и В требует перемножить каждую из N строк матрицы А с каждым из M столбцов матрицы B, итого 2NPM FLOPs и NP + PM байт
А теперь перейдем к общему случаю с тензорами. Пусть тензор C имеет размерность [G, M, K], а тензор D - [G, N, K]. То есть по сути это два батча из G матриц. Ось G обоих тензоров можно назвать осью батча, ось K - смежной осью. То есть при перемножении этих двух тензоров мы берем первую пару матриц из каждого батча, размера [M, K] и [N, K], и перемножаем их, затем делаем так же для второй пары и т.д. В сумме получаем 2GNMK FLOPs и GMK + GNK байт.
В общем виде это выглядит так.

Здесь BATCH - множество осей батча, CONTRACT - множество смежных осей
Также заметим, что для матричного умножения объем вычислений растет пропорционально O(N^3), а объем данных - O(N^2). Соответственно при росте размерностей матриц арифметическая интенсивность данной операции растет, т.е. при умножении больших матриц нам легче нагрузить наш ускоритель. Собственно это одна из причин того, почему современные нейронки это в основном последовательность матричных умножений - их очень легко масштабировать.

Туда и обратно
Вообще говоря, когда мы тренируем нейросеть, нас заботит не столько результат матричного умножения, сколько его градиент, т.к. мы делаем гораздо больше вычислений во время процесса обратного распространения.
Представим, что В это одна из матриц нашей нейросети, А это входные активации, С = АВ. Тогда градиент функции потерь для В с учетом chain rule будет:

В формуле выше учтен тот факт, что если С = АВ, то:

Теперь, если посчитаем градиент функции потерь для А, то получим:

Т.к. dLoss/dC это матрица размера [N, M], размер А - [N, P], B - [P, M], получим в итоге, что для обоих градиентов потребуется по 2NPM FLOPs, что в сумме нам дает 4NPM.
Теперь складываем это с затратами на прямой проход и получаем в итоге 6NPM FLOPs. А поскольку PM - количество параметров в матрице весом, то получаем первое приближение для количества операций, необходимых для тренировки трансформера:
6 х кол-во токенов х кол-во параметров
И это довольно близкое к реальности приближение, т.к. большую часть параметров и вычислений трансформера приходится именно на большие MLP слои, которые и представляют из себя те самые обычные матрицы без наворотов. О более точной оценке поговорим ниже.
Бухгалтерия трансформера
Чаще всего для создания LLM используют не ванильный трансформер, а его декодерную часть, поэтому будем рассматривать только ее. Вот так выглядит один ее слой и все операции, которые приходится над ним производить:

Поначалу картинка может показаться запутанной, но после пары просмотров вы уже будете свободно в ней ориентироваться, особенно если знакомы с трансформерами. Справа можно заметить кучу всяких обозначений размерностей, запоминать их необязательно, я продублирую их далее в тексте в нужных моментах.
Примечание 1: На диаграмме выше, в нижней ее части используется т.н. «gating einsum», в которой мы разбиваем матрицу up-проекции на две матрицы W_In1 и W_In2, выходы которых перемножаются поэлементно — это своего рода «функция гейтинга». Не все LLM используют этот подход, поэтому иногда можно встретить просто одну матрицу W_In, и следовательно общее количество параметров MLP будет равно 2DF вместо 3DF (D - размерность входных эмбеддингов, F - размерность скрытого слоя MLP). Как правило, в таких случаях D и F увеличиваются, чтобы сохранить то же количество параметров, что и в варианте с тремя матрицами. Тем не менее, та или иная форма gating einsum используется в LLaMA, DeepSeek и т.д.
Примечание 2: В классическом Multi-Head Attention (MHA) длины последовательностей для query, key и value одинаковы, следовательно значение T и S тоже будет одинаково. Но во всяких там Multi-Query Attention (MQA, K = 1) или Grouped MQA (GMQA, K = N // G) это соотношение будет другим.
Ниже будем считать количество операций для разных элементов трансформера.
MLP
Пусть на вход MLP приходит В последовательностей токенов длиной T, размерность каждого токена D, размерность скрытого слоя MLP - F. Тогда в итоге получается:

Attention
Пусть у нас слой внимания с N блоками query и K блоками key и value, будем считать, что размерность всех этих блоков одинакова и равна H. Тогда вот что получим для входных и выходных матриц W_Q, W_K, W_V и W_O:

Красным обозначены смежные оси, синим - оси батчей.
Пусть T - длина последовательности для query, S - длина последовательности для key и value. Пусть у нас MHA-слой, следовательно T = S. Тогда для непосредственно самой операции self-attention получаем:

Собственно видим ту самую квадратичную зависимость сложности внимания от длины последовательности, которая делает такой дорогой тренировку трансформеров на длинных контекстах.
Примечание: тут не учли causal masking, операцию которая маскирует некоторые токены в матрице query-key для того, чтобы модель училась только на данных из прошлого. В таком случае количество операций сокращается вдвое.
Другие операции
Есть еще операции вроде Layer Norm или Unembed (возращает эмбеддинг обратно к размерности словаря (на картинке ниже обозначено как V), производится один раз на выходе трансформера), но они дешевы или редки.

Отношение вычислительной сложности внимания ко всему трансформеру в зависимости от контекста
Напомню, что F - размерность скрытого слоя MLP, D - размерность каждого токена, N и K - кол-во блоков query и key/value, T - длина последовательности токенов, H - размерность блоков внимания, В - размер батча.
Предположим дополнительно, что F = 4D, D = NH и N = K, т.к. такие равенства типичны для трансформера. Тогда:

Таким образом, внимание начинает доминировать в плане вычислений, если T > 8D. Для D ~ 8k это будет примерно 64к токенов. Соответственно чем больше модель, тем менее критична для нее длина контекста в плане тренировки. Также есть способы ускорить вычисление внимания, например Flash Attention.
MoE
MoE или Mixture of Experts это обычная модель с E MLP блоками (экспертами) на слой вместо одного. Каждый токен активирует k из этих экспертов, обычно k << E. Отношение E/k называют разреженностью, оно равно обычно от 8 до 64.

По сравнению с обычной моделью MoE дополнительно использует две операции AllToAll, которые отправляют токен к нужному эксперту, а затем возвращают его обратно. Но, как мы знаем из предыдущей главы, AllToAll довольно дешевая операция.
В общем в первом приближении MoE это обычная модель с E блоками MLP в каждом слое вместо одного.
Считаем память
Довольно проблематично работать с моделью, если она не влезает в оперативную память целиком. Поэтому было бы неплохо, помимо оценок потребляемых моделью вычислений, также иметь оценку потребляемой моделью памяти.
А что именно в модели кушает память? Давайте сначала посчитаем для режима тренировки.
Тренировка
Веса модели
Если веса хранятся в fp32 - получаем 4 байта * кол-во параметров модели.
Если тренируем в mixed precision (когда помимо основной модели с весами fp32 имеем еще одну с весами в fp16) - получаем 6 байт * кол-во параметров.
Состояния оптимизатора
AdamW: 8 байт * кол-во параметров (по 4 байта на каждый из двух моментов градиентов)
SGD with momentum, LION, Adafactor: 4 байт * кол-во параметров (один момент вместо двух)
Можно дополнительно уменьшить, если хранить состояния с меньшей точностью, например bf16.
Градиенты
Если веса хранятся в fp32 - 4 байта * кол-во параметров.
Если веса хранятся в fp16 - 2 байта * кол-во параметров.
Активации
Потребление памяти зависит от длины последовательности, размера батча, размера скрытого слоя и т.д. Давайте разберемся поподробнее.
Пусть B - длина батча, T - длина последовательности токенов в батче, D - размерность каждого токена, V - размер словаря, L - количество слоев.
Тогда количество памяти, необходимой для активаций:
активация текущего слоя: dtype_size * num_of_hidden_states_copies * B * T * D
чейкпойнты активаций всех слоев: dtype_size * gradient_checkpoint_num * B * T * D * L
выходные логиты LLM: 4 * B * T * V (логиты обычно хранятся в формате fp32 - 4 байта)
dtype_size - размер типа данных, который используем для хранения активаций
num_of_hidden_states_copies отражает, сколько промежуточных тензоров одновременно живёт внутри одного слоя во время forward/backward и варьируется от 20 до 50 в зависимости от модели. Для диаграммы трансформера, приведенной выше, num_of_hidden_states_copies ~ 20.
Назначение gradient_checkpoint_num будет описано ниже.
Gradient checkpointing
Обычный алгоритм обратного распространения ошибки жертвует памятью в обмен на вычисления. Вместо того чтобы требовать O(n_layers^2) FLOPs, обратный проход требует O(n_layers) памяти, поскольку сохраняет все промежуточные активации, полученные во время прямого прохода.
Например для прямого и обратного прохода:

Если мы не хотим заново пересчитывать g(x) и exp(g(x)), нам лучше их сохранить.
Собственно gradient_checkpoint_num и означает количество чекпойнтов этих активаций на слой. Если сохранять все промежуточные активации, gradient_checkpoint_num = num_of_hidden_states_copies.
И хотя это лучше, чем вычисления с квадратичной сложностью, с точки зрения памяти это дороговато: модель с B * T = 4M (4M токенов на батч), L (число слоев)=64 и D=8192, избегающая всех лишних вычислений в обратном проходе, должна сохранить примерно 2*20*B*T*D*L = 84TB активаций в bfloat16.
Поэтому мы используем промежуточные стратегии:
Block remat: сохранять только вход в каждый слой. Это наиболее агрессивный метод - он сохраняет лишь 1 чекпоинт на слой, то есть gradient_checkpoint_num = 1, а в приведённом примере мы сохранили бы только 4.2TB. Однако это вынуждает нас повторять практически все вычисления прямого прохода в обратном, увеличивая множитель количества FLOPs с примерно 6 до 8.
Big matmuls only: еще одна простая стратегия - сохранять только выходы крупных матричных умножений. Это позволяет избежать повторного вычисления больших matmul в обратном проходе, но всё равно требует пересчёта остальных функций активации и частей attention. Это снижает gradient_checkpoint_num до примерно 7 на слой.
Итого общее приблизительное потребление памяти LLM во время тренировки:

K - зависит от размера типов данных, используемых для хранения весов модели, состояний оптимизиатора, градиентов, типа оптимизатора и т.д. и варьируется в среднем от 10 до 20.
Nm - кол-во параметров модели
Nb - кол-во токенов в батче
D - размерность каждого токена
L - кол-во слоев LLM
V - размер словаря
Инференс
В режиме инференса не надо хранить состояния оптимизатора, градиенты и их чекпойнты. В этом режиме потребление памяти считается по формуле:

I - зависит от размера типов данных, используемых для хранения весов модели и равно 6 байт для mixed precision (4+2), 4 байта для fp32 и т.д. Но помимо весов модели в режиме инференса есть и дополнительный потребитель памяти.
KV Cache
Инференс состоит из двух фаз:
Prefill. Обрабатывает исходный + системный промпт LLM
Generation. Авторегресионно генерирует выходные токены один за другим.
Для того, что генерация очередного токена не занимала квадратичное время, значения key и value сохраняются в Key-Value Cache (KV Cache). Следовательно каждый слой LLM имеет свой KV Cache.
Каждый KV Cache имеет размерность [2,S,L,K,H], 2 берется из того, что у нас отдельный кэш для key и value. И это, мягко говоря, немало. Для небольшой модели с длиной контекста 8к, 64 слоями, K * H = N * H = D = 8192, и при условии, что KV Cache хранится в int8, его размер будет 2 * S * L * K * H = 2 * 8192 * 64 * 8192 = 8GB. Именно поэтому сейчас так популярны всякие модификации механизма внимания, вроде GMQA, у которого K << N.
Заключение
В этой статье мы научились считать количество вычислений и памяти, необходимых для тренировки и инференса нашей модели. Более подробно оба эти режима мы рассмотрим в следующих статьях цикла.
Ну а на этом все, подписывайтесь, чтобы не пропустить следующие статьи, до скорого!
