Запуск современных Text-to-Video моделей локально — задача не для слабонервных. Когда китайские исследователи из PKU-YuanGroup выложили в open-source свою модель Open-Sora-Plan, энтузиасты бросились её тестировать. Но есть нюанс: оригинальный пайплайн рассчитан на кластеры уровня H100/A100. Веса модели в полном разрешении занимают десятки гигабайт.

Моя цель заключалась в том, чтобы запустить инференс Open-Sora-Plan (v1.3.0) в спартанских условиях — на абсолютно бесплатном инстансе Google Colab с видеокартой NVIDIA T4 (15 ГБ VRAM, архитектура Turing 2018 года) и 12.7 ГБ системной ОЗУ.

Спойлер: скрипт отработал от начала и до конца без OOM (Out of Memory). Но для этого нам пришлось вскрывать исходники, бороться с аппаратными лимитами GPU и в прямом смысле делать нейросети математическую «лоботомию».

Вызов 1: OOM Killer и Staged Execution

Архитектура Open-Sora-Plan состоит из двух тяжеловесных компонентов:

  1. Текстовый энкодер (T5-XXL) — переводит промпт в эмбеддинги. В fp32 он весит под 19 ГБ.

  2. Диффузионный трансформер (DiT) — генерирует кадры (ещё около 11 ГБ).

Суммарно это 30 ГБ. Попытка загрузить всё в 15 ГБ VRAM ожидаемо вызывала смерть ядра. Файл подкачки (SWAP) спасал системную ОЗУ при загрузке весов с диска, но не решал проблему видеопамяти.

Решение: 4-bit квантование и поэтапная выгрузка

Во-первых, мы принудительно квантовали T5-XXL до 4 бит «на лету» с помощью bitsandbytes (формат NF4). Это сжало модель до приемлемых 5-6 ГБ.

Python

import transformers
import torch

kwargs['quantization_config'] = transformers.BitsAndBytesConfig(
    load_in_4bit=True, 
    bnb_4bit_quant_type="nf4", 
    bnb_4bit_use_double_quant=True, 
    bnb_4bit_compute_dtype=torch.float16
)

Во-вторых, мы реализовали Staged Execution. Поскольку текстовый энкодер нужен только на первой секунде (чтобы получить prompt_embeds), нет смысла держать его в памяти при генерации видео.

Мы загружали T5 на GPU, получали векторы, а затем безжалостно убивали объекты с принудительным вызовом сборщика мусора и очисткой кэша CUDA, освобождая плацдарм для загрузки DiT:

Python

del pipe.text_encoder
gc.collect()
torch.cuda.empty_cache()

Вызов 2: Хардкод и следы Huawei Ascend

На этапе сборки пайплайна мы обнаружили, что в релизе v1.5.0 разработчики выложили VAE-декодер только в формате .ckpt для китайских NPU Huawei Ascend (папка MindSpeed). Привычного config.json для GPU просто не было.

Попытка разобраться в их .sh скриптах привела к забавной находке. В эталонном коде инференса разработчики забыли вычистить пути к своим локальным серверам:

Bash

--ae_path "/home/save_dir/lzj/Formal_8dim/latent8"

Скрипт упорно пытался найти модель на жестком диске китайского аспиранта. Чтобы обойти это, мы написали точечный загрузчик через huggingface_hub, который "пинцетно" вытащил нужную папку VAE из предыдущего релиза (v1.3.0), игнорируя десятки гигабайт ненужных DiT-весов.

Вызов 3: Архитектурный лимит NVIDIA T4 (FP16 vs BFloat16)

Загрузив DiT в видеокарту, мы запустили генерацию. Скрипт отработал 50 шагов, но на выходе выд��л абсолютно черный экран.

Проблема крылась в аппаратной несовместимости. Open-Sora обучалась в формате bfloat16, который обладает огромным динамическим диапазоном (до ~3.39e38). Но видеокарта T4 (Turing) физически не поддерживает вычисления в bfloat16. При попытке запуска PyTorch ругался: RuntimeError: No available kernel.

Нам пришлось опустить torch_dtype до float16. И вот тут начался настоящий ад.

Лимит float16 — это 65504. Во время диффузии, особенно при высоком параметре CFG (Guidance Scale) и в слоях LayerNorm / AdaLN, дисперсия значений стремительно росла. Как только одно из чисел превышало лимит, происходило переполнение (Overflow). Число превращалось в inf, а на следующем шаге умножения матриц — в NaN.

VAE-декодер, получая тензор из NaN, просто декодировал его как нули (черный цвет в RGB).

Решение: Хирургический Monkey-Patching (Апкаст Attention)

Чтобы спасти математику на старом железе, мы написали "Квантовый стабилизатор". Мы не могли перевести всю модель в FP32 (не хватило бы VRAM), поэтому мы перехватили самую уязвимую функцию — scaled_dot_product_attention — и заставили её конвертировать тензоры в float32 ровно на миллисекунду перед перемножением матриц, а затем возвращать обратно в float16.

Вот как выглядит этот хак (Monkey-patching):

Python

import torch.nn.functional as F

orig_sdpa = F.scaled_dot_product_attention

def sdpa_fp32(query, key, value, *args_s, **kwargs_s):
    # Принудительный апкаст в FP32 перед вычислением Attention
    q, k, v = query.float(), key.float(), value.float()
    
    new_args = tuple(a.float() if isinstance(a, torch.Tensor) and a.dtype.is_floating_point else a for a in args_s)
    new_kwargs = {kw: (val.float() if isinstance(val, torch.Tensor) and val.dtype.is_floating_point else val) for kw, val in kwargs_s.items()}
    
    # Считаем в FP32 и возвращаем в исходный формат (FP16)
    return orig_sdpa(q, k, v, *new_args, **new_kwargs).to(query.dtype)

# Переопределяем функцию глобально
F.scaled_dot_product_attention = sdpa_fp32

Аналогичный патч мы применили к слоям LayerNorm внутри пайплайна DiT.

Побочный эффект: "Лоботомия" и рождение НЛО

Дополнительно, чтобы гарантированно избежать NaN, мы внедрили жесткий clamp (ограничитель) латентов перед отправкой их в VAE:

Python

def sanitize(x, limit=15.0):
    return torch.nan_to_num(x.float(), nan=0.0, posinf=limit, neginf=-limit)

Результат: Математика перестала взрываться. Мы победили черный экран. Но ценой семантики.

Из-за того, что мы агрессивно срезали экстремальные значения латентов (которые отвечали за детализацию и форму объектов), нейросеть потеряла структуру. Вместо промпта "Величественный дракон летит над городом" мы получили светящееся, размытое неоновое НЛО (цветовой пиксельный шум).

Выводы

  1. Запустить 30-гигабайтную видео-модель на 15 ГБ VRAM реально, если грамотно жонглировать памятью, использовать 4-bit квантование для энкодеров и применять gc.collect().

  2. Главная проблема инференса современных моделей на картах Turing (T4/RTX 20xx) и Pascal — отсутствие поддержки BFloat16. Переход на FP16 неизбежно ведет к математическим взрывам в слоях Attention.

  3. Точечный апкаст критических слоев в FP32 через Monkey-Patching спасает ядро от краша, но требует ювелирной работы со сэмплером, чтобы не разрушить семантику итогового изображения.

Да, вместо дракона мы получили пиксельное НЛО. Но это успешный Proof-of-Concept, показывающий, что «железные» ограничения всегда можно обойти, если знать, где ковырять исходники PyTorch.