SBER-MoVQGAN или новый эффективный Image Encoder для генеративных моделей
Вариационные автоэнкодеры в квантованном векторном пространстве стали довольно популярными в последние несколько лет и успешно применяются в широком спектре генеративных задач (Stable Diffusion, VQ Diffusion, VideoGPT и др.). VQVAE позволяет сжимать картинку в латентное пространство меньшей размерности, а затем восстанавливать это латентное представление изображения в RGB-состояние. Операции в латентном пространстве выполняются быстро, поэтому VQVAE получил широкое применение как в авторегрессионных мультимодальных архитектурах (DALLE, ruDALL-E, RUDOLPH), так и в диффузионных моделях (DALL-E 2, Kandinsky 2.1, Latent Diffusion). В первом случае вариационный автоэнкодер позволяет закодировать картинку в последовательность визуальных токенов, которые вместе с текстовыми токенами используются в обучении трансформера. Во втором случае VQVAE кодирует картинку в квантованное пространство малой размерности, позволяя выполнять диффузионный процесс в латентном пространстве (ввиду того, что диффузионный процесс является итеративным и скорость генерации напрямую зависит от числа шагов диффузии, вычислительная сложность каждого шага очень важна), который в сравнении с пиксельной диффузией выполняется быстрее и потребляет меньше памяти.
Во всех перечисленных задачах качество генерации напрямую зависит от качества восстановления исходных картинок с помощью VQVAE. Пару лет назад мы уже проводили эксперименты и обучали SBER-VQGAN, который на тот момент давал лучшие результаты в сравнении c dVAE и ванильным VQGAN. Подробнее об этих экспериментах можно прочитать в статье на Хабре. Однако по-прежнему нам не хватало качества восстановления в сложных доменах, таких как текст и лица, поэтому мы попытались модифицировать и улучшить SBER-VQGAN, в результате получив SoTA среди моделей по кодированию изображений.
Топ решений на сегодня
В 2022 году вышло несколько статей по улучшению VQGAN, среди которых ViT-VQGAN, RQ-VAE и MoVQ.
ViT-VQGAN использует в энкодере/декодере Vision Transformer (ViT) вместо CNN. Энкодер ViT-VQGAN переводит непересекающиеся патчи картинок 8x8 в визуальные токены, которые затем подаются в трансформерные блоки. Декодер ViT-VQGAN переводит обратно визуальные токены из латентного пространства в патчи 8x8, формируя итоговую картинку. В статье описываются два нововведения: факторизация кодов (линейная проекция выхода энкодера размерностью 768 в 32-мерное латентное представление кодов) и L2 нормализация кодов (нормализация закодированных скрытых переменных z_e(x) и скрытых переменных кодовой книги e). В результате экспериментов ViT-VQGAN показал лучшие FID и IS в сравнении с стандартным VQGAN.
RQ-VAE вместо обычной квантизации использует residual квантизацию, которая позволяет улучшить качество восстановления картинки без увеличения размера кодовой книги. Суть её заключается в том, что квантизация латентного пространства происходит рекурсивно с заданной глубиной D. На первом шаге квантизуется латентное пространство и вычитается из исходного. На следующих шагах рекурсивно квантизуется уже остаток от латентного пространства. Таким образом, residual квантизация позволяет закодировать как более высокоуровневые представления, и так и более низкоуровневые детали. На каждой глубине квантизации используется одна и та же кодовая книга. Эксперименты показали, что RQ-VAE значительно превосходит VQGAN по FID в задаче восстановления изображений. Причём качество восстановления сильно зависит от глубины D: чем она больше, тем лучше метрики.
В статье MoVQ предлагается использовать spatially conditional нормализацию в блоках декодера, которая позволяет повысить реалистичность восстанавливаемых изображений. Spatially conditional нормализация помогает избежать возникновения артефактов, которые появляются из-за процесса квантования близких эмбеддингов в одинаковые индексы кодовой книги. Её использование добавляет степеней свободы квантованным представлениям после энкодера и позволяет распространять по слоям декодера более вариативные представления эмбеддингов. Также авторы предлагают использовать многоканальное представление латентного пространства, разделяя его на несколько частей по каналам и кодируя каждую часть по отдельности. Эти нововведения позволяют сократить размер кодовой книги с 16384 до 1024. MoVQGAN значительно превосходит описанные ранее подходы в задаче восстановления изображений по метрикам FID, SSIM, PSNR и LPIPS.
Проанализировав новые подходы к реконструкции изображений, мы решили реализовать MoVQGAN с некоторыми модификациями. Во-первых, MoVQGAN оказался SoTA архитектурой среди аналогов, при этом является самой легковесной и быстрой моделью и требует небольших изменений исходного кода. Во-вторых, трансформерная архитектура ViT-VQGAN (самая близкая по метрикам) требовала намного больше ресурсов и времени на обучение, при этом привязана к размеру подаваемой картинки, что сильно ограничивает применимость такого визуального энкодера в других задачах.
Архитектура SBER-MoVQGAN
Архитектура SBER-MoVQGAN основана на архитектуре VQGAN с добавлением spatially conditional нормализации из статьи MoVQ. Среди других важных особенностей:
встроены EMA веса (exponential moving average);
изменены лоссы на этапе обучения (подробно расскажем ниже).
За основу SBER-MoVQGAN взяли код из оригинального репозитория VQGAN. Код для обучения и инференса SBER-MoVQGAN можно найти на Github.
Spatially conditional нормализация реализована подобно AdaIN слоям (Adaptive Instance Normalization), которые применяются в архитектуре StyleGAN. AdaIN слои позволяют выполнять style transfer с помощью нормализации исходного вектора и добавления scale и bias вектора стиля. Однако в сравнении с этим методом spatially conditional карта признаков представлена как квантованная карта, которая, в свою очередь, содержит выученные компактные представления данных. Spatially conditional нормализация рассчитывается по формуле (1):
где Fi-1 — промежуточная карта признаков, μ(∙) и σ(∙) — функции расчёта среднего и стандартного отклонения (в качестве нормализации используется GroupNorm), φγ(∙) и φβ(∙) — обучаемые аффинные преобразования, которые реализованы как свёртки 1x1, предназначенные для преобразования zq в значения scale и bias. Таким образом, эта нормализация позволяет встраивать пространственно-вариативную информацию, при этом одинаковые квантованные представления генерируют правдоподобные и разнообразные результаты. На рисунке 1 представлена схема встраивания spatially conditional нормализации в блоки декодера.
Стандартная функция потерь (лосс) для обучения VQGAN, используемая в официальном репозитории, сходится плохо, что не позволяет восстанавливать изображения в хорошем качестве. В ходе экспериментов было замечено, что функция потерь, описанная в статье ViT-VQGAN, даёт лучшие результаты, поэтому его модификация использовалась при обучении SBER-MoVQGAN. Функция потерь рассчитывается по формуле (2):
где LVQ — codebook loss, LAdv — adversarial loss, Lperc — perceptual loss, — reconstruction loss, LNLL — усреднённое значение (0,1Lperc+1,0L2). В оригинальном репозитории использовался адаптивный вес для LAdv и LNLL, однако в экспериментах этот подход показал плохую сходимость, и в результате мы от него отказались.
Добавление EMA весов при обучении SBER-MoVQGAN значительно повысило качество генерации: происходит расчёт экспоненциального скользящего среднего по всем обучаемым параметрам модели. SBER-MoVQGAN обучался с параметром ema_decay 0,9999. Это означает, что во время обучения на каждой итерации при обновлении весов модели сохраняется 99,99 % от предыдущих весов и обновляется только 0,01 % весов. Таким образом, EMA веса обладают лучшей обобщающей способностью, позволяют генерировать более стабильные результаты и борются с потенциальным переобучением модели. Однако их использование приводит к увеличению времени обучения. EMA веса успешно применяются при обучении диффузионных моделей и трансформеров, и в итоге также показали высокую эффективность при обучении SBER-MoVQGAN.
Выше были описаны основные изменения относительно исходного кода VQGAN, которые позволили повысить качество восстановления и обучить SoTA модель. Однако помимо этих изменений были проведены эксперименты с изменением дискриминатора, добавлением L2 нормализации кодов в квантизаторе, изменением текущего квантизатора на Gumbel квантизатор, уменьшением размера словаря кодовой книги и добавлением свёрточного слоя в spatially conditional нормализацию, но в результате они показали качество восстановления изображений чуть хуже, поэтому не будем на них останавливаться.
Версии SBER-MoVQGAN и подробности обучения
Мы обучили три версии SBER-MoVQGAN — 67M, 102M, 270M. Маленькая версия 67M по размерам получилась такой же, как и стандартный VQGAN. Разница между ними лишь в архитектурных изменениях, описанных выше. А вот с моделями 102M и 270M мы решили поэкспериментировать и проверить, как увеличение параметров влияет на качество восстановления изображений. Модель 102M использует в два раза больше Residual блоков в сравнении с 67M, а модель 270M работает с количеством каналов в два раза больше исходного. Маленькую модель мы обучали на 1xA100 в течение двух недель, и за это время она вышла на плато. Большие модели мы обучали на 8xA100. Версия 102M вышла на плато за четыре недели, а версия 270M — в течение двух недель на 8xA100, и затем в течение 10 дней — на 32xA100. Более подробная информация представлена в таблице 1. Полученные результаты и метрики трёх версий SBER-MoVQGAN будут описаны ниже. Все модели обучались на наборе LAION HighRes.
Таблица 1. Информация о моделях SBER-MoVQGAN.
Модель | Архитектура | GPU | Train steps |
SBER-MoVQGAN 67M | 2 Residual-блока x channels | 1xA100 | 2M |
SBER-MoVQGAN 102M | 4 Residual-блока | 8xA100 | 2,3M |
SBER-MoVQGAN 270M | x2 channels | 8xA100 и 32xA100 | 920K и 410K |
Эксперименты
SBER-MoVQGAN тестировался двумя способами. В первом случае рассчитывалась метрика FID по доменам на наборе MS COCO и сравнивалось с ранее обученным SBER-VQGAN. Подробное описание распределения COCO набора по доменам и прошлые эксперименты описаны в нашей статье. В таблице 2 представлены метрики трёх версий SBER-MoVQGAN, и в результате показано, что все три версии значительно обходят предыдущую нашу модель.
Во втором случае оценивалось качество модели по метрикам FID, SSIM, PSNR и L1 и выполнялось сравнение с аналогичными моделями на наборе Imagenet. Подробнее про метрики SSIM и PSNR можно прочитать здесь, про FID — в этой статье. Отметим только, что FID сравнивает распределения реальных и сгенерированных изображений. Если картинки полностью совпадают, то FID равен 0. SSIM — это метрика оценки изображений по трём параметрам: яркости, контрастности и структуре. Она принимает значения от -1 до 1, при этом чем выше значение, тем ниже искажения изображения и выше качество. PSNR — метрика, показывающая пиковое отношение сигнал/шум. PSNR определяет уровень искажений при сжатии и включает в себя подсчёт среднеквадратичной ошибки (MSE); диапазон принимаемых значений от 0 до 100. Метрика L1 сравнивает изображения попиксельно и показывает среднюю абсолютную ошибку. Если картинки полностью идентичны, она будет равна 0.
В таблицах 3 и 4 приведено сравнение моделей ViT-VQGAN, RQVAE, MoVQGAN, VQGAN, KL-VAE, SBER-VQGAN и SBER-MoVQGAN по метрикам FID, SSIM, PSNR и L1 на наборах Imagenet и FFHQ. Для первых трёх моделей значения метрик взяты напрямую из статей, для остальных мы сами посчитали (для моделей VQGAN и KL-VAE чекпоинты были взяты из репозитория latent-diffusion). В результате все три версии SBER-MoVQGAN превосходят описанные модели по качеству реконструкции изображений.
Таблица 2. Оценка моделей по метрике FID на наборе MS COCO, распределённом по доменам (чем меньше значение, тем лучше).
Домен/модель | SBER-VQGAN | SBER-MOVQGAN 67M | SBER-MOVQGAN 102M | SBER-MOVQGAN 270M |
all | 30.136 | 18.4724 | 17.081 | 16.448 |
indoor | 44.686 | 26.509 | 24.274 | 23.164 |
kitchen | 36.579 | 22.508 | 21.019 | 20.164 |
appliance | 52.064 | 27.145 | 25.167 | 23.686 |
electronic | 50.447 | 29.417 | 27.08 | 25.743 |
furniture | 29.569 | 18.211 | 16.909 | 16.271 |
outdoor | 45.287 | 28.438 | 26.155 | 24.959 |
sports | 31.756 | 19.04 | 17.727 | 17.256 |
food | 41.413 | 25.001 | 23.221 | 22.754 |
vehicle | 26.463 | 17.095 | 15.616 | 15.101 |
animal | 32.078 | 21.217 | 19.777 | 18.745 |
accessory | 44.454 | 27.467 | 25.342 | 24.524 |
person | 15.484 | 9.48 | 8.753 | 8.528 |
face | 26.750 | 16.746 | 15.545 | 15.076 |
text | 21.148 | 12.883 | 11.834 | 11.392 |
Таблица 3. Сравнение аналогичных моделей и SBER-MoVQGAN по метрикам FID, SSIM, PSNR, L1 на наборе Imagenet. * означает, что метрики для этих моделей взяты напрямую из статей.
Модель | Latent size | Num Z | Compression | Train steps | FID ↓ | SSIM ↑ | PSNR ↑ | L1 ↓ |
ViT-VQGAN* 128 TPUv4 | 32x32 | 8192 | 8 | 500000 | 1.28 | - | - | - |
8x8x16 | 16384 | 32 | 10 epochs | 1.83 | - | - | - | |
Mo-VQGAN* 4 Tesla V100 | 16x16x4 | 1024 | 16 | 40 epochs | 1.12 | 0.6731 | 22.42 | - |
VQ CompVis | 32x32 | 16384 | 8 | 971043 | 1.34 | 0.6499 | 23.8469 | 0.0533 |
KL CompVis | 32x32 | - | 8 | 246803 | 0.9682 | 0.6918 | 25.1121 | 0.0474 |
SBER-VQGAN (from pretrain vqgan_gumbel_f8) | 32x32 | 8192 | 8 | 1 epoch | 1.4378 | 0.6816 | 24.3135 | 0.0503 |
SBER-MoVQGAN 67M | 32x32 | 1024 | 8 | 5M | 1.34 | 0.7041 | 25.6778 | 0.0451 |
SBER-MoVQGAN 67M | 32x32 | 16384 | 8 | 2M | 0.9647 | 0.7249 | 26.4485 | 0.0415 |
SBER-MoVQGAN 102M | 32x32 | 16384 | 8 | 2360k | 0.7764 | 0.7373 | 26.8887 | 0.0398 |
SBER-MoVQGAN 270M | 32x32 | 16384 | 8 | 1330k | 0.6858 | 0.7411 | 27.0370 | 0.0393 |
Таблица 4. Сравнение аналогичных моделей и SBER-MoVQGAN по метрикам FID, SSIM, PSNR и L1 на наборе FFHQ. * означает, что метрики для этих моделей взяты напрямую из статей.
Модель | Latent size | Num Z | Compression | Train steps | FID ↓ | SSIM ↑ | PSNR ↑ | L1 ↓ |
ViT-VQGAN* 128 TPUv4 | 32x32 | 8192 | 8 | 500000 | 3.13 | - | - | - |
16x16x4 | 2048 | 16 | 10 epochs | 3.88 | 0.7602 | 24.53 | - | |
Mo-VQGAN* 4 Tesla V100 | 16x16x4 | 1024 | 16 | 40 epochs | 2.26 | 0.8212 | 26.72 | - |
VQ CompVis | 32x32 | 16384 | 8 | 971043 | 2.9858 | 0.7648 | 27.2309 | 0.0330 |
KL CompVis | 32x32 | - | 8 | 246803 | 2.0428 | 0.8071 | 28.7890 | 0.0285 |
SBER-VQGAN (from pretrain vqgan_gumbel_f8) | 32x32 | 8192 | 8 | 1 epoch | 3.3081 | 0.7912 | 27.5646 | 0.0323 |
SBER-MoVQGAN 67M | 32x32 | 1024 | 8 | 5M | 2.5829 | 0.8063 | 28.9378 | 0.0284 |
SBER-MoVQGAN 67M | 32x32 | 16384 | 8 | 2M | 1.9980 | 0.8248 | 29.8859 | 0.0253 |
SBER-MoVQGAN 102M | 32x32 | 16384 | 8 | 2360k | 1.8245 | 0.8340 | 30.3245 | 0.0241 |
SBER-MoVQGAN 270M | 32x32 | 16384 | 8 | 1330k | 1.7592 | 0.8365 | 30.4576 | 0.0238 |
Также мы оценили время инференса всех трёх версий SBER-MoVQGAN, и заметили, что большая модель 270М практически не уступает маленькой 67M, поэтому именно её можно использовать в дальнейших проектах при обучении латентной диффузии. Результаты замеров времени инференса на батче 1 приведены в таблице 5.
Таблица 5. Время инференса SBER-MoVQGAN.
Модель | Inference time |
f=8, SBER-MoVQGAN 67M | 23,1 мс |
f=8, SBER-MoVQGAN 102M | 33,6 мс |
f=8, SBER-MoVQGAN 270M | 23,6 мс |
Таким образом, все три версии модели SBER-MoVQGAN от 67М до 270М являются новой SoTA в задаче реконструкции изображений (предыдущей SoTA моделью был MoVQ).
Примеры генераций SBER-MoVQGAN
В этой главе представлены примеры реконструкции изображений всеми версиями SBER-MoVQGAN, обученным на разрешении 256x256, в таких трудновосстанавливаемых доменах, как лица, текст и другие сложные сцены. Для каждой реконструкции были рассчитаны метрики MSE и MAE и построена разностная картинка между оригинальным изображением и восстановленным. И в результате видно, что значительно выросло качество реконструкции SBER-MoVQGAN относительно нашей прошлой модели SBER-VQGAN. Также несмотря на то, что SBER-MoVQGAN обучался на разрешении 256x256, эта модель также может работать и с другими разрешениями изображений. При этом чем оно выше, тем лучше получается качество восстановления картинок.
Домены лицо, человек, толпа:
Домен текст:
Домен сложные сцены:
Выводы и планы
В результате нам удалось в задаче реконструкции изображений обучить SoTA модель, которая теперь может успешно применяться в различных проектах: при обучении как диффузионных, так и мультимодальных моделей. SBER-MoVQGAN 67M был успешно внедрён в Kandinsky 2.1 и стал одним из блоков архитектуры, которые позволили существенным образом повысить качество генераций изображений по тексту. В будущем также планируем использовать лучшую версию SBER-MoVQGAN при обучении разрабатываемых генеративных моделей. Код и веса можно найти на Github и Huggingface.
Коллектив авторов: Арсений Шахматов, Анастасия Мальцева, Андрей Кузнецов, Денис Димитров.