
В прошлой статье мы научились классифицировать данные без разметки с помощью понижения размерности и методов кластеризации. По итогам получили первичную разметку данных и узнали, что это картинки. С такими начальными условиями можно придумать что-то более серьёзное, например, дообучить существующую нейросеть на наши классы, даже если до этого она их никогда не видела. В iFunny на первом уровне модерации мы выделяем три основных класса:
approved — картинки идут в раздел collective (развлекательный контент и мемы);
not suitable — не попадают в общую ленту, но остаются в ленте пользователя (селфи, пейзажи и другие);
risked — получают бан и удаляются из приложения (расизм, порнография, расчленёнка и всё, что попадает под определение «противоправный контент»).
Сегодня на наглядных примерах расскажу, как мы перестраивали модель под наши классы, обучали её и выделяли паттерны распознавания картинок. Технические подробности — под катом.
Для начала возьмём небольшую сеть VGG-11, которая уже реализована во фреймворке pytorch. Сейчас есть много других сетей с результатами получше, но данная модель достаточно легкая, чтобы получить заметный результат за короткое время.
Зададим несколько преобразований, чтобы унифицировать данные и привести их к привычному для сети виду. Из документации следует, что модель была предобучена на изображениях размера 224×224 пикселя, приведённых к такому формату с помощью преобразований:
from torchvision import transforms transform = transforms.Compose( [ transforms.ToTensor(), transforms.Resize(256), transforms.CenterCrop(224), transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ) ] )
Сначала уменьшаем изображение вдоль его наименьшей стороны до 256 пикселей, а затем вырезаем из центра квадрат со сторонами 224×224 пикселя. После этого независимо производится нормировка вдоль каждого канала в RGB-пространстве. Именно поэтому в преобразовании задано три значения среднего, и столько же дисперсий.
Зачем вообще нужно делать нормализацию? Представьте, что вы непутёвая Золушка и рассыпали на пол несколько видов круп, которые нужно собрать и разделить. Можно собирать по зёрнышку и сразу откладывать в нужную кучку, но удобнее собрать всё в одну кучку, а уже потом разделять по сортам. То же самое мы делаем для сети — собираем все значения в одну кучу, а потом заставляем сеть делить их на выделенные классы.

Теперь загрузим саму модель и её веса, полученные при обучении на датасете ImageNet:
From torchvision import models model = models.vgg11(pretrained=True).eval()
Метод eval класса VGG позволяет отключить обучение и расчёт градиента, что ускоряет предсказание модели там, где нужно узнать только ответ без дополнительного обучения, а также фиксирует все веса, что позволяет получать один и тот же ответ независимо от количества запусков. Картинки, которые находятся в наборе данных ImageNet, выглядят примерно так:

Посмотрим, на какие объекты изображений реагирует сеть при выборе того или иного класса — так лучше поймём, как она работает. Для этого воспользуемся вектором Шепли (Shapley Value), который отражает то, как меняется решение сети с параметром и без него (в нашем случае параметром будет значение пикселя).
Код для получения рисунка ниже можно взять из документации. Мы лишь добавили ещё одно изображение, заменили модель (в примере используется VGG-16), а также слой, с которого берётся градиент (с седьмого, как в примере, на десятый).
Результат следующий:

Первый столбец — исходная картинка, второй столбец — первый предсказанный класс с наибольшим значением вероятности, третий столбец — второй по важности класс. Красным выделено то, на что сеть реагировала больше всего, а синим — то, что её склоняло в сторону другого класса при выборе текущего.
Если погуглить названия классов, то будет видно, что даже если сеть ошиблась, как в случае совы (она не верно указала вид, однако это может быть связано с наличием только такого лейбла в датасете), то её ответ очень близок к истине.
Теперь вспомним, что у нас есть свой датасет, который мы разметили. Для примера возьмём реальные картинки из трёх классов нашего приложения iFunny. Напомню их:
approved — картинки идут в раздел приложения collective;
not suitable — не попадают в общую ленту, но остаются в ленте пользователя. К этому относятся девушки в купальниках и мужчины в плавках, селфи и всё, что не является мемами и не несет в себе развлекательную функцию;
risked — сюда относится расизм, порно, расчленёнка и всё, что законадательно запрещено к размещению и может навредить имиджу компании. Такой контент получает бан и перестает быть доступным всем пользователям iFunny.
В процессе обучения придётся часто обращаться к данным и производить с ними математические операции на ЦПУ или ГПУ (в зависимости от железа). В фреймворке pytorch уже есть реализованный класс ImageFolder, открывающий изображение по заданному пути и присваивающий ему класс в соответствии с папкой, в которой находится. Чтобы им воспользоваться, необходимо сгруппировать изображения определённым образом. В нашем случае все тренировочные изображения лежат в train/interim и разбиты по папкам с названием класса, как это показано ниже. Название объекта должно быть уникальным (у нас это ID контента).
interim/ approved/ H6f8XI2i8.jpg XFkQE1Zi8.jpg ... not_suitable/ DCS2iR3i8.jpg KmyGT7Yi8.jpg ... risked/ KRXZUuci8.jpg m6CH7yxh8.jpg ...
Создадим словарь с датасетами тренировочной и валидационной выборок:
from torchvision import datasets datasets = { 'train': ImageFolder( root='train/interim/', transform=transform ), 'valid': ImageFolder( root='test/interim/', transform=transform ), } datasets['train'].class_to_idx
Как было сказано ранее, класс ImageFolder производит лейблирование — каждой названной папке сопоставляет определённую цифру. По факту выдает порядковый номер отсортированного по алфавиту списка классов:
{'approved': 0, 'not_suitable’: 1, 'risked’: 2}
Это необходимо, поскольку производить математические операции с числами в процессе оптимизации сети проще, чем обращаться к строкам. Но загружать по одному изображению долго, поэтому используем DataLoader, который делает это сразу пачкой (в DS среде она называется батчом):
torchfrom torch.utils.data import DataLoader batch_size = 100 num_workers = 20 dataloaders = { 'train': DataLoader( datasets['train'], batch_size=batch_size, shuffle=True, num_workers=num_workers ), 'valid': DataLoader( datasets['valid'], batch_size=batch_size, shuffle=True, num_workers=num_workers ) }
Передаём в DataLoader наш датасет и указываем размер батча, который говорит о том, сколько картинок загрузить в сеть одновременно. С точки зрения машины все изображения являются матрицами, а все операции внутри сети — матричными. Поэтому ничего не мешает производить их сразу с N изображениями, собрав их вместе.
Также DataLoader может загружать данные параллельно, за счёт чего процесс загрузки происходит в разы быстрее. Количество одновременных процессов на загрузку задаётся параметром num_workers.
Отобразим наш тренировочный датасет:
import numpy as np real_batch = next(iter(dataloaders['train'])) plt.figure(figsize=(8,8)) plt.axis("off") plt.title("Training Images") plt.imshow(np.transpose(vutils.make_grid(real_batch[0][:64], padding=2, normalize=True),(1,2,0)));
Примерно так выглядит контент, который загружают пользователи в iFunny, с единственным отличием, предполагающим сбалансированность классов в нашем датасете, в отличие от реальности, где risked контент составляет меньше 10%:

Перестройка модели под наши цели
Внутри сети происходит выделение разных паттернов (уши и хвосты на картинках с животными, как в примере выше), а самый последний слой на их основании делает предположение, чем является данный объект. Именно последний слой нам и нужно поменять, потому что на выходе он по умолчанию имеет тысячу классов, а нам нужно всего три, и совсем других, которых не было в этой тысяче.
В следующей строке повторим загрузку модели с заданной архитектурой VGG-11. Флаг pretrained в положении True позволяет загрузить предобученные на ImageNet веса, любезно предоставленные pytorch. Модель переводим в память ГПУ методом to, аргументом которого является название необходимого девайса, так как все дальнейшие вычисления будут производиться на ней.
device = 'cuda' model = models.vgg11(pretrained=True, progress=False).to(device)
Затем в цикле присваиваем атрибуту requires_grad всех слоев сети значение False, чтобы в процессе обучения они не изменялись:
for param in model.parameters(): param.requires_grad = False
После чего меняем слой, отвечающий за классификацию:
model.classifier[6] = torch.nn.Linear(4096, 3).to(device)
Важный момент: мы не переобучаем все предыдущие слои, так как считаем, что они уже научились выделять общие признаки, в отличие от последнего слоя, который только что поменяли и тем самым переключили его атрибут requires_grad в положение True. Этот слой необходимо обучить, так как теперь он имеет случайные веса, а значит не может корректно отличать что-либо. Также нужно задать оптимизатор, который будет подбирать наиболее подходящие веса и фактически обучать сеть.
Для этого возьмём один из методов градиентного спуска Adam. Ещё необходимо указать критерий, по которому будет идти оптимизация — для этого используем кросс-энтропию.
params_to_update = model.parameters() print("Params to learn:") params_to_update = [] for name,param in model.named_parameters(): if param.requires_grad == True: params_to_update.append(param) optimizer = torch.optim.Adam(params_to_update, lr=0.0001) criterion = torch.nn.CrossEntropyLoss()
В цикле выше собираем слои, у которых флаг requires_grad в положении True. До этого мы их переводили в положении False, поэтому в оптимизатор передаётся информация только о последнем слое.
Обучение модели
Для обучения воспользуемся функцией train_model из официального туториала pytorch:
import time import copy from tqdm import tqdm def train_model(model, dataloaders, criterion, optimizer, num_epochs=10): since = time.time() val_acc_history = [] best_model_wts = copy.deepcopy(model.state_dict()) best_acc = 0.0 for epoch in range(num_epochs): print('Epoch {}/{}'.format(epoch, num_epochs - 1)) print('-' * 10) # Each epoch has a training and validation phase for phase in ['train', 'valid']: if phase == 'train': model.train() # Set model to training mode else: model.eval() # Set model to evaluate mode running_loss = 0.0 running_corrects = 0 # Iterate over data. for inputs, labels in tqdm(dataloaders[phase]): inputs = inputs.to(device) labels = labels.to(device) # zero the parameter gradients optimizer.zero_grad() # forward # track history if only in train with torch.set_grad_enabled(phase == 'train'): outputs = model(inputs) loss = criterion(outputs, labels) _, preds = torch.max(outputs, 1) # backward + optimize only if in training phase if phase == 'train': loss.backward() optimizer.step() # statistics running_loss += loss.item() * inputs.size(0) running_corrects += torch.sum(preds == labels.data) epoch_loss = running_loss / len(dataloaders[phase].dataset) epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset) print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc)) # deep copy the model if phase == 'valid' and epoch_acc > best_acc: best_acc = epoch_acc best_model_wts = copy.deepcopy(model.state_dict()) if phase == 'valid': val_acc_history.append(epoch_acc) print() time_elapsed = time.time() - since print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60)) print('Best val Acc: {:4f}'.format(best_acc)) # load best model weights model.load_state_dict(best_model_wts) return model, val_acc_history
На данный момент уже существует множество библиотек для автоматизации процесса обучения. Они отличаются между собой различными логгерами, возможностью подключения дополнительных модулей и другими удобствами. Но есть неизменная основа.
Нужно обязательно обнулять градиент командой optimizer.zero_grad() перед каждым запуском обратного распространения ошибки внутри сети:
loss.backward() optimizer.step()
Первая строка запускает операцию обратного распространения ошибки внутри сети из переменной потери (loss). Вторая — выполняет градиентный шаг на основе вычисленных градиентов.
Ответы нашей модели считаются в следующей строке:
outputs = model(inputs)
А ошибку, которую используем в дальнейшем для расчёта градиента, получаем одной строчкой:
loss = criterion(outputs, labels)
В процессе обучения происходит не только тренировка, но и проверка качества сети на объектах, не участвующих в обучении. Так можно выявить наличие переобучения, когда сеть начинает запоминать образцы из тренировочной выборки, и теряет обобщающую способность — то есть показывает результаты на новых данных заметно хуже, чем на тренировочных.
Чтобы этого избежать, нужно регулярно производить проверку на отложенной выборке. В данном процессе не стоит вычислять градиент, чтобы сеть не запомнила и эти примеры. Поэтому переключаем сеть в другой режим методом eval:
if phase == 'train': model.train() # Set model to training mode else: model.eval() # Set model to evaluate mode
Для иллюстрации процесса обучения и его результатов, мы провели обучение модели на 10 эпохах. Эпохой в DS среде называется полный проход по всем примерам выборки (обычно сеть не способна за один раз усвоить все правила, поэтому количество эпох почти всегда больше 1). Ниже приведены выводы функции обучения:
model, val_acc_history = train_model(model, dataloaders, criterion, optimizer, num_epochs=10)
Логи обучения
0%| | 0/245 [00:00<?, ?it/s] Epoch 0/9 ---------- 100%|██████████| 245/245 [00:51<00:00, 4.76it/s] 0%| | 0/30 [00:00<?, ?it/s] train Loss: 0.8573 Acc: 0.6141 100%|██████████| 30/30 [00:08<00:00, 3.51it/s] 0%| | 0/245 [00:00<?, ?it/s] valid Loss: 0.8701 Acc: 0.5867 Epoch 1/9 ---------- 100%|██████████| 245/245 [00:51<00:00, 4.73it/s] 0%| | 0/30 [00:00<?, ?it/s] train Loss: 0.7892 Acc: 0.6526 100%|██████████| 30/30 [00:08<00:00, 3.57it/s] 0%| | 0/245 [00:00<?, ?it/s] valid Loss: 0.8458 Acc: 0.6012 Epoch 2/9 ---------- 100%|██████████| 245/245 [00:51<00:00, 4.75it/s] 0%| | 0/30 [00:00<?, ?it/s] train Loss: 0.7716 Acc: 0.6601 100%|██████████| 30/30 [00:08<00:00, 3.50it/s] 0%| | 0/245 [00:00<?, ?it/s] valid Loss: 0.8380 Acc: 0.6049 Epoch 3/9 ---------- 100%|██████████| 245/245 [00:52<00:00, 4.70it/s] 0%| | 0/30 [00:00<?, ?it/s] train Loss: 0.7551 Acc: 0.6658 100%|██████████| 30/30 [00:08<00:00, 3.59it/s] 0%| | 0/245 [00:00<?, ?it/s] valid Loss: 0.8374 Acc: 0.6012 Epoch 4/9 ---------- 100%|██████████| 245/245 [00:52<00:00, 4.71it/s] 0%| | 0/30 [00:00<?, ?it/s] train Loss: 0.7464 Acc: 0.6703 100%|██████████| 30/30 [00:08<00:00, 3.47it/s] 0%| | 0/245 [00:00<?, ?it/s] valid Loss: 0.8166 Acc: 0.6157 Epoch 5/9 ---------- 100%|██████████| 245/245 [00:52<00:00, 4.71it/s] 0%| | 0/30 [00:00<?, ?it/s] train Loss: 0.7423 Acc: 0.6731 100%|██████████| 30/30 [00:08<00:00, 3.57it/s] 0%| | 0/245 [00:00<?, ?it/s] valid Loss: 0.8155 Acc: 0.6174 Epoch 6/9 ---------- 100%|██████████| 245/245 [00:52<00:00, 4.69it/s] 0%| | 0/30 [00:00<?, ?it/s] train Loss: 0.7379 Acc: 0.6764 100%|██████████| 30/30 [00:08<00:00, 3.54it/s] 0%| | 0/245 [00:00<?, ?it/s] valid Loss: 0.8117 Acc: 0.6221 Epoch 7/9 ---------- 100%|██████████| 245/245 [00:52<00:00, 4.69it/s] 0%| | 0/30 [00:00<?, ?it/s] train Loss: 0.7329 Acc: 0.6780 100%|██████████| 30/30 [00:08<00:00, 3.55it/s] 0%| | 0/245 [00:00<?, ?it/s] valid Loss: 0.8113 Acc: 0.6201 Epoch 8/9 ---------- 100%|██████████| 245/245 [00:51<00:00, 4.77it/s] 0%| | 0/30 [00:00<?, ?it/s] train Loss: 0.7314 Acc: 0.6802 100%|██████████| 30/30 [00:08<00:00, 3.50it/s] 0%| | 0/245 [00:00<?, ?it/s] valid Loss: 0.8106 Acc: 0.6221 Epoch 9/9 ---------- 100%|██████████| 245/245 [00:51<00:00, 4.73it/s] 0%| | 0/30 [00:00<?, ?it/s] train Loss: 0.7243 Acc: 0.6787 100%|██████████| 30/30 [00:08<00:00, 3.65it/s] valid Loss: 0.8184 Acc: 0.6123 Training complete in 10m 4s Best val Acc: 0.622095
По динамике видно, что метрики со временем улучшаются как на тренировочной, так и на валидационной выборках. При этом значения loss-функции падают. Это говорит о положительном течении процесса обучения, а также об отсутствии переобучения модели.
Результаты обучения
Также как и в случае с классификацией чисел, построим матрицу ошибок, чтобы увидеть, как наша модель справляется с поставленной задачей.

Сеть хорошо научилась разделять approved и not suitable контент, а вот с изображениями класса risked не всё так гладко. Но напомню, что использовалась очень простенькая сеть без каких-либо дополнительных методов улучшения, поэтому её ещё есть куда развивать.
Паттерны новой сети
Отразим основные паттерны нашей сети с помощью вектора Шепли.

На первой картинке с котом видно, что сеть при выборе approved класса в основном смотрела на мордочку и уши. У Гитлера определённые паттерны при выборе класса risked выделить сложнее — сеть немного отреагировала на нос и усы, но скорее всего решение принималось по цветовой гамме, так как нацистские фото почти всегда чёрно-белые, и она цепляется за этот признак.
Самое интересное можно наблюдать на третьей картинке с девушкой. Грудь была важным параметром для выбора обоих классов (и risked, и not suitable), но оказала негативное влияние на класс approved, отклонив его. В пользу not suitable сеть склонил купальник, контур которого у этого класса выделен красным, а у risked — синим. В данном случае все логично, поскольку без купальника данный контент стал бы эротикой и считался недопустимым. Также наличие лица в кадре отрицательно сказалось при проверке сетью принадлежности объекта к risked классу — мы не запрещаем пользователям загружать селфи, но и не продвигаем его через общую ленту, так как такой контент больше подходит для Instagram.
Вместо заключения
Мы научились классифицировать данные с разметкой с помощью дообучения предобученной сети. А также разобрались, на что реагируют модели и как происходит обучение сети. Эти способы скорее базовые и не претендуют на совершенство, но являются отличной отправной точкой.
С помощью метода из первой статьи можно получить первичную разметку, которую придётся перепроверить вручную для лучшего качества, но это будет проще, чем размечать все объекты с нуля. На основе последовательности действий из статьи можно дообучить не только представленную сеть, но и что-то более сложное.
На данный момент в открытом доступе лежит большое количество готовых реализаций архитектур от Google, Facebook, а также других компаний и университетов с предобученными весами, показывающими отличные результаты, а порой даже относящихся к State of the Art (SOTA). Для поиска подходящей архитектуры есть удобный сайт Papers with Code, где в открытом доступе на GitHub есть готовые реализации.
Напоследок — мем, который был одобрен нашей обученной сетью с вероятностью 85%.

Вероятности классов:
approved — 0.8528
not suitable — 0.0996
risked — 0.0475
