Как стать автором
Обновить

Сегментация экземпляров с помощью Mask R-CNN

Время на прочтение4 мин
Количество просмотров4.9K

Задача сегментации изображений может решаться в нескольких постановках. Самая распространённая - semantic segmentation с одним классом и фоном, необходимо просто отделить объекты от фона, не различая их между собой. Но часто просто отделения от фона недостаточно, необходимо отделять отдельные образцы друг от друга, например, чтобы оценить размер или расположение каждого отдельного объекта. Как это можно сделать?

Это задача instance segmentation. Её можно решать, как семантическую сегментацию, просто выделяя границы как отдельный класс и присваивая ему больший вес в лоссе. Такое решение подходит для многих ситуаций с простой границей и работает довольно быстро, но может оказаться нестабильным, завестись далеко не “из коробки”. Другой традиционный способ - нахождение объекта через object detection и сегментация внутри найденных ограничивающих рамок. Такой подход, например, используют Mask R-CNN и, с некоторыми модификациями, YOLACT.

Mask R-CNN является наследницей Faster R-CNN, которая используется для выделения ограничивающих рамок. Обе сети дают высокое качество при большом количестве обучающих данных и являются хорошим бейзлайном для большинства задач. Сегодня мы рассмотрим, как можно дообучить модель с torchvision на своих данных и как сделать её инференс.

 Полный код можно найти по ссылке, здесь будут короткие выдержки.

Для начала нужно подготовить датасет.  Разметить изображения можно с помощью VGG, он может работать как локально, так и из браузера, отмечать объекты нужно полигонами, а выгрузить результаты в csv, VGG json или COCO json. Затем перегнать аннотации из полигонов в маски с помощью подобного скрипта, ссылка. Итоговая разметка должна быть в формате Penn-Fudan Database. Поскольку процесс разметки трудоёмкий, можно использовать упрощающий трюк - разметить несколько изображений, обучить модель и сделать предразметку ей на необработанных изображения, затем только поправив руками. Пример такого разметчика на основе Mask R-CNN для аннотаций в формате VGG json есть по ссылке. Кроме того, некоторые платформы для разметки предоставляют такую услугу.

Можно сразу сохранить маски отдельными файлами, но это может занимать много места, и их загрузка во время обучения может отнимать время, лучше создавать маски прямо во время обучения. Это можно делать так:

# маска каждого экземпляра отдельным цветом    	
mask = get_mask(annotation)

# убираем фон и выделяем номера масок
obj_ids = np.unique(mask)[1:]

# разбиваем общую маску на бинарные
masks = mask == obj_ids[:, None, None]

# переводим в тензор
masks = torch.as_tensor(masks, dtype=torch.uint8)

Затем надо сделать стандартные вещи - выбрать оптимизатор и шедулер, изменить количество выходных каналов и backbone, можно настроить и разбиение для anchor boxes. Например, можно настроить модель так:

# загружаем предобученную модель
# можно загрузить другой backbone с torchvision, например MobileNetV2
model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)

# размер входа для детекторной части
in_features_box = model.roi_heads.box_predictor.cls_score.in_features
# размер входа для сегментационной части
in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
# выбираем размер для внутреннего представления 
# в предикторе для сегментации
hidden_layer = 256

# заменяем последние слои детекции с учётом количества наших классов
#  и размера выхода backbone
model.roi_heads.box_predictor = FastRCNNPredictor(in_features_box, num_classes)
# заменяем последние слои для сегментации
model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask, hidden_layer, num_classes)

Как применить аугментации? Можно применять их к изначальной маске, полученной из аннотаций, либо использовать аугментации из torchvision когда уже будут созданы bbox’ы и прочее, вот доступные для второго варианта. Первый вариант больше подходит для albumentations. В полной версии мы используем случайное отражение по горизонтали (RandomHorizontalFlip) и RandomPhotometricDistort, который включает в себя изменение яркости, контрастности и цветовой гаммы, добавление шума.

Сама тренировка проводится с помощью функции train_one_epoch из torchvision.references.detection.engine, там же лежит evaluate для проверки на валидационной выборке. Выводятся стандартные метрики для сегментации и детекции отдельно.

Наконец, перейдём к инференсу. Если запустить сразу, то получится вот такая каша:

Это связано с тем, что сеть для object detection выделяет много лишнего и сегментационная часть обрабатывает пустые регионы. Нужно настроить порог по уверенности для bbox’ов и порог для non maximum suppression.

model.roi_heads.score_thresh=0.4
model.roi_heads.nms_thresh=0.3

И можно предсказывать!

img = transforms.ToTensor()(img)
img = img.to(device)
prediction = model([img])
masks = torch.squeeze(prediction[0]['masks'])
masks = masks > segmentation_th

На последнем шаге мы оставляем только те пиксели, уверенность в которых выше порога.

С внесёнными изменениями выглядит намного лучше:

Итак, сегодня мы освоили сегментацию экземпляров с Mask R-CNN. Это не самый быстрый в инференсе и обучении, но стабильный алгоритм, который не боится сливающихся границ, легко настраивается и поддерживается.

Теги:
Хабы:
Всего голосов 2: ↑2 и ↓0+2
Комментарии0

Публикации

Истории

Работа

Data Scientist
46 вакансий

Ближайшие события