Как стать автором
Обновить

Улучшения для генеративно-состязательных сетей (GAN)

Время на прочтение8 мин
Количество просмотров3.8K
Для прикладных задач, редко когда требуется искусственная генерация данных. Тем не менее алгоритм состязательной-генеративной модели (GAN) поражает и даёт возможность создавать сервисы рисования и даже фотографию не существующего человека.
На Хабре есть несколько статей разбора алгоритма с теоретической точки зрения. Здесь я бы хотел сконцентрироваться на коде, а именно заострить внимание на улучшениях и трюках, которые сделают процесс обучения быстрее, более контролируемым и улучшают качество генерируемых примеров.

За основу взят пример из документации tensorflow и код на keras. Они оба отлично работают, но как убедитесь ниже, не идеальны. Идеи для улучшения в основном взяты из статьи Tips for GAN и из исследования по улучшению GAN.

Подготовка


Для работы потребуется всего 2 библиотеки: numpy и tensorflow и несколько библиотек для отрисовки GIF изображения прямо в jupyter notebook.

Зависимости


#python 3.6
!pip install tensorflow==2.0.0
!pip install numpy==1.18.1
!pip install imageio==2.6.1
!pip install matplotlib==3.1.2
!pip install tqdm==4.41.1


Импорты


import tensorflow as tf
import numpy as np
from tensorflow.keras.layers import Input, BatchNormalization, LeakyReLU, Reshape
from tensorflow.keras.layers import Conv2D, Conv2DTranspose, Activation, Concatenate
from tensorflow.keras.layers import Dense, Flatten, Dropout
from tensorflow.keras.models import Model
from tensorflow.keras.utils import to_categorical
from tqdm.notebook import tqdm
#библиотеки нужные для отображения GIF
import imageio
from IPython.display import Image
import matplotlib.pyplot as plt
import os

Загрузка данных


DEPTH = 10
(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')
train_images = train_images / 255 # нормализуем в [0, 1]
train_labels = np.array(to_categorical(train_labels, DEPTH))

Универсальный генератор данных


Улучшения:

  • в источниках говориться, что случайный вектор лучше генерировать с поверхности сферы (т.е. нормализовать), но опыт показал, что это увеличивает продложительность обучения;
  • ярлыки изображений также потребуются. Это полезная информация и её ни в коем случае нельзя выбрасывать;
  • сделаем регулируемым размер батчей в эпохе;
  • с помощью tqdm добавим визуализацию прогресса эпохи;
  • изображения из массива будут браться случайно;

RANDOM_DIM = 100

class DataGenerator():
    def __init__(self, train_images, train_labels, batches_per_epoch, batch_size):
        self.train_images = train_images
        self.train_labels = train_labels
        self.batches_per_epoch = batches_per_epoch
        self.batch_size = batch_size
    
    @staticmethod
    def rand_norm(npoints=1, ndim=RANDOM_DIM):
        rand_vec = np.random.normal(0, 1, size=[npoints, ndim])
        # проекция на поверхность сферы
        # rand_vec = rand_vec / np.sqrt(np.sum(np.square(rand_vec), axis=1))[:, np.newaxis]
        return rand_vec

    def __len__(self):
        return self.batches_per_epoch
    
    def batch(self):
        rand_images_indexes = np.random.randint(0, train_images.shape[0], size=self.batch_size)
        image_batch = train_images[rand_images_indexes]
        labels_batch = train_labels[rand_images_indexes]
        return image_batch, labels_batch
    
    def __iter__(self):
        for b in tqdm(range(self.batches_per_epoch), leave=False):
            yield self.batch()
            
    def rand_batch(self):
        rand_vec = self.rand_norm(self.batch_size)
        rand_labels = np.random.randint(0, 10, size=[self.batch_size])
        rand_labels = np.array(to_categorical(rand_labels, DEPTH))
        return [rand_vec, rand_labels]

np.random.seed(42)
images_indexes = np.random.randint(0, len(train_images), size=16)
PICS_FROM_DATASET = [train_images[images_indexes], train_labels[images_indexes]]
FIXED_NOISE = [DataGenerator.rand_norm(16), train_labels[images_indexes]]

Генератор


Улучшения:

  • одна переменная complexity определяющая количество обучаемых переменных, и как следствие, упрощение балансировки генератора и дискриминатора;
  • для первого слоя инициализация весов производится с меньшим значением дисперсии;
  • для генератора в каждом слое применяется нормализация батча;
  • вместо классических свёрточных слоёв используются Conv2DTranspose;
  • опытным путём было определено, что активация ‘sigmoid’ в выходном слое работает, лучше чем ‘tanh', но возможно это зависит от типа данных или размера сети;
  • оптимайзер для генератора sgd;

def make_generator():
    complexity = 80
    alpha = 0.2
    
    random_vector = Input(shape=[RANDOM_DIM])
    labels = Input(shape=[DEPTH])
    
    X = Concatenate()([random_vector, labels])
    # 1 слой
    X = Dense(7 * 7 * int(complexity/2),
                          input_dim=RANDOM_DIM,
                          kernel_initializer=tf.keras.initializers.RandomNormal(stddev=0.02))(X)
    X = BatchNormalization()(X)
    X = LeakyReLU(alpha)(X)
    
    X = Reshape((7, 7, int(complexity/2)))(X)
    
    # 2 слой
    X = Conv2DTranspose(int(complexity), 
                        kernel_size=(5, 5),
                        strides=(2, 2),
                        padding='same')(X)
    X = BatchNormalization()(X)
    X = LeakyReLU(alpha)(X)

    # 3 слой
    X = Conv2DTranspose(int(complexity),
                        kernel_size=(5, 5),
                        strides=(1, 1),
                        padding='same')(X)
    X = BatchNormalization()(X)
    X = LeakyReLU(alpha)(X)
    
    # 4 слой
    X = Conv2DTranspose(int(complexity),
                        kernel_size=(5, 5),
                        strides=(1, 1),
                        padding='same')(X)
    X = BatchNormalization()(X)
    X = LeakyReLU(alpha)(X)
    
    # 5 слой
    X = Conv2DTranspose(int(complexity),
                        kernel_size=(5, 5),
                        strides=(1, 1),
                        padding='same')(X)
    X = BatchNormalization()(X)
    X = LeakyReLU(alpha)(X)
    
    # 6 слой
    X = Conv2DTranspose(1,
                        kernel_size=(5, 5),
                        strides=(2, 2),
                        padding='same')(X)
    X = Activation('sigmoid')(X)
    
    model = Model(inputs=[random_vector, labels], outputs=X, name='generator')
    model.compile(loss='binary_crossentropy',
                  metrics=['acc'],
                  optimizer='sgd')
    model.summary()
    return model
generator = make_generator()

Тест генератора


Можно не вдаваться в подробности как устроена эта функция. Главное, что она рисует и сохраняет в отдельную папку 16 сгенерированных изображений и 16 изображений из набора данных.


def mkdir_p(path):
    try:
        os.makedirs(path)
    except OSError:
        print('Folder already exists')

def plot_and_save_images(generated_images, folder_name, title, cols, rows, figsize, show=False, subtitles=False):
    fig, axs = plt.subplots(rows, cols, constrained_layout=True)
    fig.set_figheight(figsize[0])
    fig.set_figwidth(figsize[1])
    for i in range(rows):
        for j in range(cols):
            axs[i][j].imshow(generated_images[i][j], interpolation='nearest', cmap='gray_r')
            axs[i][j].axis('off')
            if not subtitles is None:
                if i == 0 and j == 0:
                    axs[i][j].set_title(subtitles[0], fontsize=10)
                if i == int(rows/2) and j == 0:
                    axs[i][j].set_title(subtitles[1], fontsize=10)
    
    fig.suptitle(title, fontsize=12)
    plt.savefig(os.path.join(folder_name, title + '.png'))
    if show:
        plt.show()
    plt.close()


def plot_generated_images(noise, folder_name, title, cols=8, rows=4, figsize=(4, 8), show=False, subtitles=('generated', 'dataset')):
    generated_images = generator.predict(noise)
    images = np.concatenate([generated_images, PICS_FROM_DATASET[0]])
    images = images.reshape(rows, cols, 28, 28)
    plot_and_save_images(images, folder_name, title, cols, rows, figsize, show, subtitles)
    

folder_name = 'GAN_pics_for_gif'
mkdir_p(os.path.join(folder_name))

title = f"DCGAN training process {0} epochs"
plot_generated_images(FIXED_NOISE, folder_name, title, show=True)

Дискриминатор


Для дискриминатора потребуется оптимайзер adam с уменьшенными значениями для параметров learning_rate и beta_1.


def optimizer():
    return tf.keras.optimizers.Adam(lr=0.0002, beta_1=0.5)

Улучшения:
  • одна переменная complexity определяющая количество обучаемых переменных => упрощение балансировки генератора и дискриминатора;
  • дискриминатор угадывает не только является ли цифра сгенерированной, но и метку изображения;
  • для дискриминатора не применяется нормализация батча;
  • для промежуточных слоёв функция активации LeakyReLU c alpha=0.2 (LeakyReLU показывается себя лучше чем 'ReLU' и 'elu');

Для баланса дискриминатор (скорее всего) должен быть меньше генератора.


def make_discriminator():
    complexity = 70
    drop_rate = 0.2
    alpha = 0.2
    
    inp = Input(shape=(28, 28, 1))
    X = inp
    # 1 слой
    X = Conv2D(int(complexity/4),
               kernel_size=(5, 5),
               strides=(2, 2),
               kernel_initializer=tf.keras.initializers.RandomNormal(stddev=0.02),
               padding='same')(X)
    X = LeakyReLU(alpha)(X)
    X = Dropout(drop_rate)(X)
    
    # 2 слой
    X = Conv2D(int(complexity/2),
               kernel_size=(5, 5),
               strides=(2, 2),
               #kernel_initializer=tf.keras.initializers.RandomNormal(stddev=0.02),
               padding='same')(X)
    X = LeakyReLU(alpha)(X)
    X = Dropout(drop_rate)(X)
    
    # 3 слой
    X = Conv2D(int(complexity),
               kernel_size=(5, 5),
               strides=(2, 2),
               #kernel_initializer=tf.keras.initializers.RandomNormal(stddev=0.02),
               padding='same')(X)
    X = LeakyReLU(alpha)(X)
    X = Dropout(drop_rate)(X)
    
    # 4 слой
    X = Conv2D(int(2*complexity),
               kernel_size=(5, 5),
               strides=(1, 1),
               #kernel_initializer=tf.keras.initializers.RandomNormal(stddev=0.02),
            padding='same')(X)
    X = LeakyReLU(alpha)(X)
    X = Dropout(drop_rate)(X)
    
    # 5 слой
    X = Conv2D(int(complexity),
               kernel_size=(5, 5),
               strides=(1, 1),
               #kernel_initializer=tf.keras.initializers.RandomNormal(stddev=0.02),
               padding='same')(X)
    X = LeakyReLU(alpha)(X)
    X = Dropout(drop_rate)(X)

    # 6 слой
    X = Flatten()(X)
    X_real_fake = Dense(1, activation='sigmoid', name='real_fake')(X)
    # 6 слой
    X_labels = Dense(DEPTH, activation='softmax', name='labels')(X)
    
    model = Model(inputs=inp, outputs=[X_real_fake, X_labels], name='discriminator')
    model.compile(loss={'real_fake': 'binary_crossentropy',
                        'labels': 'categorical_crossentropy'},
                  loss_weights={'real_fake':1,
                                'labels': 1},
                  optimizer=optimizer(),
                  metrics={'real_fake':'acc'})
    
    model.summary()
    return model
discriminator = make_discriminator()

Тест дискриминатора


discriminator.predict(generator.predict(FIXED_NOISE))


Сборка GAN


Обратите внимание, что надо обязательно отключить обучаемость дискриминатора перед компиляцией GAN.



def make_gan(discriminator, generator):   
    noise = generator.inputs
    image = generator(noise)
    real_vs_fake_and_label = discriminator(image)
    gan = Model(inputs=noise, outputs=real_vs_fake_and_label)
    # надо выключить обучаемость дискриминатора перед компиляцией
    # т.к. обучением GAN является обучение генератора
    discriminator.trainable = False  
    gan.compile(loss={'discriminator':'binary_crossentropy',
                      'discriminator_1':'categorical_crossentropy'},
                
                optimizer=optimizer(),
                metrics={'discriminator':'acc'})
    discriminator.trainable = True
    gan.summary()
    return gan
gan = make_gan(discriminator, generator)


Обучение


Обучение дискриминатора:

  1. Создаём батч сгенерированных изображений и изображений из датасета
  2. Делаем сглаживание целевых значений для избежания экстремальных значений градиента
  3. Создаём целевые значения меток
  4. Обучаем дискриминатор

Обучение генератора:

  1. Задаём целевые значения единицы для GAN
  2. Обучаем GAN(генератор) на сгенерированном случайном векторе

Улучшения:

  • сглаживание целевых значений как для изображений из датасета, так и для сгенерированных;
  • т.к. обучающие данные генерятся случайно, есть возможность свободно регулировать количество батчей в эпоху;
  • есть возможность отслеживать как хорошо генератор «обманывает» дискриминатор. Оптимально, что бы точность генератора была ровна 50%;

EPOCHS = 500
BATCH_SIZE = 128
data = DataGenerator(train_images, train_labels, batches_per_epoch=40, batch_size=BATCH_SIZE)
#массив нужен для отслеживания стабильности обучения
generator_acc_follow = []
for epoch in tqdm(range(1, EPOCHS+1)):
    for images_batch, labels_batch in data:
        # Обучение дискриминатора
        noise = data.rand_batch()
        generated_images = generator.predict(noise)
        X = np.concatenate([images_batch, generated_images])
        # сглаживание целевых значений
        y_real = np.random.uniform(0.8, 1, size=[BATCH_SIZE])
        y_fake = np.random.uniform(0, 0.2, size=[BATCH_SIZE])
        y_real_fake = np.concatenate([y_real, y_fake])
        # целевые значения меток
        y_labels = np.concatenate([labels_batch, np.zeros((BATCH_SIZE, DEPTH))])
        # обучаем дискриминатор
        discriminator.train_on_batch(X, [y_real_fake, y_labels])
        
        # Обучение генератора
        noise = data.rand_batch()
        # задаём целевые значения
        y_real_fake = np.ones(BATCH_SIZE)
        # обучаем GAN и сохраняем точность в массив
        generator_acc_follow.append(gan.train_on_batch(noise, [y_real_fake, noise[1]])[-1])

    # построим изображение и сохраним его для создания .gif
    title = f"DCGAN training process {epoch} over {EPOCHS} epochs"
    if epoch%int(EPOCHS/40)==0:
        plot_generated_images(FIXED_NOISE, folder_name, title, show=True)
        print('если генератор и дискриминатор сбалансированы, то точность генератора не будет уходить в 0 или 1:')
        print(np.mean(generator_acc_follow))
        generator_acc_follow = []
    else:
        plot_generated_images(FIXED_NOISE, folder_name, title)
        
generator.save('generator_trained.h5')

Создание .gif для визуализации процесса обучения


for folder_data in os.walk(os.path.join(folder_name)):
    all_pics_filenames = sorted(folder_data[2], key=lambda x: int(x.split()[3]))

with imageio.get_writer('DCGAN training.gif', mode='I', fps=60) as writer:
    for filename in tqdm(all_pics_filenames):
        image = imageio.imread(os.path.join(folder_name, filename))
        writer.append_data(image)
        
with open('DCGAN training.gif','rb') as f:
    display(Image(data=f.read(), format='png'))

Визуализация процесса обучения:

image

Итоги


Имея всего 1.3М параметров, GAN через 5000 батчей начинает генерировать приемлемые изображения, а через 20000 изображения, не отличимые от рукописных. Дополнительный приятный бонус в том, что мы можем генерировать нужную нам цифру.


Скорее всего, это не полный список улучшений, который можно применить для GAN. Если у вас есть предложения, то не стесняйтесь писать о них в комментариях.

Теги:
Хабы:
Всего голосов 3: ↑3 и ↓0+3
Комментарии0

Публикации

Истории

Работа

Python разработчик
119 вакансий
Data Scientist
78 вакансий

Ближайшие события

7 – 8 ноября
Конференция byteoilgas_conf 2024
МоскваОнлайн
7 – 8 ноября
Конференция «Матемаркетинг»
МоскваОнлайн
15 – 16 ноября
IT-конференция Merge Skolkovo
Москва
22 – 24 ноября
Хакатон «AgroCode Hack Genetics'24»
Онлайн
28 ноября
Конференция «TechRec: ITHR CAMPUS»
МоскваОнлайн
25 – 26 апреля
IT-конференция Merge Tatarstan 2025
Казань