Этот пост приурочен к недавнему релизу open-source проекта OpenMetricLearning (OML), одна из целей которого — максимально снизить порог вхождения в тему metric learning. Мы немного пройдёмся по теории, разберём примеры с кодом и покажем, как с помощью простых эвристик догнать текущие SotA модели. Проект новый, поэтому каждая звездочка на GitHub для нас на вес золота.
О задаче Metric learning
Задача metric learning состоит в том, чтобы построить функцию от двух объектов, которая будет оценивать расстояние (похожесть) между ними. Имея такую функцию, мы можем осуществлять поиск по объектам, кластеризацию, детектирование выбросов и т.д. Далее мы рассмотрим решение данной задачи с помощью нейронных сетей, то есть deep metric learning, где выделяются два основных подхода:
Siamese. Нейронная сеть принимает на вход два объекта и возвращает вероятность, что они совпадают или похожи (в зависимости от постановки задачи). Данную вероятность можно использовать как меру похожести или «расстояние».
Representation learning. Нейронная сесть принимает на вход один объект и возвращает вектор, представляющий этот объект в некотором векторном пространстве. Далее между векторами вычисляется классическое расстояние, например, евклидово.
Допустим, нам нужно оценить все возможные расстояния междуобъектами. Для первого подхода требуется инференсов модели, а для второго инференсов и расчëтов расстояний. На практике чаще используется второй подход, так как подсчет расстояний между векторами намного быстрее, чем инференс. Далее мы будем разбирать только этот подход и воспринимать его как синоним к metric learning.
Вот так выглядит векторное пространство модели, обученной на датасете Fashion MNIST:
В чём отличие от классификации?
Задачи deep metric learning и классификации могут перетекать друг в друга, что делает использование терминологии запутанным. С одной стороны, можно натренировать классификатор, а затем использовать выходы с его последнего или предпоследнего слоя как вектора, по которым оценивается расстояние. С другой стороны, можно обучить модель с не классификационной функцией потерь (например, triplet loss, о нём позже), но использовать полученные вектора для классификации, осуществляя поиск по ближайшим соседям и беря метки их классов. Вдобавок, в обеих задачах используются одни и те же архитектуры сетей.
Если всё-таки выделить характерное отличие, то я бы сказал, что в классификации классы на train и test выборках совпадают, а в metric learning — не обязательно. Кроме того, metric learning не всегда требует явной разметки на классы. Например, может использоваться разметка вида — пара объектов и индикатор похожести между ними.
Есть ли бенчмарки?
Да, для metric learning, как и для классификации, существует набор популярных датасетов, например, картиночных, на которых исследователи сравнивают свои наработки.
Как происходит обучение и валидация модели
Для примера рассмотрим датасет DeepFashion. Он содержит изображения 17 категорий одежды (куртки, джинсы, шорты и т.д.) и ~8 тысяч классов (артикулов конкретных товаров). Медианный размер класса — 5 изображений.
Классы разделены на два непересекающихся множества для тренировки и валидации. Обратите внимание, разделение сделано именно на уровне классов. В свою очередь, валидационную часть делят на запросы (query) и поисковый индекс (gallery), чтобы в дальнейшем сымитировать поиск и оценить его точность. Обратите внимание, что здесь разделение уже на уровне изображений: например, для куртки с артикулом 001 есть 7 изображений, 3 из них попадают в запросы, а остальные 4 — в индекс. Мы стремимся обучить модель так, чтобы для векторов, представляющих эти 3 запроса, ближайшими оказались данные 4 вектора из индекса.
Рассмотрим как происходит обучение модели с классическим triplet loss.
Коротко про triplet loss
где
— триплет, в который входят три объекта: якорный, позитивный (из того же класса, что и якорный), негативный (из класса, отличающегося от якорного);
— позитивное расстояние, которое мы хотим уменьшать, — негативное расстояние, которое мы хотим увеличивать;
— зазор (margin).
Есть и другие варианты triplet loss'a, которые иногда позволяют добиться большей стабильности обучения:
Пример триплета: на первых двух фотографиях одинаковые салатовые блузки, на последней — красная майка.
Тренировка
Сэмплер создает батч с условием, что в нем найдутся хотя бы 2 класса и 2 изображения на каждый (иначе мы не сможем составить триплеты). Часто батчи сбалансированы по классам.
Батч подаётся в модель и превращается в набор векторов.
Майнер собирает триплеты, используя вектора из предыдущего пункта. Можно составить все возможные триплеты или только самые сложные (когда позитивное расстояние максимально, а негативное минимально); можно плавно управлять сложностью триплетов; так же есть техники использования векторов из банка памяти, который представляет собой очередь, обновляющуюся после очередного батча.
Оптимизатор делает шаг по градиенту triplet loss'а, вычисленного для триплетов с предыдущего шага.
Схема тренировки может меняться, например:
Если вы работаете с классификационной функцией потерь (например, Log loss или ArcFace), то майнер не нужeн.
Если у вас нет разметки на классы, а вместо этого есть разметка на уровне пар или триплетов, то они сразу передаются в функцию потерь, без использования майнера.
Если вы работаете с quadruplet loss, то майнер собирает четверки, если с contrastive loss, то пары.
Валидация
Делаем инференс на всем валидационном наборе, накапливаем полученные вектора.
Считаем расстояния между всеми запросами и всеми изображениями в индексе. Получается матрица размером . Сортируем строки матрицы, чтобы в начало попали элементы индекса, наиболее близкие к запросам.
Вычисляем метрики. Логично использовать метрики из информационного поиска, например:
— , если есть хотя бы один правильный ответ в первых результах, иначе .
— доля правильных ответов в первых результах.
— аналог предыдущего, но учитываются позиции правильных ответов.
Рассмотрим пример ниже для трёх запросов (выделены синим), для которых мы вернули 5 изображений в порядке возрастания расстояния между ними и запросом; часть из результатов поиска имеют тот же артикул, что и запрос (выделены зелёным как правильные ответы), а часть из них имеют другой (выделены красным как ошибки). Для всех трёх запросов . Снемного сложнее, так как нам необходимо знать, сколько всего правильных ответов можно вернуть для запроса, чтобы не штрафовать модель в случаях, когда даже теоретически не из чего выбрать 5 правильных ответов. Допустим, для первого запроса существует 5 правильных ответов в поисковом индексе, для второго — 3, для третьего — 4. Тогда метрика для первого равна , для второго , для третьего .
О библиотеке OpenMetricLearning
OML это новая библиотека для representation learning, написанная поверх PyTorch. Для удобства понимания ниже приведены примеры, написанные на "голом" PyTorch. Вероятно, на практике вы захотите использовать примеры с PyTorch Lightning или Config API (о них дальше), но внутри они устроены так же.
Код тренировки
import torch
from tqdm import tqdm
from oml.datasets.base import DatasetWithLabels
from oml.losses.triplet import TripletLossWithMiner
from oml.miners.inbatch_all_tri import AllTripletsMiner
from oml.models.vit.vit import ViTExtractor
from oml.samplers.balance import BalanceSampler
from oml.utils.download_mock_dataset import download_mock_dataset
# скачиваем небольшой игрушечный датасет
dataset_root = "mock_dataset/"
df_train, _ = download_mock_dataset(dataset_root)
# создаем модель на основе претренированного Self-Supervised чекпоинта
model = ViTExtractor("vits16_dino", arch="vits16", normalise_features=False).train()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-6)
train_dataset = DatasetWithLabels(df_train, dataset_root=dataset_root)
# создаем criterion, включающий в себя и фунцию потерь, и майнер
criterion = TripletLossWithMiner(margin=0.1, miner=AllTripletsMiner())
# создаем сэмплер, который в каждый класс будет класть 2 представителя 2-х классов
sampler = BalanceSampler(train_dataset.get_labels(), n_labels=2, n_instances=2)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_sampler=sampler)
# т.к. логика, специфическая для metric learning скрыта в criterion, тренировка
# ничем не отличается от обычной
for batch in tqdm(train_loader):
embeddings = model(batch["input_tensors"])
loss = criterion(embeddings, batch["labels"])
loss.backward()
optimizer.step()
optimizer.zero_grad()
Код валидации
import torch
from tqdm import tqdm
from oml.datasets.base import DatasetQueryGallery
from oml.metrics.embeddings import EmbeddingMetrics
from oml.models.vit.vit import ViTExtractor
from oml.utils.download_mock_dataset import download_mock_dataset
# скачиваем небольшой игрушечный датасет
dataset_root = "mock_dataset/"
_, df_val = download_mock_dataset(dataset_root)
# создаем модель на основе претренированного Self-Supervised чекпоинта
model = ViTExtractor("vits16_dino", arch="vits16", normalise_features=False).eval()
val_dataset = DatasetQueryGallery(df_val, dataset_root=dataset_root)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=4)
# создаем калькулятор метрик, в котором будем накапливать вектора
calculator = EmbeddingMetrics()
calculator.setup(num_samples=len(val_dataset))
with torch.no_grad():
for batch in tqdm(val_loader):
batch["embeddings"] = model(batch["input_tensors"])
calculator.update_data(batch) # накапливаем вектора
metrics = calculator.compute_metrics() # вычисляем метрики: cmc@k, precision@k, map@k
Подробнее про OML в контексте сравнения с PyTorchMetricLearning
Всё познается в сравнении, поэтому, чтобы больше узнать об OML, давайте сравним его с популярной библиотекой PyTorchMetricLearning (PML). Изначально, в нашем проекте мы использовали именно её, но в итоге создали свой проект, более ориентированный на пайплайн обучения и практическое применение.
Сильная сторона PML в реализации большого количества лоссов, майнеров и функций расстояний. По этой причине мы добавили примеры их использования с OML.
OML имеет Config API, который позволяет тренировать модель без написания кода. Требуется только подготовить данные в нужном формате и адаптировать конфигурационный файл под свои нужны. (Как тренировка детектора в mmdetection, когда нужно подготовить датасет в формате COCO).
OML ориентирован на end-to-end тренировку и рецепты по практическому применению, поэтому мы предлагаем примеры с набором гиперпараметров, которые хорошо показали себя на датасетах, приближенных к реальной жизни (содержащих тысячи классов). В то же время, PML скорее является набором реализаций различных функции потерь, что подтверждает сам автор, а примеры предлагаются на игрушечных датасетах CIFAR и MNIST.
В OML есть зоопарк моделей, натренированных другими исследователями в self-supervised режиме, и нами. Получить доступ к таким моделям так же просто, как в
torchvision
, когда вы пишетеresnet(pretrained=True)
.OML интегрирован с PyTroch Lightning, поэтому мы можем использовать всю функциональность его Trainer'a, что особенно полезно для работы в DDP режиме с несколькими GPU: вы можете сравнить наш и их примеры. Хотя в PML тоже есть модуль Trainers, он давно не обновлялся и не используется в примерах, где каждый раз реализуются
train
иtest
функции, что заставляет самостоятельно реализовывать ту функциональность, которую обычно выносят в колбэки trainer'a (сохранение лучших весов, early stopping, изменение learning rate, визуализацию, логирование, и прочее).
Насколько хорошую модель можно обучить с OML?
На уровне лучших существующих моделей. Например, сопоставимо с Hyp-ViT, который представляет собой ViT обученный с contrastive loss, поверх выходов которого применяются геометрические преобразования в гиперболическом пространстве.
Мы обучили такую же архитектуру с triplet loss, зафиксировав для чистоты сравнения другие параметры, такие как тренировочные и тестовые трансформации, размер изображений и оптимизатор. При этом мы использовали наши сэмплер и майнер, которые реализуют простые эвристики:
Сэмплер кладёт в батч только классы (артикулы) из ограниченного количества категории . Например, при в батч попадают, например, только куртки, что автоматически делает негативные пары сложными: для модели гораздо важнее выучить, почему отличаются какие-то две куртки, чем какая-то куртка и какие-то брюки.
Майнер ещё больше усложняет задачу, оставляя только самые сложные триплеты (с максимальным позитивным и минимальным негативным расстояниями).
Таким образом, нам удалось получить модель на уровне SotA, обойдясь простыми эвристиками и не прибегая к сложной математике.
UPD: После дополнительной серии экспериментов мы обнаружили, что даже без ограничения количества различных категорий в батче, ViT + triplet loss показывает примерно такие же результаты. Другими словами, обычного hard mining достаточно, чтобы показывать результаты, сопоставимые со SotA.
Заключение
Если Вам захотелось поработать с данным типом задач на практике, приглашаем Вас поучаствовать в OpenMetricLearning. Можно взяться за одну из существующих задач (у нас есть и инженерные задачи, и ориентированные на ресёрч) или предложить свою идею, создав новый issue.