Это продолжение предыдущей публикации про реставрацию ruGPT3XL. Для тех кто не читал, кратенько, я конвертировал древний Megatron-LM чекпоинт в HuggingFace-формат, залил веса на HF, накатил поддержку GGUF в llama.cpp и подумал, что всё. Но нет.

По ходу тестов, проведённых разными людьми удалось выявить ряд недоработок, которые я по мере обнаружения правил, ну а после того, как удалось получить стабильную и рабочую версию мне захотелось решить одну старую проблему, которая меня в ruGPT3 моделях очень беспокоила, это проблема маленького контекста в смешные 2k токенов.

Решил поднять контекст до 8k.

PPL, Sparse Attention и Triton

После прошлой публикации на Хабре меня резонно спросили, а на каких метриках вообще проверялось качество конвертированной модели? Я честно не знал, что ответить, так как гонял MERA в отрыве от оригинала, потому что оригинальную модель через древние Megatron-LM, DeepSpeed и Apex мне запустить так и не удалось, очень старый стек.

Смеркалось, свербило.

Решил взять метрику Perplexity (PPL), она очень простая, плюс указана в карточках всех оригинальных моделей, понятно как считать и что ожидать. Единственная проблема в том, что нужен датасет, на котором тестировали оригиналы, а такого у меня нет, и у SberDevices скорее всего тоже, так как пять лет прошло с тех пор.

Взял датасет gazeta Ильи Гусева @Takagi, в нём около 60k русскоязычных новостных статей, все примеры умещаются в 2k токенов, датасет небольшой и всем известный. Написал скрипт расчёта примерно по методологии из оригинальной публикации про ruGPT3, заодно прогнал все четыре размера семейства: ruGPT3small, ruGPT3medium, ruGPT3large и мой ruGPT3XL с наивным dense attention.

Получилась такая вот табличка:

Циферка для ruGPT3small отсутствовала в карточке модели, поэтому там прочерк. Корреляция между замерами на gazeta и оригинально заявленными значениями получилась вполне приличной (R = 0.93):

PPL 50.1 WTF

PPL конвертированного ruGPT3XL первым прогоном показал 50.1, а оригинальная модель в своей карточке имеет 12.05. Ошибка в расчётах? Не похоже, ведь у остальных трёх моделей семейства цифры PPL более менее похожие. Значил дело в чём-то другом.

Начал копать. Оказалось, кодовый агент при конвертации решил схалтурить и выбросил механизм Sparse Attention, заменив его на обычный nn.MultiheadAttention из GPT-2. Это, конечно, “работает”, модель генерирует текст, вот только веса-то оптимизированы под разреженное внимание, а не под плотное, математика другая, поэтому результат на контексте больше 128 токенов ожидаемо слабый.

Благодаря тому, что я потратил время на детальное изучение исходников Megatron-LM при первой конвертации, понять где именно проблема было несложно. Объяснил агенту что не так, показал примеры кода с правильным механизмом, дал почитать оригинальную публикацию про ruGPT3, и спустя несколько итераций получил исправленный modeling_rugpt3xl.py с репликой Sparse Attention из Megatron-LM.

Sparse Attention, зачем нужен и чем от обычного отличается

Стандартный causal self-attention (то, что в оригинальном GPT-2) - это плотная матрица, где каждый токен смотрит на все предыдущие токены. Память и вычисления растут квадратично от длины последовательности, удвоили контекст, получили в четыре раза больше операций с матрицей внимания и в четыре раза больше памяти потребляем.

Sparse Attention делает то же самое, но с прорежённой маской, в ruGPT3XL используется alternating-паттерн из статьи “Generating Long Sequences with Sparse Transformers” (arxiv:1904.10509):

  • Чётные слои (0, 2, 4, …) - block-sparse attention, каждый токен видит только ограниченное локальное окно (128 токенов) плюс несколько “глобальных” блоков через регулярные интервалы. Разные головы внимания используют разные позиции глобальных блоков.

  • Нечётные слои (1, 3, 5, …) - обычный плотный causal attention.

Теоретически это даёт почти линейный рост памяти вместо квадратичного. На практике для ruGPT3XL при увеличении контекста в 4 раза память на KV+активации растёт примерно в 3-4 раза (а не в 16x), замеры чуть ниже.

Разница в PPL между sparse и dense режимом для ruGPT3XL на датасете gazeta уже видна на графике выше, но если совсем кратко:

Механизм внимания

PPL (test, gazeta)

Dense (как в GPT-2)

50.1

Sparse alternating (оригинал)

11.68

После исправления PPL понизился с 50.1 до 11.68, это уже похоже на правду и хорошо коррелирует с заявленными 12.05 у оригинала.

Параллельно выяснилось, что в GGUF-версии та же история - прошлый патч в llama.cpp (PR #21011) добавлял конвертацию весов через архитектуру LLM_ARCH_GPT2, но сама sparse attention там не была реализована. Значит, GGUF-модель тоже считала dense внимание. Пришлось делать новый патч (PR #21161) он добавляет полноценную поддержку ruGPT3XL как отдельной архитектуры со sparse attention.

После релиза фикса механизма внимания один хабровчанин в комментариях указал, что в реализации sparse attention была ошибка:

cuda\TensorCompare.cu:109: block: [0,0,0], thread: [0,0,0] Assertion `input[0] != 0` failed.
    next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
                  ~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^
torch.AcceleratorError: CUDA error: device-side assert triggered

При обучении маска строилась некорректно для батчей длиннее одного примера, из-за чего обучение падало. Исправил эту проблему.

Triton

Ну и под конец добавил поддержку Triton для ускорения sparse-операций на GPU.

В преобразованной модели изначально внимание реализовывалось как явное matmul + softmax + matmul в режиме PyTorch, это математически корректно, но снижает производительность по сравнению с решениями на базе Triton доступными на современных графических процессорах NVIDIA.

На графике четыре режима работы механизма внимания при обучении (синтетический луп, AdamW, fp16, RTX 4090). Серый столбик - исходный eager режим (~6280 tok/s), это тот самый явный matmul+softmax+matmul. Синий - переключение на F.scaled_dot_product_attention (SDPA) при том же размере батча: +40% почти бесплатно, просто меняется путь исполнения внутри PyTorch. Голубой - SDPA с бо́льшим батчом (5×2048 вместо 2×2048), SDPA эффективнее использует память и позволяет запихнуть больше. Зелёный - SDPA плюс torch.compile с Inductor-бэкендом: итого ×1.85 к baseline, компилятор дополнительно сплавляет поэлементные операции и местами генерирует Triton-ядра. Числа внутри столбиков - кратность ускорения относительно eager.

Контекст 8k

Откуда идея

2048 токенов - это больная тема для всего семейства ruGPT3, в своё время на этом сгорело не мало моих нервов, пришлось изобретать sliding window в чатах, костыльные стратегии фильтрации датасетов чтобы не поймать OOM, чанковать документы. Всё это конечно же опыт, он мне позже пригодился и не раз, но осадочек остался.

Теперь, когда у меня есть рабочая современная версия модели, грех не попробовать исправить и этот фатальный недостаток.

Вопрос “насколько реально расширить контекст у ruGPT3XL” нетривиальный из-за двух особенностей архитектуры:

  1. у модели используется “Learned Absolute Positional Embeddings” (Learned APEs), таблица позиций embed_positions, по-простому nn.Embedding(2048, 2048), обученная вместе со всеми остальными весами. В отличие от Rotary Positional Embeddings (RoPE) таблица APE не умеет экстраполировать - если модель никогда не видела позицию с индексом 2049, она понятия не имеет что туда подставить.

  2. sparse attention (о которой было выше), сетка разреженного внимания строится из max_position_embeddings // sparse_block_size, то есть тоже зависит от лимита контекста.

На эту тему нашёл пару релевантных работ:

Короче, нельзя просто взять и увеличить max_position_embeddings в конфиге, ничего хорошего не выйдет, требуется дообучение, а вот после дообучения и с правильной инициализацией уже вполне реально.

Про память и вычисления

Важное следствие sparse attention для планирования экспериментов. Если бы была плотная матрица внимания, переход с контекста L на 4L дал бы примерно 16-кратный рост памяти на self-attention. У ruGPT3XL благодаря alternating sparse-паттерну это скорее 3-4x на практике. Это означает, что 8k контекст в принципе влезет на RTX 4090 с 48 ГБ, причём даже при full обучении (полной разморозкой всех весов модели).

Стратегия расширения

Без изобретения велосипеда ничего не выйдет, так что вот три принципа, которым я решил следовать:

1. Тайлинг позиционных эмбеддингов

Первым напрашивался вариант с линейной интерполяцией, тупо в лоб взять существующую матрицу 2048 x 2048 и заскейлить её до нужного размера - это сработало плохо, интерполяция меняет все 2048 строк, в том числе те, для которых у модели всё итак работало. PPL на коротком контексте сразу после такой операции улетел за сотку, не наш вариант ;)

Покумекав вспомнил про метод тайлинга (зацикливания, ну или проще дублирования), оригинальные позиции 0-2047 копируются буквально, а новые заполняются циклически:

позиция 2048 <- веса позиции 0
позиция 2049 <- веса позиции 1
...
позиция 4095 <- веса позиции 2047
позиция 4096 <- веса позиции 0  (второй цикл)
...
позиция 8191 <- веса позиции 4095

Смысл в том, что модель с первых шагов дообучения хотя бы не паникует на новых индексах, а короткий контекст работает точно так же, как и раньше.

2. Смешанный датасет

60% длинных примеров (несколько статей gazeta, склеенных через EOS до целевой длины) и 40% коротких чанков до половины целевой длины. Длинные обучают новые позиции, короткие не дают модели забыть как работать с привычными контекстами. Без коротких примеров PPL на 2k стремительно деградировал.

3. Ступенчатое расширение

Сначала 2k -> 4k, потом берём обученную 4k-модель и делаем 4k -> 8k. Сразу прыгнуть с 2k на 8k гипотетически можно, но это значит тайлить позиции 2048-8191 из диапазона 0-2047, что довольно грубо, да и модель за три эпохи на небольшом датасете может не успеть освоить такой диапазон, а большее количество эпох может привести к оверфиту (переобучению), чего я бы не хотел допустить.

Параметры обучения

Датасет IlyaGusev/gazeta, сплит train, 3 эпохи на каждый шаг, lr=5e-6 с cosine decay, gradient checkpointing, bfloat16, RTX 4090 (48 ГБ) и так далее.

По времени ушло:

  • ~2.6 часа на шаг 2k->4k

  • ~3.9 часа на 4k->8k

При обучении на 8k контексте CUDA фрагментировала память в процессе backpropagation и падала с OOM на середине - при этом на GPU формально оставался ~1 ГБ свободной памяти, но PyTorch не мог нарезать из него нужный смежный кусок. Решается одной строкой в переменных окружения: PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True. После этого пиковое потребление упало с 46.8 до 38.5 ГБ и обучение дошло до конца без приключений.

Полученная моделька вот тут: evilfreelancer/ruGPT3XL-8k

Perplexity 8k

Тест на сплите test датасета IlyaGusev/gazeta:

Модель

PPL @ 2048

PPL @ 4096

PPL @ 8192

ruGPT3XL (baseline)

11.68

-

-

ruGPT3XL-4k

11.75

12.04

-

ruGPT3XL-8k

11.77

11.99

13.00

Регрессия на исходном 2k контексте всего +0.09 к baseline - модель не разучилась работать с короткими последовательностями.

4k в финальной 8k-модели оказался даже лучше, чем у промежуточного чекпоинта (11.99 vs 12.04) - continued pretraining чуть подтянул общее качество.

На 8k получаем 13.00, что для четырёхкратного расширения контекста вполне достойно.

Память 8k

Длина контекста

VRAM peak

KV + активации

512

2.92 GiB

0.25 GiB

1024

3.16 GiB

0.49 GiB

2048

3.86 GiB

1.19 GiB

4096

6.57 GiB

3.90 GiB

8192

15.98 GiB

13.31 GiB

Веса модели - ~2.67 ГБ в bfloat16, до 2k overhead растёт почти линейно, что подтверждает работу sparse attention, дальше становится квадратичнее.

Скорость генерации

Длина промпта

tok/s

ms/tok

512

1444

0.7

1024

882

1.1

2048

378

2.6

4096

67

14.9

8192

38

26.6

На коротких промптах модель летает благодаря KV-кешу, с 2k на 4k скорость падает в 5.6 раза, даже с KV-кешом при каждом autoregressive шаге нужно протащить внимание через всю историю.

Зато переход 4k -> 8k (2x по длине) даёт только 1.8x замедление (67 -> 38 tok/s), хотя памяти надо уже гораздо больше.

Итого

Конвертированная ruGPT3XL теперь работает правильно, PPL соответствует оригиналу, sparse attention реализован и в transformers-версии, и в llama.cpp, контекст растянут с 2k до 8k с минимальной регрессией на коротких последовательностях.

На RTX 4090 ruGPT3XL-8k пригодна на любой длине контекста, на бюджетных 8-12 ГБ карточках комфортно до 4k, что уже в два раза лучше оригинала.

Следующий очевидный шаг - instruction tuning, но это уже другая история.

Ссылки

Послесловие

Вот такой вот занятный эксперимент у меня получился, надеюсь интерес к моему маленькому проекту у читателей и подписчиков сохранится, так как хочется попробовать ещё парочку занятных вещей типа квантизацию и обучения в mxfp4, а так же конвертацию модельки в MoE формат, плюс на очереди ещё пухляшка ruGPT 3.5 на 13B параметров, короче есть ещё чем заняться.

Ну я в свою очередь благодарю вас за прочтение, надеюсь мои наработки пригодятся, буду рад фидбеку в комментариях или в телеграме.