В этой статье, про ИИ, написанной не полностью ИИ, про генерацию изображений - не будет изображений.
В конце этой статьи мы будем запускать эту модель на указанном чипе, но начнем мы с чуть более мощного - он понадобиться чтобы разобраться с проблемой.
Первая проблема с которой я столкнулся - это потребление памяти. Поиски в интернете, описание самой модели говорили о том что она должна помещаться в ~10GB VRAM. Чего должно с запасом хватать для Apple M1 16GB. Однако фактическое зафиксированное потребление памяти составило 21 GB, не зафиксированное 28 GB (после чего я и начал исследование).
На этом моменте мы вынуждены переместиться на оборудование помощнее, например ноутбук с чипом Apple M3 Pro 36GB для исследования проблемы.
Теперь немного исходных технических данных об используемых версиях софта:
Пакет | Версия |
|---|---|
macOS | 26.3 (Tahoe) |
Python | 3.14 |
PyTorch | 2.11.0 |
diffusers | 0.38.0.dev0 |
accelerate | 1.13.0 |
transformers | 5.3.0 |
safetensors | 0.7.0 |
SD 3.5 Medium на Apple Silicon: с 21 ГБ до 11.6 ГБ пиковой памяти
SD 3.5 Medium позиционируется как модель на 2.5B параметров, которой нужно ~5 ГБ VRAM. На практике на Apple Silicon MPS она потребляет 20+ ГБ. Разбираемся почему.
Проблема
Запуск SD 3.5 Medium fp16 на M3 Pro (36 ГБ unified memory):
pipe = SD3Pipeline.from_pretrained(path, torch_dtype=torch.float16, variant="fp16") pipe.to("mps")
Модель Stable Diffusion Medium состоит из следующих компонентов:
Компонент | Параметры | Размер fp16 |
|---|---|---|
T5-XXL текстовый энкодер | 5.5B | 10.75 ГБ |
DiT-трансформер | 2.1B | 4.18 ГБ |
VAE | — | 0.16 ГБ |
Итого веса | 15.09 ГБ |
T5-XXL занимает 71% всей памяти — в 2.5 раза больше самой модели генерации.
Запускаем:
from diffusers import StableDiffusion3Pipeline import torch pipe = StableDiffusion3Pipeline.from_pretrained( "stabilityai/stable-diffusion-3.5-medium", torch_dtype=torch.float16, variant="fp16") pipe.to("mps") image = pipe("a warrior with a sword", num_inference_steps=40).images[0] image.save("warrior.png")
Пиковое потребление: 21.4 ГБ. Без разницы, steps=1 или steps=40.
Пик в 21 ГБ складывается из двух проблем:
Дублирование CPU→MPS.
from_pretrainedгрузит все веса на CPU, затем.to("mps")создаёт MPS-копии. Во время переноса обе копии существуют одновременно: T5 на MPS (11 ГБ) + DiT/VAE на CPU (4.3 ГБ) + overhead MPS = 20+ ГБ пик.enable_model_cpu_offload()не помогает на Apple Silicon. CPU и MPS разделяют одну физическую RAM — перемещение тензоров между ними не освобождает память.
Решение: Загружаем Decoder и DiT раздельно
Загружаем Decoder T5 и DiT раздельно. Теперь пик потребления памяти - приходится на T5 так эта модель больше, чем сам генератор изображений. Но есть и тонкости:
1. Загрузка с device_map="balanced" — без CPU-копии
# Было: пик ~15 ГБ (CPU- и MPS-копии сосуществуют) pipe = SD3Pipeline.from_pretrained(path, torch_dtype=torch.float16) pipe.to("mps") # Стало: пик ~11 ГБ (загрузка напрямую в MPS через accelerate) pipe = SD3Pipeline.from_pretrained(path, torch_dtype=torch.float16, device_map="balanced")
2. Удаление хуков accelerate перед освобождением модели
Неочевидный момент. device_map использует dispatch-хуки accelerate, которые держат strong references на тензоры. del pipe + gc.collect() + empty_cache() сами по себе не освобождают в этой ситуации память.
from accelerate.hooks import remove_hook_from_submodules for attr in list(vars(pipe)): comp = getattr(pipe, attr, None) if isinstance(comp, torch.nn.Module): remove_hook_from_submodules(comp) setattr(pipe, attr, None) del pipe gc.collect() torch.mps.synchronize() torch.mps.empty_cache()
Без этого T5 остаётся в памяти (cur=10.75 ГБ) даже после «выгрузки», и DiT грузится поверх — воспроизводя пик в 21 ГБ.
3. Полная последовательность: encode → unload → generate
# Фаза 1: только T5-энкодер enc_pipe = SD3Pipeline.from_pretrained(path, transformer=None, vae=None, text_encoder=None, text_encoder_2=None, # без CLIP tokenizer=None, tokenizer_2=None, variant="fp16", torch_dtype=torch.float16, device_map="balanced") prompt_embeds = enc_pipe.text_encoder_3(tokens)[0].cpu() # Выгрузка T5 (с удалением хуков из шага 2) # ... # Фаза 2: только DiT + VAE — используем .to("mps"), не device_map gen_pipe = SD3Pipeline.from_pretrained(path, text_encoder=None, text_encoder_2=None, text_encoder_3=None, tokenizer=None, tokenizer_2=None, tokenizer_3=None, variant="fp16", torch_dtype=torch.float16) gen_pipe.to("mps") # заставляет MPS-драйвер освободить закешированные страницы T5 gen_pipe(prompt_embeds=prompt_embeds.to("mps"), ...)
Для DiT намеренно используется .to("mps") — это заставляет MPS-драйвер переиспользовать кеш T5, сбрасывая footprint с 11.5 до 6 ГБ.
С этого момента можно возвращаться на чип M1 16Gb.
Результат
M3 Max 36 ГБ, SD 3.5 Medium fp16, 512x512:
Этап | До | После |
|---|---|---|
Загрузка T5 | 15.3 ГБ | 11.5 ГБ |
T5 после выгрузки | cur=10.75 ГБ (утечка) | cur=0.01 ГБ |
DiT + диффузия | 8.9 ГБ | 8.9 ГБ |
Пик | 21.4 ГБ | 11.6 ГБ |
Минимальный рабочий пример — пик 11.6 ГБ:
from diffusers import StableDiffusion3Pipeline from accelerate.hooks import remove_hook_from_submodules import torch, gc path = "stabilityai/stable-diffusion-3.5-medium" # ── Фаза 1: кодирование промпта (только T5-XXL, ~11 ГБ) ── enc_pipe = StableDiffusion3Pipeline.from_pretrained(path, transformer=None, vae=None, # без DiT + VAE text_encoder=None, text_encoder_2=None, # без CLIP tokenizer=None, tokenizer_2=None, variant="fp16", torch_dtype=torch.float16, device_map="balanced") # напрямую в MPS prompt_embeds, neg_embeds, pooled, neg_pooled = enc_pipe.encode_prompt( prompt="a warrior with a sword", prompt_2="a warrior with a sword", prompt_3="a warrior with a sword", device="mps", num_images_per_prompt=1) # ── Выгрузка T5 (обязательно снять хуки accelerate!) ────── for attr in list(vars(enc_pipe)): comp = getattr(enc_pipe, attr, None) if isinstance(comp, torch.nn.Module): remove_hook_from_submodules(comp) setattr(enc_pipe, attr, None) del enc_pipe gc.collect() torch.mps.synchronize() torch.mps.empty_cache() # ── Фаза 2: генерация (только DiT + VAE, ~8 ГБ) ────────── gen_pipe = StableDiffusion3Pipeline.from_pretrained(path, text_encoder=None, text_encoder_2=None, text_encoder_3=None, tokenizer=None, tokenizer_2=None, tokenizer_3=None, variant="fp16", torch_dtype=torch.float16) gen_pipe.to("mps") # рекле́ймит кеш T5 из MPS-драйвера image = gen_pipe( prompt_embeds=prompt_embeds, negative_prompt_embeds=neg_embeds, pooled_prompt_embeds=pooled, negative_pooled_prompt_embeds=neg_pooled, num_inference_steps=40).images[0] image.save("warrior.png")
Статистика и производительность
Генерация одной картинки:
Apple M1 16GB: ~5 минут | 11 GB - 13 GB VRAM
Elapsed: 4 minutes and 54 seconds
Pipeline timing:
Generate stabilityai/stable-diffusion-3.5-medium (4 minutes and 54 seconds)
+ T5-XXL encoder (sequential offload) (1 minute and 6 seconds)
+ DiT inference (3 minutes and 34 seconds)
Apple M3 Pro 36B: ~1 минута 40 секунд | 9 GB - 11 GB VRAM
Elapsed: 1 minute and 47 seconds
Pipeline timing:
Generate stabilityai/stable-diffusion-3.5-medium (1 minute and 47 seconds)
+ T5-XXL encoder (sequential offload) (14 seconds)
+ DiT inference (1 minute and 27 seconds)
