Как я пытался понять на что именно тратится VRAM при генерации изображений.
В процессе моих изысканий о том как-же создавть изображения локально, я столкнулся с неочевидной, для себя, проблемой в виде колоссального потребления VRAM, которое не сходилось с тем что написано в карточках моделей и в интернете.
Разбираемся на примере современной FLUX.2-dev. Чтобы теоретически влезать в доступную мне VRAM на моем оборудовании я выбрал вариант GGUF Q4_K_M. И вот тут началось все самое интересное.
Модель: unsloth/FLUX.2-dev-GGUF-Q4_K_M — FLUX.2-dev (DiT-архитектура) с квантизацией Q4_K_M (4-бит, K-means block scaling).
Начнем с того, что для unsloth/FLUX.2-dev требуется текстовый энкодер, вместе с которым эта модель и создавалась авторами. В нашем случае это Mistral-Small-3.2-24B. Если "взять и запустить" то загрузится полная версия этой модели которая съедает VRAM ~45GB. У меня столько доступной памяти нет, по этому я начал искать.
После подбора различных моделей и вариантов рабочим оказался Mistral-Small-3.2-24B через MLX 4-bit. Всё, ничего другое у меня нормально (в ожидаемым потреблением VRAM) не завелось. После таких манипуляций этот этап стал занимать ~16GB VRAM.
Общий Flow такой - сначала грузим энкодер, выгружаем энкодер (результаты не выгружаем), загружаем основную модель. Я назвал это sequential offload.
Ожидалось, что GGUF-файл весит 19 GB на диске. Значит, в памяти он должен занимать примерно столько-же. Ну скажем хотябы не более 20-21 GB.
Реальность: 28.92 GB peak (Activity Monitor).
Куда уходят "лишние" ~9 GB? Пытаемся понять и сократить потребление.
Стенд
Компонент | Версия |
|---|---|
macOS | 26.3.1 (Tahoe) |
Чип | Apple M3 Pro, 36 GB unified memory |
Python | 3.14.3 |
PyTorch | 2.11.0 |
diffusers | 0.38.0.dev (Коммит: c02c17c6 от 2026-03-21 ) |
accelerate | 1.13.0 |
transformers | 5.3.0 |
safetensors | 0.7.0 |
mlx / mlx-lm | 0.31.1 |
Архитектура загрузки
FLUX.2-dev использует Mistral-Small-3.2-24B (~24B параметров) в качестве текстового энкодера. Это огромная модель, поэтому загрузка идёт через sequential offload:
Фаза 1: MLX encoder (Mistral 4-bit, ~13 GB) encode() → prompt_embeds на CPU unload() → mx.clear_cache() Фаза 2: DiT pipeline (GGUF Q4_K_M, ~19 GB) from_single_file() → CPU from_pretrained() → VAE, scheduler на CPU pipe.to("mps") → всё на MPS Фаза 3: Inference 1-28 шагов диффузии на MPS
Идея: peak = max(encoder, DiT), а не encoder + DiT. Энкодер выгружается перед загрузкой DiT. Теоретический пик — 19 GB (размер DiT).
Но на практике...
Инструменты профилирования
Для анализа написаны два скрипта:
**scripts/profile_memory.py— поэтапный профиль внутри процесса: измеряетphys_footprint(метрика Activity Monitor) иtorch.mps.driver_allocated_memory()после каждой фазы. Режимы:--stage gguf-only|all|inference, флаг--low-watermark.**scripts/measure_generate.py** — внешний мониторинг: запускаетgenerate.pyкак subprocess и семплируетphys_footprintкаждые 0.5s черезproc_pid_rusage. Строит таймлайн и находит пик.
Ключевой момент: psutil.memory_info().rss не включает MPS-аллокации на macOS. Для unified memory нужен phys_footprint из proc_pid_rusage (то же, что показывает Activity Monitor в колонке Memory).
Профилирование
Поэтапное измерение (profile_memory.py --stage inference)
Этап | phys_footprint | MPS driver | MPS current |
|---|---|---|---|
Baseline (imports) | 0.27 GB | 0 | 0 |
MLX encoder loaded | 13.22 GB | 0 | 0 |
MLX encode + unload | 0.51 GB | 0 | 0 |
GGUF transformer (CPU) | 19.23 GB | 0 | 0 |
from_pretrained (CPU) | 19.38 GB | 0 | 0 |
pipe.to(mps) | 23.79 GB | 20.06 GB | 18.75 GB |
MID-INFERENCE (step 1) | 26.50 GB | 25.62 GB | 18.78 GB |
After inference | 25.74 GB | 25.68 GB | 18.76 GB |
After empty_cache() | 24.46 GB | 21.08 GB | 18.76 GB |
Внешний мониторинг (measure_generate.py)
Запуск скрипта по генерации (один шаг и одно изображение за раз): generate.py --model unsloth/FLUX.2-dev-GGUF-Q4_K_M --batch 1 --steps 1
Memory (GB) 30 ┤ │ ▄▄▄█▄▄▄ 28 ┤ ▄█ ██▄▄▄ │ ▄█ █▄ 26 ┤ ▄▄▄██ █████████████ │ ▄█ 24 ┤ ▄█ │ ██ 22 ┤ █ │ ▄█ 20 ┤ ▄█ │██████████████████████████ 18 ┤ (GGUF loading) ▲ PEAK 28.92 GB │ 14 ┤ ▄▄▄▄▄ │ █ █ (MLX encoder) 12 ┤█ █ │ █ 0 ┤──────────────────────────────────────────────────────────────────→ t 0s 10s 20s 40s 50s 70s 90s 120s
Peak: 28.92 GB — наступает при pipe.to("mps") (CPU→MPS transfer), а не при inference.
Куда уходит память: анатомия 28 GB
Компоненты на MPS после загрузки
Компонент | Размер | Как измерено |
|---|---|---|
DiT Q4_K_M (19.96B uint8 элементов) | 18.59 GB |
|
VAE (bf16) | 0.16 GB |
|
MPS current | 18.75 GB |
|
MPS allocator pool overhead | +1.31 GB |
|
MPS driver | 20.06 GB |
|
GGUF веса не деквантизируются при загрузке на MPS — остаются в формате torch.uint8 как GGUFParameter. Деквантизация в bf16 происходит покомпонентно во время forward pass.
"Тёмная материя" — IOKit / MPS overhead
phys_footprint после pipe.to(mps): 23.79 GB MPS driver_allocated: 20.06 GB ───────────────────────────────────────────── Разница: 3.73 GB
Эти 3.73 GB — память, невидимая для PyTorch:
Metal graphCache (~1-2 GB) — скомпилированные compute shaders.
Для FLUX.2-dev с 28 transformer-блоками и множеством attention-конфигураций кеш шейдеров значителен.IOKit page tables — метаданные GPU-доступной памяти для 331 отдельных Metal-буферов (каждый
GGUFParameter— отдельная аллокация).Limbo buffers — буферы, ожидающие завершения GPU command buffer.
Это структурный overhead Apple Silicon MPS, задокументированный как PyTorch issue #164299.
Он не устраним средствами PyTorch, но это не точно.
Dequantization overhead при inference
phys_footprint MID-INFERENCE: 26.50 GB phys_footprint model loaded: 23.79 GB ───────────────────────────────────────── Дельта: +2.71 GB
При forward pass каждый GGUFLinear слой деквантизирует свои uint8-веса в bf16 для матричного умножения. Самый большой слой (double_stream_modulation_img.linear.weight: 36864 x 12288) создаёт временный bf16-тензор ~0.9 GB. MPS allocator кеширует эти буферы после использования вместо возврата OS:
MPS driver после inference: 25.68 GB MPS current после inference: 18.76 GB ───────────────────────────────────────── Кеш деквантизации: 6.92 GB
torch.mps.empty_cache() возвращает часть этого кеша, но не весь.
CPU→MPS transfer peak
Главный «виновник» пика 28.92 GB — pipe.to("mps"). Стандартный .to() переносит тензоры инкрементально, но macOS не успевает освободить
CPU-страницы:
CPU resident MPS driver phys_footprint ───────────────────────────────────────────────────────────────── До .to(mps): 19.38 GB 0 GB 19.38 GB В процессе .to(): ~8 GB ~16 GB ~28 GB ← PEAK После .to(mps): ~0.3 GB 20.06 GB 23.79 GB
Проблема: unified memory — CPU и MPS используют один пул физической RAM.
Пока CPU-копия не освобождена, а MPS-копия уже создана, phys_footprint включает обе.
Итоговый бюджет
Компонент | Размер | Устранимо? |
|---|---|---|
DiT Q4_K_M weights (uint8) | 18.59 GB | Нет (размер модели) |
VAE | 0.16 GB | Нет |
MPS allocator pool | ~1.5 GB | Нет (PyTorch internals) |
MPS/IOKit overhead | ~3.5 GB | Нет (PyTorch #164299) |
Python + app runtime | ~1.5 GB | Частично |
Steady state | ~24 GB | |
GGUF деквантизация (inference) | +3-5 GB | Частично (empty_cache) |
Peak (inference) | ~27 GB | |
CPU→MPS overlap (transient) | +1-3 GB | Частично (incremental GC) |
Peak (загрузка) | ~28-29 GB |
20 GB — это только размер GGUF-файла. Реальный минимум для работы на MPS — ~24 GB, пик ~28 GB.
Оптимизации
1. Incremental .to(mps) с периодическим GC
Вместо pipe.to("mps") переносим параметры по одному с вызовом gc.collect() каждые N параметров. Это даёт macOS время на возврат CPU-страниц:
GC_EVERY_N_PARAMS = 30 def _move_to_device_incremental(pipe, device): n = 0 for attr in list(vars(pipe)): comp = getattr(pipe, attr, None) if not isinstance(comp, torch.nn.Module): continue for param in comp.parameters(): param.data = param.data.to(device) n += 1 if n % GC_EVERY_N_PARAMS == 0: gc.collect() for buf in comp.buffers(): buf.data = buf.data.to(device) gc.collect() torch.mps.synchronize() torch.mps.empty_cache()
Применяется только для GGUF-моделей на MPS (где .to() переносит ~19 GB квантованных данных).
2. empty_cache() между шагами inference
GGUF деквантизация создаёт временные bf16-буферы на каждом шаге. MPS allocator кеширует их, и на 28-шаговом прогоне они накапливаются до 5+ GB. Flush между шагами предотвращает накопление:
def _on_step(pipe, step, timestep, kwargs): if step_callback: step_callback(step + 1, steps) if flush_mps: torch.mps.empty_cache() return kwargs pipe_kwargs["callback_on_step_end"] = _on_step
3. empty_cache() после загрузки модели
После pipe.to(device) и gc.collect() — явный flush MPS allocator cache:
gc.collect() if torch.backends.mps.is_available(): torch.mps.synchronize() torch.mps.empty_cache()
Результаты
Тест: generate.py --model unsloth/FLUX.2-dev-GGUF-Q4_K_M ---batch 1 --steps 1
До / после оптимизаций (один прогон)
Метрика | До | После | Дельта |
|---|---|---|---|
Peak phys_footprint | 28.92 GB | 27.36 GB | -1.56 GB |
Process (steady, после inference) | 26.25 GB | 21.76 GB | -4.49 GB |
Peak (self-report GUI) | 28.72 GB | 27.36 GB | -1.36 GB |
Время генерации | 2:04 | 2:11 | +7s |
Пик снижен на ~1.5 GB (incremental GC при загрузке). Steady-state после inference снижен на ~4.5 GB благодаря empty_cache().
Оверхед по времени: +7 секунд из-за gc.collect() в incremental .to(). На 28-шаговом прогоне empty_cache() между шагами добавляет ещё ~1-2 секунды суммарно.
Таймлайн после оптимизаций
Memory (GB) 28 ┤ ▄█▄ │ ██ █▄ 26 ┤ ▄▄▄█▄ ██████████████████ │ ▄█ 24 ┤ ██▀ │ ▄██ 22 ┤ ▄██ ▲ PEAK 27.36 GB │ ██ 20 ┤ ▄█ │████████████████████████ 18 ┤ (GGUF loading) │ 14 ┤ ▄▄▄▄▄ │ █ █ (MLX encoder) 12 ┤█ █ │ █ 0 ┤──────────────────────────────────────────────────────────────────→ t 0s 10s 20s 40s 50s 70s 90s 120s
Заметно: после inference (t=91s) phys_footprint падает до 23.38 GB благодаря empty_cache() в callback — было бы 26+ GB без flush.
Что нельзя оптимизировать
IOKit/MPS overhead (~3.5 GB) — Metal driver metadata, graphCache, page tables. PyTorch issue #164299, не решается на стороне приложения.
MPS allocator pool (~1.5 GB) — внутренний пулинг PyTorch MPS backend.
PYTORCH_MPS_LOW_WATERMARK_RATIO=0.0немного помогает с reclaim после inference (-2.7 GB приempty_cache()), но не снижает пик.Деквантизация при forward pass — GGUF uint8 → bf16 для каждого Linear слоя. Это фундаментальная стоимость runtime-деквантизации. Единственная альтернатива — нативная квантованная арифметика (пока не поддерживается MPS backend).
Размер модели — 18.59 GB в Q4_K_M. Дальнейшее сжатие (Q2_K, Q3_K_S) существенно снижает качество генерации.
Итоги
20 GB — не достижимая величина на PyTorch на Apple Silicon, в настоящий момент. Размер GGUF-файла на диске и потребление памяти при работе на Apple Silicon MPS — разные вещи. К весам модели прибавляются:
MPS allocator pool (~1.5 GB) — PyTorch кеширует аллокации
IOKit/Metal overhead (~3.5 GB) — driver metadata, shader cache
Runtime деквантизация (~3-5 GB) — временные bf16 буферы при inference
CPU→MPS overlap (~1-3 GB transient) — unified memory считает обе копии
Корректный бюджет для FLUX.2-dev Q4_K_M на Apple Silicon:
Steady state: ~24 GB (модель загружена, idle)
Inference peak: ~27 GB (с оптимизациями)
Load peak: ~27-28 GB (transient, при CPU→MPS transfer)
На M1/M2 с 16 GB — не влезет. M3 Pro / M2 Pro с 32 GB — впритык. M3 Max / M2 Ultra с 64+ GB — комфортно.
Скорость генерации на M3 Pro
Step=1
Done! 1 sprite(s) generated in 2 minutes and 10 seconds Started: 2026-03-26 10:20:47 Finished: 2026-03-26 10:22:58 Elapsed: 2 minutes and 10 seconds Pipeline timing: Generate unsloth/FLUX.2-dev-GGUF-Q4_K_M (2 minutes and 10 seconds) + Mistral3-VLM encoder (sequential offload) (11 seconds) + DiT inference (44 seconds) + GGUF: unsloth/FLUX.2-dev-GGUF/flux2-dev-Q4_K_M.gguf (37 seconds)
Step=28 (default):
Done! 1 sprite(s) generated in 34 minutes and 30 seconds Started: 2026-03-26 07:47:30 Finished: 2026-03-26 08:22:01 Elapsed: 34 minutes and 30 seconds Pipeline timing: Generate unsloth/FLUX.2-dev-GGUF-Q4_K_M (34 minutes and 28 seconds) + Mistral3-VLM encoder (sequential offload) (9 seconds) + DiT inference (33 minutes and 2 seconds) + GGUF: unsloth/FLUX.2-dev-GGUF/flux2-dev-Q4_K_M.gguf (36 seconds)
Разница между общим временем работы пайплайна и работой кадой модели в отдельности - это этап загрузки моделей с дика а память, в настоящий моент это я не отслеживаю инструментально.
