Pull to refresh
840.25
Сбер
Больше чем банк

Как мы тестировали и дообучали одну из самых хайповых разработок года

Reading time5 min
Views5.6K

 

Всем привет!

Всё началось с того, что мы в Sber AI решили немного поизучать/почитать подробнее про хайповую нейронную сеть DALL·E и понять её потенциал возможностей, а также в чём заключается боттлнек – что же мешает генерить картинки хорошего качества и как можно попытаться улучшить работу модели?

Все автоэнкодеры важны, все автоэнкодеры нужны

О том, что такое автоэнкодеры (AE, VAE, CVAE и др.), уже много писали на Хабре и не только (можно почитать здесь, здесь и здесь), так что углубляться не будем. В двух словах: это нейронная сеть, состоящая из двух частей, – энкодера и декодера. Энкодер сжимает входной сигнал (например картинку) в какое-то скрытое состояние меньшей размерности. Декодер переводит скрытое состояние в исходный сигнал.

Таким образом, есть возможность «обучить» свой архиватор. Сами автоэнкодеры уже существуют давно, ими никого не удивишь. Однако фантазия исследователей позволяет и по сей день ставить занимательные эксперименты. Так, например, идея работы DALL·E пришла после создания так называемого VQ-VAE и последующего его улучшения в VQ-VAE 2, а потом ещё и в VQ-GAN).

Причём тут вообще DALL-E?

Краткая вводная для тех, кто ещё не в курсе. DALL·E – это нейронная сеть, которая создаёт изображения, основываясь на текстовом описании на естественном языке. Например, вот что она может создать по тексту «a snail made of harp» (улитка, сделанная из арфы).

Принцип работы DALL·E устроен довольно просто, его можно сравнить с хорошо известной генерацией текста «Sequence to Sequence» с помощью Transformer. Для генерации вместо привычных текстовых токенов используются токены пространства codebook, которые отображают скрытое дискретное состояние изображений после процедуры энкодинга VAE. В качестве трансформера используется одна из версий GPT-3 на 12 миллиардов параметров, а подают в него эмбеддинги текстового запроса на естественном языке и эмбеддинги codebook текущего состояния генерации изображения, конкатенированные между собой. Затем сгенерённая последовательность codebook подаётся в декодинг VAE, который восстанавливает изображение.

Таким образом, очевидным bottleneck’ом для данной архитектуры является автоэнкодер если он не способен восстановить исходное изображение процедурой encode-decode с приемлемым качеством, то и DALL·E никогда не сможет сгенерить изображение хорошего качества.

Мы решили проверить качество на различных доменах некоторых опубликованных дискретных VAE и убедиться, что они действительно справляются со своей задачей.

А чем оригинальный энкодер не угодил?

Как уже было сказано выше, есть как хорошо генерируемые домены, так и плохо генерируемые. Так вот, мы решили сделать подборку изображений из датасета COCO по некоторым субъективно выбранным проблемным доменам. Затем с помощью этой подборки оценили качество на моделях:

  • “16384”: VQGAN ImageNet (f=16), 16384

  • “VAE”: DALL-E dVAE (f=8), 8192, GumbelQuantization

  • “gumbelf8”: VQGAN OpenImages (f=8), 8192, GumbelQuantization

  • “SBER-gumbelf8”: VQGAN SberData (f=8), 8192, GumbelQuantization

Последний из списка мы дообучали на собственных собранных данных. Стоит отметить, что все выбранные модели не видели датасет COCO в процессе обучения, и оценка на этом датасете достаточно честная и независимая. Естественно, для улучшения качества наши данные были дополнены примерами из этих доменов. Мы лишь дофайнтюнили предобученные модели из оригинального репозитория на одну эпоху с warmup learning rate до 1.5e-06.

Вы всё ещё обучаете с нуля? Тогда мы идём к вам!

Для оценки качества восстановленных изображений мы использовали метрики Inception Score (IS) и Fréchet inception distance (FID), к сожалению, первая статья по FID куда-то пропала, поэтому предлагаем посмотреть такой вариант.

Что такое IS и как его посчитать, можно почитать тут или тут. Если вкратце, то для расчёта используется всем известная модель нейронной сети Inception v3. С помощью неё оценивается вероятность принадлежности изображений каждому из 1000 классов. Затем все распределения вероятностей по одинаковым классам суммируются, считается KL-дивергенция между этими распределениями каждого индивидуального класса и средним общим распределением. Оценка принимает значения от 1 до N, где N – количество классов. Чем больше значения вероятностей принадлежности к каким-то классам и чем больше разнообразие этих классов, тем больше IS.

Похожая ситуация и с FID, подробное описание и расчёт тут тоже есть. Для FID, однако, уже необходим датасет, с которым можно сравнить сгенерённые картинки, в нашем случае это просто набор исходных картинок. Снова используется та же модель Inception, но уже без головы, и сравниваются распределения признаков оригинальных и сгенерённых изображений. Чем ближе распределения, тем больше наши картинки похожи на настоящие.

Дообучение на одну эпоху дало нам следующие результаты, поле all рассчитано как среднее взвешенное всех категорий на их количество:

IS (больше – лучше):

domain/model

VAE

16384

gumbelf8

SBER-gumbelf8

Original

all

11.133

13.647

15.203

15.316

15.278

indoor

9.769

10.744

11.707

11.688

11.638

kitchen

9.726

11.354

12.333

12.152

11.813

appliance

5.705

6.024

6.154

6.199

5.890

electronic

7.830

9.509

9.712

9.606

9.497

furniture

10.861

13.346

14.500

14.531

14.592

outdoor

8.163

9.520

10.668

10.293

10.451

sports

7.467

8.544

8.814

8.841

8.962

food

7.954

8.725

9.390

9.434

9.191

vehicle

10.527

12.947

14.240

14.559

14.233

animal

11.933

14.249

15.999

15.879

15.857

accessory

9.399

11.687

13.117

13.388

13.228

person

13.752

17.794

20.048

20.420

20.600

face

11.903

14.987

16.986

17.489

17.584

text

14.902

18.457

21.396

21.292

21.131

FID (меньше – лучше):

domain/model

VAE

16384

gumbelf8

SBER-gumbelf8

all

59.753

38.912

30.304

30.136

indoor

74.734

57.925

45.432

44.686

kitchen

66.424

47.086

36.735

36.579

appliance

80.359

70.604

53.225

52.064

electronic

77.856

64.034

50.759

50.447

furniture

53.438

38.204

29.510

29.569

outdoor

91.932

58.877

46.309

45.287

sports

65.540

39.961

32.219

31.756

food

76.974

53.109

41.018

41.413

vehicle

60.318

34.259

26.721

26.463

animal

64.250

41.520

32.039

32.078

accessory

79.843

56.311

44.660

44.454

person

40.810

20.430

15.523

15.484

face

54.153

34.109

26.663

26.750

text

47.299

27.656

21.303

21.148

Также мы посмотрели, как Inception Score оценивает оригинальные изображения (без процедуры encode-decode). Интересно, что для данной метрики оригиналы выглядят менее естественными, чем то, что сгенерили некоторые автоэнкодеры (!). Вероятнее всего, мы уже упёрлись в возможности данной метрики, оценка с её помощью уже становится менее объективной. Предлагаем «дедовскими» методами (глазами) посмотреть на генерации. И вот что получилось с проблемными доменами:  визуально качество заметно улучшилось для SBER-gumbelf8:

«Текст»

«Несколько людей»

«Лица»

Не моё, а общее

Мы хотим поделиться с вами тем, что у нас получилось. А именно – дообученными чекпоинтами и примерами кода, который поможет вам использовать эти модели. Всё это есть на Google-диске и в колабе, так что приятного использования!  

Колаб с инференсом моделек.

Колаб с расчётом метрик.

Общая папка с моделями и ноутбуками.

Ссылка на гитхаб

На самом деле очень интересно, какие примеры у вас получатся. Спасибо, что дочитали до конца, и вперёд, к новым генерациям!

Команда: @denis_karachev, @shonenkov, @dendimitrov

Tags:
Hubs:
Total votes 4: ↑3 and ↓1+3
Comments6

Information

Website
www.sber.ru
Registered
Founded
Employees
over 10,000 employees
Location
Россия