Всем привет! Меня зовут Айбек Аланов. Я — аспирант факультета компьютерных наук ВШЭ, а также научный сотрудник группы «Вероятностные методы машинного обучения» AIRI. Сегодня мне хотелось бы поделиться с вами успехами, которых добилась наша научная группа в вопросе адаптации генеративно-состязательных сетей на новые домены.
Генеративно-состязательные сети
Генеративно-состязательной сетью (Generative adversarial network, GAN) называют алгоритм машинного обучения без учителя, который основан на комбинации двух нейронных сетей — генератора и дискриминатора, — настроенных на работу друг против друга. Первая генерирует новые образцы на основе исходных, задача второй — распознать, что это подделка. После каждого цикла генерации и распознавания, происходит обновление весов каждой сети на основе общей функции потерь, которая минимизируется генератором и максимизируется дискриминатором. Такая антагонистическая игра позволяет генератору все лучше и лучше подделывать образцы до такой степени, что к концу обучения они становятся неотличимы от реальных образцов.
Впервые такая модель была предложена легендой машинного обучения Яном Гудфеллоу в 2014 году. Менее, чем за декаду она показала себя эффективным инструментом для генерации и улучшения изображений, благодаря чему GAN активно используются для решения самых разных задач в компьютерном зрении: улучшения и изменения изображений, междоменных преобразований «image-to-image» и многого другого.
Проблема малой выборки
Сегодня, чтобы обучить генеративную модель на каком-то домене (то есть наборе схожих по признакам изображений), нам нужен доступ к большой выборке высокого качества из него. Под «большой» я понимаю «очень большой». Например, если мы хотим научиться генерировать реалистичные лица высоком разрешении, то нам нужно иметь датасет типа FFHQ, содержащий около 70 000 изображений, которые были специально для этого отобраны и имеют разрешение как минимум 1024×1024 пикселей.
Но такие большие датасеты встречаются не всегда. Скажем, для генерации мордочек котиков может помочь датасет AFHQ, но там их всего пять тысяч. Еще хуже обстоят дела с другими наборами данных. Например, в датасете MetFaces содержится всего чуть более тысячи лиц людей с картин разных художников, а в датасете FaceSketches, собранном из карандашных скетчей, всего около 300 объектов. Если мы попытаемся обучить наш GAN на таких маленьких выборках, результат будет неудовлетворительный.
Трансферное обучение
Но все не так плохо. В какой-то момент специалисты по машинному обучению поняли, что эту проблему можно решить, адаптируя нейросеть, обученную на домене, где есть много картинок, к узкому домену. Такой подход получил название трансферного обучения (transfer learning, TL).
В случае с лицами первоначальное обучение проводят на все том же датасете FFHQ. Обученный таким образом GAN можно дообучить на небольшом числе картин людей, и он начнет рисовать лица людей в стиле известных художников. Существуют также методы доменной адаптации, которые позволяют делать это на основе текстовых описаний или одного стилевого изображения.
Но проблема всех этих методов в том, что, чтобы адаптировать генератор, обычно обучают все веса модели. В своей работе мы рассматривали в основном модель StyleGAN2 — так вот для нее необходимо настраивать порядка 30 миллионов весов! На самом деле, это довольно избыточный труд в случае, если новый домен довольно близок к основному. Например, если вы всего лишь хотите превратить ваших друзей в персонажей аниме.
Мы с моими коллегами, Дмитрием Ветровым и Вадимом Титовым, поставили себе цель облегчить эту процедуру и уменьшить число обучаемых параметров. Для этого нам потребовалось глубже разобраться, как устроена архитектура генератора StyleGAN2.
Модифицируем схему работы StyleGAN2
Если кратко, эта модель устроена следующим образом. Случайный шум, который определяет наш выходной объект, подается с помощью специального стилевого вектора. Он получается из исходного шума путем нескольких преобразований и подается в каждый сверточный слой генератора. В слоях происходит модуляция входных каналов свертки путем умножения на компоненты этого вектора. Таким образом осуществляется контроль дисперсии каждого входного канала. Другими словами, стилевой вектор контролирует все семантические признаки выходного изображения: пол, возраст и другие свойства. Наконец, в слое происходит демодуляция, которая нормирует наши выходные каналы.
Мы решили внести в эту схему дополнительную операцию, а именно доменную модуляцию. Она заключается в введении дополнительного доменного вектора той же размерности, что и стилевой вектор. Его компоненты также домножаются на входные каналы, прежде чем происходит демодуляция. Доменная адаптация в данном случае заключается в том, что для нового домена нам достаточно дообучать только этот новый вектор, а не все веса модели, а его размерность — это всего лишь 6 тысяч, то есть в пять тысяч раз меньше, чем в традиционном подходе.
В своей недавней статье, опубликованной в сборнике трудов конференции NeurIPS 2022, мы показали, что этого действительно достаточно, чтобы адаптировать предобученную StyleGAN2 на самые разные домены. Вы можете сами убедиться, что оба подхода визуально показывают одинаковый результат для адаптации по текстовому описанию:
Для one-shot-адаптации — то есть адаптации по одному изображению — результат так же хороший (на примере модели MindTheGap):
Мульти-доменная адаптация с помощью гиперсети
Вдохновившись таким радикальным снижением пространства параметров, мы задались вопросом, а нельзя ли обучить дополнительную сеть, которая автоматически выдавала бы нам доменный вектор под каждый домен? Другими словами, можно ли решить таким способом задачу мульти-доменной адаптации GAN?
Оказалось, что можно. Для этого мы разработали гиперсеть, которая получила название HyperDomainNet. На ее вход мы подаем эмбеддинг текста, который описывает целевой домен —его можно получить с помощью дополнительного предобученного энкодера. На выходе гиперсеть выдает соответствующие доменные векторы. В результате GAN рисует изображения в каждом домене, который был использован в обучении гиперсети.
Такой подход дает результаты не хуже, чем если бы мы дообучали модель на каждый домен по-отдельности.
В ходе работы с гиперсетью мы обнаружили еще один необычный эффект: когда число доменов, на которых она обучается, становится достаточно велико, алгоритм смог адаптировать GAN к новым доменам (unseen domains), то есть к доменам, которые не были представлены в обучении.
Если вам интересно позапускать нашу модель, можете зайти на наш репозиторий. Там есть ссылка на наш Colab, где вы можете использовать наши предобученные модели на своих фотографиях, либо самостоятельно дообучить StyleGAN2 на новых доменах с помощью малого числа параметров.
На этом все. Если есть вопросы, с удовольствием отвечу на них в комментариях!