Вариационные автоэнкодеры в квантованном векторном пространстве стали довольно популярными в последние несколько лет и успешно применяются в широком спектре генеративных задач (Stable Diffusion, VQ Diffusion, VideoGPT и др.). VQVAE позволяет сжимать картинку в латентное пространство меньшей размерности, а затем восстанавливать это латентное представление изображения в RGB-состояние. Операции в латентном пространстве выполняются быстро, поэтому VQVAE получил широкое применение как в авторегрессионных мультимодальных архитектурах (DALLE, ruDALL-E, RUDOLPH), так и в диффузионных моделях (DALL-E 2, Kandinsky 2.1, Latent Diffusion). В первом случае вариационный автоэнкодер позволяет закодировать картинку в последовательность визуальных токенов, которые вместе с текстовыми токенами используются в обучении трансформера. Во втором случае VQVAE кодирует картинку в квантованное пространство малой размерности, позволяя выполнять диффузионный процесс в латентном пространстве (ввиду того, что диффузионный процесс является итеративным и скорость генерации напрямую зависит от числа шагов диффузии, вычислительная сложность каждого шага очень важна), который в сравнении с пиксельной диффузией выполняется быстрее и потребляет меньше памяти. 

Во всех перечисленных задачах качество генерации напрямую зависит от качества восстановления исходных картинок с помощью VQVAE. Пару лет назад мы уже проводили эксперименты и обучали SBER-VQGAN, который на тот момент давал лучшие результаты в сравнении c dVAE и ванильным VQGAN. Подробнее об этих экспериментах можно прочитать в статье на Хабре. Однако по-прежнему нам не хватало качества восстановления в сложных доменах, таких как текст и лица, поэтому мы попытались модифицировать и улучшить SBER-VQGAN, в результате получив SoTA среди моделей по кодированию изображений.

Топ решений на сегодня

В 2022 году вышло несколько статей по улучшению VQGAN, среди которых ViT-VQGAN, RQ-VAE и MoVQ

ViT-VQGAN использует в энкодере/декодере Vision Transformer (ViT) вместо CNN. Энкодер ViT-VQGAN переводит непересекающиеся патчи картинок 8x8 в визуальные токены, которые затем подаются в трансформерные блоки. Декодер ViT-VQGAN переводит обратно визуальные токены из латентного пространства в патчи 8x8, формируя итоговую картинку. В статье описываются два нововведения: факторизация кодов (линейная проекция выхода энкодера размерностью 768 в 32-мерное латентное представление кодов) и L2 нормализация кодов (нормализация закодированных скрытых переменных z_e(x) и скрытых переменных кодовой книги e). В результате экспериментов ViT-VQGAN показал лучшие FID и IS в сравнении с стандартным VQGAN.

RQ-VAE вместо обычной квантизации использует residual квантизацию, которая позволяет улучшить качество восстановления картинки без увеличения размера кодовой книги. Суть её заключается в том, что квантизация латентного пространства происходит рекурсивно с заданной глубиной D. На первом шаге квантизуется латентное пространство и вычитается из исходного. На следующих шагах рекурсивно квантизуется уже остаток от латентного пространства. Таким образом, residual квантизация позволяет закодировать как более высокоуровневые представления, и так и более низкоуровневые детали. На каждой глубине квантизации используется одна и та же кодовая книга. Эксперименты показали, что RQ-VAE значительно превосходит VQGAN по FID в задаче восстановления изображений. Причём качество восстановления сильно зависит от глубины D: чем она больше, тем лучше метрики.

В статье MoVQ предлагается использовать spatially conditional нормализацию в блоках декодера, которая позволяет повысить реалистичность восстанавливаемых изображений. Spatially conditional нормализация помогает избежать возникновения артефактов, которые появляются из-за процесса квантования близких эмбеддингов в одинаковые индексы кодовой книги. Её использование добавляет степеней свободы квантованным представлениям после энкодера и позволяет распространять по слоям декодера более вариативные представления эмбеддингов. Также авторы предлагают использовать многоканальное представление латентного пространства, разделяя его на несколько частей по каналам и кодируя каждую часть по отдельности. Эти нововведения позволяют сократить размер кодовой книги с 16384 до 1024. MoVQGAN значительно превосходит описанные ранее подходы в задаче восстановления изображений по метрикам FID, SSIM, PSNR и LPIPS. 

Проанализировав новые подходы к реконструкции изображений, мы решили реализовать MoVQGAN с некоторыми модификациями. Во-первых, MoVQGAN оказался SoTA архитектурой среди аналогов, при этом является самой легковесной и быстрой моделью и требует небольших изменений исходного кода. Во-вторых, трансформерная архитектура ViT-VQGAN (самая близкая по метрикам) требовала намного больше ресурсов и времени на обучение, при этом привязана к размеру подаваемой картинки, что сильно ограничивает применимость такого визуального энкодера в других задачах. 

Архитектура SBER-MoVQGAN

Архитектура SBER-MoVQGAN основана на архитектуре VQGAN с добавлением spatially conditional нормализации из статьи MoVQ. Среди других важных особенностей:

  •  встроены EMA веса (exponential moving average); 

  •  изменены лоссы на этапе обучения (подробно расскажем ниже). 

За основу SBER-MoVQGAN взяли код из оригинального репозитория VQGAN. Код для обучения и инференса SBER-MoVQGAN можно найти на Github.

Spatially conditional нормализация реализована подобно AdaIN слоям (Adaptive Instance Normalization), которые применяются в архитектуре StyleGAN. AdaIN слои позволяют выполнять style transfer с помощью нормализации исходного вектора и добавления scale и bias вектора стиля. Однако в сравнении с этим методом spatially conditional карта признаков представлена как квантованная карта, которая, в свою очередь, содержит выученные компактные представления данных. Spatially conditional нормализация рассчитывается по формуле (1):

где Fi-1 — промежуточная карта признаков, μ(∙) и σ(∙) — функции расчёта среднего и стандартного отклонения (в качестве нормализации используется GroupNorm), φγ() и φβ() — обучаемые аффинные преобразования, которые реализованы как свёртки 1x1, предназначенные для преобразования zq в значения scale и bias. Таким образом, эта нормализация позволяет встраивать пространственно-вариативную информацию, при этом одинаковые квантованные представления генерируют правдоподобные и разнообразные результаты. На рисунке 1 представлена схема встраивания spatially conditional нормализации в блоки декодера.

Рис. 1. Spatially conditional-нормализация [1]

Стандартная функция потерь (лосс) для обучения VQGAN, используемая в официальном репозитории, сходится плохо, что не позволяет восстанавливать изображения в хорошем качестве. В ходе экспериментов было замечено, что функция потерь, описанная в статье ViT-VQGAN, даёт лучшие результаты, поэтому его модификация использовалась при обучении SBER-MoVQGAN. Функция потерь рассчитывается по формуле (2):

где LVQ — codebook loss, LAdv — adversarial loss, Lperc — perceptual loss, — reconstruction loss, LNLL — усреднённое значение (0,1Lperc+1,0L2). В оригинальном репозитории использовался адаптивный вес для LAdv и LNLL, однако в экспериментах этот подход показал плохую сходимость, и в результате мы от него отказались.

Добавление EMA весов при обучении SBER-MoVQGAN значительно повысило качество генерации: происходит расчёт экспоненциального скользящего среднего по всем обучаемым параметрам модели. SBER-MoVQGAN обучался с параметром ema_decay 0,9999. Это означает, что во время обучения на каждой итерации при обновлении весов модели сохраняется 99,99 % от предыдущих весов и обновляется только 0,01 % весов. Таким образом, EMA веса обладают лучшей обобщающей способностью, позволяют генерировать более стабильные результаты и борются с потенциальным переобучением модели. Однако их использование приводит к увеличению времени обучения. EMA веса успешно применяются при обучении диффузионных моделей и трансформеров, и в итоге также показали высокую эффективность при обучении SBER-MoVQGAN.

Выше были описаны основные изменения относительно исходного кода VQGAN, которые позволили повысить качество восстановления и обучить SoTA модель. Однако помимо этих изменений были проведены эксперименты с изменением дискриминатора, добавлением L2 нормализации кодов в квантизаторе, изменением текущего квантизатора на Gumbel квантизатор, уменьшением размера словаря кодовой книги и добавлением свёрточного слоя в spatially conditional нормализацию, но в результате они показали качество восстановления изображений чуть хуже, поэтому не будем на них останавливаться.

Версии SBER-MoVQGAN и подробности обучения

Мы обучили три версии SBER-MoVQGAN — 67M, 102M, 270M. Маленькая версия 67M по размерам получилась такой же, как и стандартный VQGAN. Разница между ними лишь в архитектурных изменениях, описанных выше. А вот с моделями 102M и 270M мы решили поэкспериментировать и проверить, как увеличение параметров влияет на качество восстановления изображений. Модель 102M использует в два раза больше Residual блоков в сравнении с 67M, а модель 270M работает с количеством каналов в два раза больше исходного. Маленькую модель мы обучали на 1xA100 в течение двух недель, и за это время она вышла на плато. Большие модели мы обучали на 8xA100. Версия 102M вышла на плато за четыре недели, а версия 270M — в течение двух недель на 8xA100, и затем в течение 10 дней — на 32xA100. Более подробная информация представлена в таблице 1. Полученные результаты и метрики трёх версий SBER-MoVQGAN будут описаны ниже. Все модели обучались на наборе LAION HighRes.

Таблица 1. Информация о моделях SBER-MoVQGAN.

Модель

Архитектура

GPU

Train steps

SBER-MoVQGAN 67M

2 Residual-блока x channels

1xA100

2M

SBER-MoVQGAN 102M

4 Residual-блока

8xA100

2,3M

SBER-MoVQGAN 270M

x2 channels

8xA100 и 32xA100

920K и 410K

Эксперименты

SBER-MoVQGAN тестировался двумя способами. В первом случае рассчитывалась метрика FID по доменам на наборе MS COCO и сравнивалось с ранее обученным SBER-VQGAN. Подробное описание распределения COCO набора по доменам и прошлые эксперименты описаны в нашей статье. В таблице 2 представлены метрики трёх версий SBER-MoVQGAN, и в результате показано, что все три версии значительно обходят предыдущую нашу модель.

Во втором случае оценивалось качество модели по метрикам FID, SSIM, PSNR и L1 и выполнялось сравнение с аналогичными моделями на наборе Imagenet. Подробнее про метрики SSIM и PSNR можно прочитать здесь, про FID — в этой статье. Отметим только, что FID сравнивает распределения реальных и сгенерированных изображений. Если картинки полностью совпадают, то FID равен 0. SSIM — это метрика оценки изображений по трём параметрам: яркости, контрастности и структуре. Она принимает значения от -1 до 1, при этом чем выше значение, тем ниже искажения изображения и выше качество. PSNR — метрика, показывающая пиковое отношение сигнал/шум. PSNR определяет уровень искажений при сжатии и включает в себя подсчёт среднеквадратичной ошибки (MSE); диапазон принимаемых значений от 0 до 100. Метрика L1 сравнивает изображения попиксельно и показывает среднюю абсолютную ошибку. Если картинки полностью идентичны, она будет равна 0. 

В таблицах 3 и 4 приведено сравнение моделей ViT-VQGAN, RQVAE, MoVQGAN, VQGAN, KL-VAE, SBER-VQGAN и SBER-MoVQGAN по метрикам FID, SSIM, PSNR и L1 на наборах Imagenet и FFHQ. Для первых трёх моделей значения метрик взяты напрямую из статей, для остальных мы сами посчитали (для моделей VQGAN и KL-VAE чекпоинты были взяты из репозитория latent-diffusion). В результате все три версии SBER-MoVQGAN превосходят описанные модели по качеству реконструкции изображений.

Таблица 2. Оценка моделей по метрике FID на наборе MS COCO, распределённом по доменам (чем меньше значение, тем лучше).

Домен/модель

SBER-VQGAN

SBER-MOVQGAN 67M

SBER-MOVQGAN 102M

SBER-MOVQGAN 270M

all

30.136

18.4724

17.081

16.448

indoor

44.686

26.509

24.274

23.164

kitchen

36.579

22.508

21.019

20.164

appliance

52.064

27.145

25.167

23.686

electronic

50.447

29.417

27.08

25.743

furniture

29.569

18.211

16.909

16.271

outdoor

45.287

28.438

26.155

24.959

sports

31.756

19.04

17.727

17.256

food

41.413

25.001

23.221

22.754

vehicle

26.463

17.095

15.616

15.101

animal

32.078

21.217

19.777

18.745

accessory

44.454

27.467

25.342

24.524

person

15.484

9.48

8.753

8.528

face

26.750

16.746

15.545

15.076

text

21.148

12.883

11.834

11.392

Таблица 3. Сравнение аналогичных моделей и SBER-MoVQGAN по метрикам FID, SSIM, PSNR, L1 на наборе Imagenet. * означает, что метрики для этих моделей взяты напрямую из статей.

Модель

Latent size

Num Z

Compression

Train steps

FID ↓

SSIM ↑

PSNR ↑

L1

ViT-VQGAN* 128 TPUv4

32x32

8192

8

500000

1.28

-

-

-

RQ-VAE*

8x8x16

16384

32

10 epochs

1.83

-

-

-

Mo-VQGAN* 4 Tesla V100

16x16x4

1024

16

40 epochs

1.12

0.6731

22.42

-

VQ CompVis

32x32

16384

8

971043

1.34

0.6499

23.8469

0.0533

KL CompVis

32x32

-

8

246803

0.9682

0.6918

25.1121 

0.0474

SBER-VQGAN (from pretrain vqgan_gumbel_f8)

32x32

8192

8

1 epoch

1.4378

0.6816

24.3135

0.0503

SBER-MoVQGAN 67M 

32x32

1024

8

5M

1.34

0.7041

25.6778

0.0451

SBER-MoVQGAN 67M

32x32

16384

8

2M

0.9647

0.7249

26.4485

0.0415

SBER-MoVQGAN 102M

32x32

16384

8

2360k

0.7764

0.7373

26.8887

0.0398

SBER-MoVQGAN 270M

32x32

16384

8

1330k

0.6858

0.7411

27.0370

0.0393

Таблица 4. Сравнение аналогичных моделей и SBER-MoVQGAN по метрикам FID, SSIM, PSNR и L1 на наборе FFHQ. * означает, что метрики для этих моделей взяты напрямую из статей.

Модель

Latent size

Num Z

Compression

Train steps

FID ↓

SSIM ↑

PSNR ↑

L1 ↓

ViT-VQGAN* 128 TPUv4

32x32

8192

8

500000

3.13

-

-

-

RQ-VAE*

16x16x4

2048

16

10 epochs

3.88

0.7602

24.53

-

Mo-VQGAN* 4 Tesla V100

16x16x4

1024

16

40 epochs

2.26

0.8212

26.72

-

VQ CompVis

32x32

16384

8

971043

2.9858

0.7648

27.2309

0.0330

KL CompVis

32x32

-

8

246803

2.0428

0.8071

28.7890

0.0285

SBER-VQGAN (from pretrain vqgan_gumbel_f8)

32x32

8192

8

1 epoch

3.3081

0.7912

27.5646

0.0323

SBER-MoVQGAN 67M 

32x32

1024

8

5M

2.5829

0.8063

28.9378

0.0284

SBER-MoVQGAN 67M

32x32

16384

8

2M

1.9980

0.8248

29.8859

0.0253

SBER-MoVQGAN 102M

32x32

16384

8

2360k

1.8245

0.8340

30.3245

0.0241

SBER-MoVQGAN 270M

32x32

16384

8

1330k

1.7592

0.8365

30.4576

0.0238

Также мы оценили время инференса всех трёх версий SBER-MoVQGAN, и заметили, что большая модель 270М практически не уступает маленькой 67M, поэтому именно её можно использовать в дальнейших проектах при обучении латентной диффузии. Результаты замеров времени инференса на батче 1 приведены в таблице 5. 

Таблица 5. Время инференса SBER-MoVQGAN.

Модель

Inference time

f=8, SBER-MoVQGAN 67M

23,1 мс

f=8, SBER-MoVQGAN 102M

33,6 мс

f=8, SBER-MoVQGAN 270M

23,6 мс

Таким образом, все три версии модели SBER-MoVQGAN от 67М до 270М являются новой SoTA в задаче реконструкции изображений (предыдущей SoTA моделью был MoVQ). 

Примеры генераций SBER-MoVQGAN

В этой главе представлены примеры реконструкции изображений всеми версиями SBER-MoVQGAN, обученным на разрешении 256x256, в таких трудновосстанавливаемых доменах, как лица, текст и другие сложные сцены. Для каждой реконструкции были рассчитаны метрики MSE и MAE и построена разностная картинка между оригинальным изображением и восстановленным. И в результате видно, что значительно выросло качество реконструкции SBER-MoVQGAN относительно нашей прошлой модели SBER-VQGAN. Также несмотря на то, что SBER-MoVQGAN обучался на разрешении 256x256, эта модель также может работать и с другими разрешениями изображений. При этом чем оно выше, тем лучше получается качество восстановления картинок. 

Домены лицо, человек, толпа:

Домен текст:

Домен сложные сцены:

Выводы и планы

В результате нам удалось в задаче реконструкции изображений обучить SoTA модель, которая теперь может успешно применяться в различных проектах: при обучении как диффузионных, так и мультимодальных моделей. SBER-MoVQGAN 67M был успешно внедрён в Kandinsky 2.1 и стал одним из блоков архитектуры, которые позволили существенным образом повысить качество генераций изображений по тексту. В будущем также планируем использовать лучшую версию SBER-MoVQGAN при обучении разрабатываемых генеративных моделей. Код и веса можно найти на Github и Huggingface.

Коллектив авторов: Арсений Шахматов, Анастасия Мальцева, Андрей Кузнецов, Денис Димитров.