Как стать автором
Обновить
VK
Технологии, которые объединяют

Генерация персонализированных стикеров на основе DreamBooth

Уровень сложностиСредний
Время на прочтение8 мин
Количество просмотров4.2K

Привет, Хабр! Меня зовут Саша Рогачёв, я старший программист-исследователь в команде компьютерного зрения в VK. Перенос стиля — одна из самых интересных задач в генеративном компьютерном зрении. Не каждый может создавать изображения в определённом стиле, как это реализовано во множестве фоторедакторов с открытым и закрытым кодом, которые позволяют сделать картинку в жанре импрессионизма, ретро, кубизма и т. д. Самая частая проблема, с которой можно столкнуться при реализации таких приложений – это дообучение больших моделей. Решить её можно при помощи разных методов: например, DreamBooth, LoRA и т. д. 

С этой задачей я и моя команда столкнулись в школе по практическому программированию и анализу данных от Питерской Вышки, генеральным партнером которого выступила компания VK. В рамках образовательной программы от экспертов VK Education мы решали задачу по генерации стикеров с использованием диффузионных моделей. В этой статье мы расскажем о нашем подходе к её решению, с какими трудностями встретились и к каким выводам пришли.

Файнтюн Stable Diffusion

Stable Diffusion — диффузионная модель, позволяющая генерировать картинки по введённому пользователем тексту (и не только). Обучение заключается в пошаговом зашумлении картинки, а затем её восстановлении. В отличие от похожих реализаций (Imagen, MidJourney) у Stable Diffusion есть несколько отличительных черт, благодаря которым мы её и выбрали. Во-первых, веса модели лежат в открытом доступе, что позволяет любому человеку использовать их бесплатно. Во-вторых, её можно легко настроить, используя дополнительные архитектуры, придавая изображению желаемый стиль. 

Схема архитектуры Stable Diffusion (из оригинальной статьи).
Схема архитектуры Stable Diffusion (из оригинальной статьи).

Изначально мы рассматривали три способа файнтюна Stable Diffusion: DreamBooth, LoRA и Textual Inversion, они являются наиболее популярными на сегодняшний день алгоритмами для дообучения больших text2img-моделей. 

DreamBooth — популярный способ затюнить большую диффузионную модель (Stable Diffusion, Imagen) на маленьком наборе картинок.

LoRA — подход к обучению моделей с использованием низкоразмерных адаптеров, при котором не требуется большой объём памяти для хранения весов: обучаются только адаптеры под нужную задачу, а веса основной модели замораживаются.

Textual Inversion — особая методика генерации изображений в определённом стиле, при которой текстовый энкодер «ищет» схожие с пользовательским изображением эмбеддинги, а затем добавляет их к промпту.

Впоследствии мы оставили только первые два: в оригинальной статье авторы DreamBooth указывают, что качество Textual Inversion хуже, а ещё мы реализуем генерацию img2img, для которой этот подход не предназначен. 

Основными преимуществами DreamBooth являются:

  • Возможность дообучения на маленьких наборах данных, что удовлетворяет нашей задаче: обычно размер набора стикеров не превышает 30-40 картинок.

  • Хорошее качество результата.

Однако, главный недостаток этого подхода заключается в том, что модель потребляет очень много памяти, что усложняет продуктовизацию и взаимодействие с пользователем. В качестве альтернативы решили протестировать метод LoRA: дообучение довольно быстрое, а веса занимают меньше места по сравнению с DB.

Эмпирически выяснили, что DreamBooth работает лучше: картинки получаются более прорисованными и похожими на оригинальные, к тому же второй подход при инференсе по большей части опирается на текстовый промпт, а не на картинку. Поэтому решили далее работать с DreamBooth.

Сравнение различных подходов к дообучению моделей.
Сравнение различных подходов к дообучению моделей.

В оригинальной статье авторы приводят два ключевых метода: Class-specific Prior Preservation Loss и Rare-token Identifiers. Немного расскажем о каждом из них.

Class-specific Prior Preservation Loss — функция ошибки, которая позволяет обучить модель с приоритетом на класс, указанный пользователем. Благодаря ей модель «не забывает» основные признаки нужного класса.

Rare-token Identifiers — специальные токены (обычно используется sks), благодаря которым модель «обращает внимание» на важные признаки в тексте и изображении, и генерации не теряют исходный смысл.

Сравнение LoRA и DreamBooth.
Сравнение LoRA и DreamBooth.

Итоговый пайплайн

Мы используем довольно сложную архитектуру, состоящую из нескольких моделей. Схематически она выглядит так:

  1. Файнтюн StableDiffusion с помощью DreamBooth на пользовательском наборе стикеров. Это основной этап работы, после которого картинки на выходе принимают схожий с референсом стиль.

  2. Генерация текстового промпта с помощью BLIP . Этот подход используется для создания промпта для будущей генерации, так как от этого сильно зависит качество итогового стикера.

  3. Генерация стикера по пользовательскому изображению и промпту из пункта 2.

  4. Сегментация фона (дообученная модель SegFormer). Она нужна для избавления от различных шумов на фоне и придания ему прозрачности.

preencoded.png
Схема итогового пайплайна.

Ресурсы

Изначально эксперименты проводились в бесплатном Colab Notebook (одно ядро GPU T4), но через некоторое время появились неприятности: большие модели требуют больше видеопамяти, и даже запуски скриптов с использованием ускорителей (библиотеки xformers и accelerate) не сильно помогли: дообучение не прерывалось только при batch_size=1. После этого решили перейти на сервер VK Cloud (одно ядро GPU A100). Использовали наборы стикеров из Telegram, которые находятся в открытом доступе (всего около 7,5 тысяч изображений).

Файнтюн Stable Diffusion с помощью DreamBooth

1. text2img

Мы решили сперва дообучить DreamBooth на всём наборе собранных стикеров. В этом заключалась наша главная ошибка, ведь к обучению даже на таком относительно небольшом датасете (7,5 тыс. изображений) DreamBooth оказался совершенно не готов. Даже при batch_size=16 и lr=1е-7 лосс скакал с 0,001 до 0,.5, что крайне негативно сказывалось на качестве генерируемых картинок.

После этого решили дообучать модель на отдельных наборах стикеров. Это улучшило ситуацию: большинство сгенерированных изображений имело схожий с референсом стиль.

2. img2img

Попробовали альтернативное решение: зачем с помощью BLIP преобразовывать эмбеддинги картинки в текст, если у диффузионки можно поменять TextEncoder на ImageEncoder и не добавлять ещё одну модель в и без того тяжёлый пайплайн. Так и сделали: переписали код c Hugging Face для тренировочного цикла и инференса. Однако использовать это не удалось, потому что пришлось бы дообучать модель на наборе из пар «оригинальная и сгенерированная картинка», а подходящего для нашей задачи не было.

Сегментация

Во время обучения стало понятно, что придётся использовать сегментацию для выделения фона и его удаления, так как большинство стикеров не квадратной формы (у них прозрачный фон). В первой версии мы использовали предобученный UNet как одно из самых распространённых решений для задачи сегментации и удаления фона на изображении. Увы, качество сегментации предобученным UNet не дотягивало до удовлетворительного, поэтому впоследствии мы решили дообучить на всём нашем наборе модели UNet и SegFormer. Оценили их по метрике IoU, часто используемой для оценки качества работы сегментационных моделей:

UNet

SegFormer

IoU

0,79

0,98

Про метрики

Для оценки качества работы нашего решения использовалось сразу два подхода: метрики для оценки качества генерируемых картинок и метрики для подбора промпта. Собрали небольшой набор из четырёх классов: trash — случайный шум и мусорные картинки, bad — примеры плохой генерации (сюда преимущественно вошли результаты работы сильно переученного Stable Diffusion), good — удачные примеры генерации, и painted — примеры стикеров, нарисованных вручную, которые мы берём за эталон по стилистике и качеству. 

Примеры стикеров в наборах.
Примеры стикеров в наборах.

Цель ㅡ максимизировать разницу в показаниях метрик на каждом из наборов: для нарисованных стикеров метрика должна быть наивысшей; похуже для набора с хорошими примерами генерации; ещё хуже для набора с плохими; и совсем плохой результат — для набора с мусором. Так можно будет убедиться в релевантности показаний.

Метрики для подбора промптов 

Для подбора промптов использовались метрики CLIP- и Pick-score.

CLIP-score основана на сравнении эмбеддингов изображений (визуальные признаки) и промптов (текстовые признаки) в общем пространстве. Нужно было найти такой промпт, эмбеддинг которого был бы максимально близок к хорошим примерам генерации и рисованным картинкам, но при этом максимально далёк от эмбедингов мусорных картинок и плохих примеров генерации.

Метрика Pick-score также основана на сравнении эмбеддингов картинок и промптов, но для обучения их извлечения используется другой набор данных. В нём для каждой картинки человеком был написан комментарий, отражающий, насколько картинка красива, качественна и т. д. В случае с имеющимся набором данных уже понятно, какие картинки хорошие, а какие плохие, и нужно найти промпт, который отображал бы эту разницу наиболее явно. То есть задача стояла та же, что и в случае с clip-ом: поиск промпта, максимально близкого к хорошим картинкам и максимально далёкого от плохих.

Было протестировано 200 уникальных промптов. Основываясь на показаниях обеих метрик, наиболее подходящим, на удивление, оказался лаконичный промпт a telegram sticker:

dataset

CLIP-score

Pick-score

painted

27,796

19,864

good

27,659

19,784

bad

25,269

19,059

trash

17,872

17,647

А самым неподходящим ― a painted beautiful pretty sticker on monotonous black background:

dataset

CLIP-score

Pick-score

painted

18,107

18,4

good

21,597

18,607

bad

19,014

17,906

trash

21,009

17,783

Метрики для оценки качества генерируемых картинок

Для оценки качества генерируемых картинок использовали сразу три метрики: FID, Inception Score и LPIPS. Они, как и метрики, описанные выше, создают из входных данных эмбеддинги и сравнивают их близость, но в этом случае мы работаем только с фичами изображений. Для FID и IS использовали дообученную на наборах стикеров модель ViT-B-16, а LPIPS работала на предобученном VGG. Результаты:

dataset

Inception Score

FID

ERGAS

LPIPS

bad

1,528

181,138

386,717

0,559

good

1,219

207,136

274,662

0,491

painted

1,434

101,073

149,079

0,344

trash

1,619

132,298

249,227

0,776

Метрика LPIPS оказалась наиболее релевантной нашим данным: для painted она достигла 0,344, что близко к 0 (метрика LPIPS равна 0 для одинаковых изображений, так как при обработке их одной и той же сетью — в нашем случае предобученным VGG, — их выходные векторные представления будут идентичны, то есть. и ошибка между ними нулевая. А для trash ошибка, наоборот, близится к 1, что говорит о большой разнице в восприятии сетью vgg мусорных и эталонных изображений. Значения метрики для наборов good и bad также оказались в ожидаемом диапазоне.

Продуктовое решение

Для удобства взаимодействия с сервисом решили сделать Telegram—бот. С ним тоже было несколько проблем. 

  1. Из-за того, что суммарно наша архитектура весит около 5-7 Гб, решение получается довольно громоздким, то есть потребляет много видеопамяти. 

  2. Так как вся архитектура работает на GPU во время инференса, бот мог работать лишь с одним пользователем. В противном случае вылезала ошибка CUDA out of memory. В этом и раскрывается главный недостаток DreamBooth, так как для нормального использования потребуется придумать различные ухищрения: использовать несколько видеокарт, распараллелить обучение. Конечно, мы быстренько написали обработчик, чтобы бот не падал, однако проблема всё ещё не решна. 

  3. Не получилось реализовать многопоточность. 

Выводы

  1. Наиболее подходящим промптом генерации оказался a telegram sticker (по метрикам pick- и clip-score).

  2. Наиболее релевантной для оценки качества генерируемых изображений на имеющихся данных оказалась метрика LPIPS (в сравнении с FID, IS, ERGAS).

  3. Самой подходящей моделью для дообучения Stable Diffusion оказалась DreamBooth, благодаря своей способности дообучаться на крошечных наборах от 3-5 изображений. Тем не менее, к обучению на всём нашем массиве из нескольких тысяч стикеров, чтобы сперва показать Stable Diffusion основные паттерны Telegram-стикеров, DreamBooth оказался не готов, и модель быстро переобучилась.

  4. Лучшим методом генерации стикеров с точки зрения качества ожидаемо оказался Stable Diffusion. Тем не менее, этот подход очень ресурсоёмкий: генерация одного стикера на GPU A100 занимает около минуты, а дообучение с DreamBooth даже на небольшом наборе стикеров — 5 минут. Помимо этого, дообученная модель занимает много места в памяти (до 7 Гб). Всё это сильно мешало бы масштабировать решение на Stable Diffusion.

Примеры генерации.
Примеры генерации.

Авторы:

Таисия Чегодаева

Александр Смирнов

Игорь Карташов

Теги:
Хабы:
Всего голосов 8: ↑7 и ↓1+12
Комментарии1

Публикации

Информация

Сайт
team.vk.company
Дата регистрации
Дата основания
Численность
свыше 10 000 человек
Местоположение
Россия