Привет, Хабр! Меня зовут Саша Рогачёв, я старший программист-исследователь в команде компьютерного зрения в VK. Перенос стиля — одна из самых интересных задач в генеративном компьютерном зрении. Не каждый может создавать изображения в определённом стиле, как это реализовано во множестве фоторедакторов с открытым и закрытым кодом, которые позволяют сделать картинку в жанре импрессионизма, ретро, кубизма и т. д. Самая частая проблема, с которой можно столкнуться при реализации таких приложений – это дообучение больших моделей. Решить её можно при помощи разных методов: например, DreamBooth, LoRA и т. д.
С этой задачей я и моя команда столкнулись в школе по практическому программированию и анализу данных от Питерской Вышки, генеральным партнером которого выступила компания VK. В рамках образовательной программы от экспертов VK Education мы решали задачу по генерации стикеров с использованием диффузионных моделей. В этой статье мы расскажем о нашем подходе к её решению, с какими трудностями встретились и к каким выводам пришли.
Файнтюн Stable Diffusion
Stable Diffusion — диффузионная модель, позволяющая генерировать картинки по введённому пользователем тексту (и не только). Обучение заключается в пошаговом зашумлении картинки, а затем её восстановлении. В отличие от похожих реализаций (Imagen, MidJourney) у Stable Diffusion есть несколько отличительных черт, благодаря которым мы её и выбрали. Во-первых, веса модели лежат в открытом доступе, что позволяет любому человеку использовать их бесплатно. Во-вторых, её можно легко настроить, используя дополнительные архитектуры, придавая изображению желаемый стиль.
Изначально мы рассматривали три способа файнтюна Stable Diffusion: DreamBooth, LoRA и Textual Inversion, они являются наиболее популярными на сегодняшний день алгоритмами для дообучения больших text2img-моделей.
DreamBooth — популярный способ затюнить большую диффузионную модель (Stable Diffusion, Imagen) на маленьком наборе картинок.
LoRA — подход к обучению моделей с использованием низкоразмерных адаптеров, при котором не требуется большой объём памяти для хранения весов: обучаются только адаптеры под нужную задачу, а веса основной модели замораживаются.
Textual Inversion — особая методика генерации изображений в определённом стиле, при которой текстовый энкодер «ищет» схожие с пользовательским изображением эмбеддинги, а затем добавляет их к промпту.
Впоследствии мы оставили только первые два: в оригинальной статье авторы DreamBooth указывают, что качество Textual Inversion хуже, а ещё мы реализуем генерацию img2img, для которой этот подход не предназначен.
Основными преимуществами DreamBooth являются:
Возможность дообучения на маленьких наборах данных, что удовлетворяет нашей задаче: обычно размер набора стикеров не превышает 30-40 картинок.
Хорошее качество результата.
Однако, главный недостаток этого подхода заключается в том, что модель потребляет очень много памяти, что усложняет продуктовизацию и взаимодействие с пользователем. В качестве альтернативы решили протестировать метод LoRA: дообучение довольно быстрое, а веса занимают меньше места по сравнению с DB.
Эмпирически выяснили, что DreamBooth работает лучше: картинки получаются более прорисованными и похожими на оригинальные, к тому же второй подход при инференсе по большей части опирается на текстовый промпт, а не на картинку. Поэтому решили далее работать с DreamBooth.
В оригинальной статье авторы приводят два ключевых метода: Class-specific Prior Preservation Loss и Rare-token Identifiers. Немного расскажем о каждом из них.
Class-specific Prior Preservation Loss — функция ошибки, которая позволяет обучить модель с приоритетом на класс, указанный пользователем. Благодаря ей модель «не забывает» основные признаки нужного класса.
Rare-token Identifiers — специальные токены (обычно используется sks), благодаря которым модель «обращает внимание» на важные признаки в тексте и изображении, и генерации не теряют исходный смысл.
Итоговый пайплайн
Мы используем довольно сложную архитектуру, состоящую из нескольких моделей. Схематически она выглядит так:
Файнтюн StableDiffusion с помощью DreamBooth на пользовательском наборе стикеров. Это основной этап работы, после которого картинки на выходе принимают схожий с референсом стиль.
Генерация текстового промпта с помощью BLIP . Этот подход используется для создания промпта для будущей генерации, так как от этого сильно зависит качество итогового стикера.
Генерация стикера по пользовательскому изображению и промпту из пункта 2.
Сегментация фона (дообученная модель SegFormer). Она нужна для избавления от различных шумов на фоне и придания ему прозрачности.
Ресурсы
Изначально эксперименты проводились в бесплатном Colab Notebook (одно ядро GPU T4), но через некоторое время появились неприятности: большие модели требуют больше видеопамяти, и даже запуски скриптов с использованием ускорителей (библиотеки xformers и accelerate) не сильно помогли: дообучение не прерывалось только при batch_size=1. После этого решили перейти на сервер VK Cloud (одно ядро GPU A100). Использовали наборы стикеров из Telegram, которые находятся в открытом доступе (всего около 7,5 тысяч изображений).
Файнтюн Stable Diffusion с помощью DreamBooth
1. text2img
Мы решили сперва дообучить DreamBooth на всём наборе собранных стикеров. В этом заключалась наша главная ошибка, ведь к обучению даже на таком относительно небольшом датасете (7,5 тыс. изображений) DreamBooth оказался совершенно не готов. Даже при batch_size=16 и lr=1е-7 лосс скакал с 0,001 до 0,.5, что крайне негативно сказывалось на качестве генерируемых картинок.
После этого решили дообучать модель на отдельных наборах стикеров. Это улучшило ситуацию: большинство сгенерированных изображений имело схожий с референсом стиль.
2. img2img
Попробовали альтернативное решение: зачем с помощью BLIP преобразовывать эмбеддинги картинки в текст, если у диффузионки можно поменять TextEncoder на ImageEncoder и не добавлять ещё одну модель в и без того тяжёлый пайплайн. Так и сделали: переписали код c Hugging Face для тренировочного цикла и инференса. Однако использовать это не удалось, потому что пришлось бы дообучать модель на наборе из пар «оригинальная и сгенерированная картинка», а подходящего для нашей задачи не было.
Сегментация
Во время обучения стало понятно, что придётся использовать сегментацию для выделения фона и его удаления, так как большинство стикеров не квадратной формы (у них прозрачный фон). В первой версии мы использовали предобученный UNet как одно из самых распространённых решений для задачи сегментации и удаления фона на изображении. Увы, качество сегментации предобученным UNet не дотягивало до удовлетворительного, поэтому впоследствии мы решили дообучить на всём нашем наборе модели UNet и SegFormer. Оценили их по метрике IoU, часто используемой для оценки качества работы сегментационных моделей:
UNet | SegFormer | |
IoU | 0,79 | 0,98 |
Про метрики
Для оценки качества работы нашего решения использовалось сразу два подхода: метрики для оценки качества генерируемых картинок и метрики для подбора промпта. Собрали небольшой набор из четырёх классов: trash — случайный шум и мусорные картинки, bad — примеры плохой генерации (сюда преимущественно вошли результаты работы сильно переученного Stable Diffusion), good — удачные примеры генерации, и painted — примеры стикеров, нарисованных вручную, которые мы берём за эталон по стилистике и качеству.
Цель ㅡ максимизировать разницу в показаниях метрик на каждом из наборов: для нарисованных стикеров метрика должна быть наивысшей; похуже для набора с хорошими примерами генерации; ещё хуже для набора с плохими; и совсем плохой результат — для набора с мусором. Так можно будет убедиться в релевантности показаний.
Метрики для подбора промптов
Для подбора промптов использовались метрики CLIP- и Pick-score.
CLIP-score основана на сравнении эмбеддингов изображений (визуальные признаки) и промптов (текстовые признаки) в общем пространстве. Нужно было найти такой промпт, эмбеддинг которого был бы максимально близок к хорошим примерам генерации и рисованным картинкам, но при этом максимально далёк от эмбедингов мусорных картинок и плохих примеров генерации.
Метрика Pick-score также основана на сравнении эмбеддингов картинок и промптов, но для обучения их извлечения используется другой набор данных. В нём для каждой картинки человеком был написан комментарий, отражающий, насколько картинка красива, качественна и т. д. В случае с имеющимся набором данных уже понятно, какие картинки хорошие, а какие плохие, и нужно найти промпт, который отображал бы эту разницу наиболее явно. То есть задача стояла та же, что и в случае с clip-ом: поиск промпта, максимально близкого к хорошим картинкам и максимально далёкого от плохих.
Было протестировано 200 уникальных промптов. Основываясь на показаниях обеих метрик, наиболее подходящим, на удивление, оказался лаконичный промпт a telegram sticker:
dataset | CLIP-score | Pick-score |
painted | 27,796 | 19,864 |
good | 27,659 | 19,784 |
bad | 25,269 | 19,059 |
trash | 17,872 | 17,647 |
А самым неподходящим ― a painted beautiful pretty sticker on monotonous black background:
dataset | CLIP-score | Pick-score |
painted | 18,107 | 18,4 |
good | 21,597 | 18,607 |
bad | 19,014 | 17,906 |
trash | 21,009 | 17,783 |
Метрики для оценки качества генерируемых картинок
Для оценки качества генерируемых картинок использовали сразу три метрики: FID, Inception Score и LPIPS. Они, как и метрики, описанные выше, создают из входных данных эмбеддинги и сравнивают их близость, но в этом случае мы работаем только с фичами изображений. Для FID и IS использовали дообученную на наборах стикеров модель ViT-B-16, а LPIPS работала на предобученном VGG. Результаты:
dataset | Inception Score | FID | ERGAS | LPIPS |
bad | 1,528 | 181,138 | 386,717 | 0,559 |
good | 1,219 | 207,136 | 274,662 | 0,491 |
painted | 1,434 | 101,073 | 149,079 | 0,344 |
trash | 1,619 | 132,298 | 249,227 | 0,776 |
Метрика LPIPS оказалась наиболее релевантной нашим данным: для painted она достигла 0,344, что близко к 0 (метрика LPIPS равна 0 для одинаковых изображений, так как при обработке их одной и той же сетью — в нашем случае предобученным VGG, — их выходные векторные представления будут идентичны, то есть. и ошибка между ними нулевая. А для trash ошибка, наоборот, близится к 1, что говорит о большой разнице в восприятии сетью vgg мусорных и эталонных изображений. Значения метрики для наборов good и bad также оказались в ожидаемом диапазоне.
Продуктовое решение
Для удобства взаимодействия с сервисом решили сделать Telegram—бот. С ним тоже было несколько проблем.
Из-за того, что суммарно наша архитектура весит около 5-7 Гб, решение получается довольно громоздким, то есть потребляет много видеопамяти.
Так как вся архитектура работает на GPU во время инференса, бот мог работать лишь с одним пользователем. В противном случае вылезала ошибка CUDA out of memory. В этом и раскрывается главный недостаток DreamBooth, так как для нормального использования потребуется придумать различные ухищрения: использовать несколько видеокарт, распараллелить обучение. Конечно, мы быстренько написали обработчик, чтобы бот не падал, однако проблема всё ещё не решна.
Не получилось реализовать многопоточность.
Выводы
Наиболее подходящим промптом генерации оказался a telegram sticker (по метрикам pick- и clip-score).
Наиболее релевантной для оценки качества генерируемых изображений на имеющихся данных оказалась метрика LPIPS (в сравнении с FID, IS, ERGAS).
Самой подходящей моделью для дообучения Stable Diffusion оказалась DreamBooth, благодаря своей способности дообучаться на крошечных наборах от 3-5 изображений. Тем не менее, к обучению на всём нашем массиве из нескольких тысяч стикеров, чтобы сперва показать Stable Diffusion основные паттерны Telegram-стикеров, DreamBooth оказался не готов, и модель быстро переобучилась.
Лучшим методом генерации стикеров с точки зрения качества ожидаемо оказался Stable Diffusion. Тем не менее, этот подход очень ресурсоёмкий: генерация одного стикера на GPU A100 занимает около минуты, а дообучение с DreamBooth даже на небольшом наборе стикеров — 5 минут. Помимо этого, дообученная модель занимает много места в памяти (до 7 Гб). Всё это сильно мешало бы масштабировать решение на Stable Diffusion.
Авторы: