Любой, кто хоть раз обучал нейронки, знает, что принято на каждой эпохе шаффлить датасет, чтобы не повторялся порядок батчей. А зачем это делать? Обычно это объясняют тем, что шаффлинг улучшает генерализацию сетей, делает точнее эстимейт градиента на батчах и уменьшает вероятность застревания SGD в локальных минимумах. Здесь можно посмотреть визуализацию поведения градиентов батчей с шаффлингом и без шаффлинга. Ну и самый простой и традиционный для ML аргумент - наши эксперименты подтверждают, что отключение шаффлинга действительно ухудшает метрики, так что проверяйте, не забагован ли ваш трейн-луп ? Еще больше полезной информации в нашем telegram-канале Варим ML
Почему вообще может быть полезно использовать хитрые стратегии сэмплинга данных? Помимо уже упомянутого дизбаланса классов ваш датасет может иметь ещё ряд особенностей:
Шумная разметка - из-за ошибок разметчиков или при использовании псевдолейбелинга.
Различная сложность сэмплов - в датасете могут превалировать простые примеры, которые сеть легко и быстро выучивает.
Динамичность - датасет может "обрастать" новыми сэмплами, в которых могут содержаться новые классы или паттерны.
Часто эти свойства ещё и идут в комплекте - например, в medical imaging довольно много шумной и неконсистентной разметки, много простых случаев с очевидными признаками заболевания, и часто в ходе развития продукта добавляются новые патологии и случаи. Давайте посмотрим какие способы работы с такими данными может предложить нам DL-литература.
Hard-example mining
Идея hard-example mining известна всем любителям детекторов. Популярнейшая статья Training Region-based Object Detectors with OHEM развила старую идею бутстрапинга примеров для обучения - для бэкворд-пасса выбираем те регионы картинки, которые дают наибольший лосс. Такой подход может увеличить метрики и ускорить обучение. Подобная идея может использоваться не только для детекции - например, можно чаще сэмплить картинки с высоким лоссом, картинки со сложными классами и даже использовать предсказание лосса для активного обучения.
Недостаток такого подхода - он может не сработать, а то и ухудшить метрики ? Особенно это вероятно, если датасет шумный - ведь тогда наибольший лосс будут иметь как раз примеры с неверной или неточной разметкой. В таких случаях можно попробовать semi-hard sampling, к примеру, выбирать не самые сложные примеры, а сэмплы с лоссом между 85 и 95 перцентилем;
Curriculum learning
Идея curriculum learning (очень сложное слово, поэтому дальше CL) черпает вдохновение из процесса обучения детей - ведь в школе мы начинаем учиться простым вещам и концепциям, а затем уже переходим к более сложным. Собственно говоря, слово curriculum (последний раз!) переводится как "учебный план". Чтобы применить подобный подход к нейронкам нам нужно иметь два основных компонента - измеритель сложности сэмпла и scheduler, который будет выстраивать процесс обучения на основе сложности.
Как можно измерить сложность сэмпла для нейронки?
Эвристки. Мы можем предположить, что сложность предсказания будет коррелировать с каким-то расчётным показателем - количеством объектов на изображении, их размером, согласованностью разметчиков, классом изображения или какой-то ещё мета-информацией.
Self-taught. В этом подходе мы обучаем сетку фиксированное количество эпох, рассчитываем лосс каждого сэмпла и используем его для дальнейшего шедулинга.
Self-paced. Это группа онлайн-методов, которые динамически измеряют сложность примеров в батче на основе текущего лосса, по норме градиентов или другим способом.
Какие бывают шедулеры?
CL - по ходу обучения увеличиваем вероятность сэмплинга сложных кейсов.
Anti-CL - наоборот, идём от сложного к простому (я вот знаю людей, которые любят так делать).
Гибридные способы - например, учитываем сложность примера, но стараемся поддерживать diversity по классам, чтобы каждый класс хотя бы минимально участвовал в каждой эпохе.
При применении методов CL есть ряд нюансов, которые стоит иметь в виду:
Важность сэмпла для сетки может меняться по ходу обучения.
Априори определить оптимальное расписание для конкретной задачи практически невозможно - нужны эксперименты.
Статьи по CL обычно написаны на примере задач классификации, и часто требуется их адаптация под детекцию и сегментацию.
CL увеличивает время обучения - часто нужен лишний форвард-пасс или большой батч.
Можно изменять вероятность сэмплинга, а можно вес каждого сэмпла при расчёте лосса.
Для желающих подробнее ознакомиться с CL-методами оставляю ссылки на surveys - раз и два.
Data-driven learning
Сэмплы с низким лоссом могут быть слишком лёгкими и бесполезными, с высоким - шумными, с высоким uncertainty - странными выбросами, нерелевантными для задачи. Ааааа, что делать?
Отдельная крайне любопытная группа методов фокусируется на том, чтобы выучить оптимальное расписание обучения или схему ревейтинга сэмплов. Естественно, чтобы что-то выучить, нам нужны данные. Если конкретнее, то для данной задачи нам потребуется хороший, сбалансированный, соответствующий тест-распределению валидационный сет. Предположим, он есть (хаха). Что можно делать дальше?
Prioritized Training on points that are learnable, worth learning, and not yet learned. Давайте затреним небольшую модель на нашем валидационном сете, а затем посчитаем и запомним лосс каждого сэмпла на трейн-сете. Назовём это irreducible holdout loss - ниже этого лосса мы не можем спуститься без обучения на трейне. Заново инициализируем модель и начинаем учить её на трейне. На каждом сэмпле большого батча размера B мы считаем текущий трейн-лосс и вычитаем из него зафиксированный ранее irreducible loss. Полученную разницу назовём reducible loss. Дальше всё просто - берём b примеров из батча с наибольшим reducible loss и добавляем их айдишники в список. Этот список в итоге и будет расписанием, по которому мы будем обучать настоящую модель. Почему это может работать? Для простых кейсов reducible loss будет низким, потому что у них будет низкий training loss. Для шумных сэмлов он тоже будет низким, ведь у них будет высокий irreducible loss - шумные кейсы нереально предсказать корректно после обучения на валидационном сете. Наконец, выбросы из трейн-сета, которых нет в нашей чистой валидации, тоже будут иметь высокий irreducible loss. Бинго, мы набрали список полезных и чистеньких сэмплов!
Learning to Reweight Examples for Robust DL. В данном методе мы пытаемся выучить оптимальные веса сэмплов в трейн-сете. Для этого мы выучиваем эти веса с помощью градиентного спуска. Для этого мы двигаем веса примеров на текущем трейн-батче в сторону, которая минимизировала бы лосс на текущем вал-батче.
MentorNet. Если сетки - это дети, то им нужен учитель, который будет оберегать их от вредного воздействия. В качестве такого учителя выступает отдельная сетка. В неё в качестве фичей залетают лоссы данного сэмпла на последних нескольких эпохах и разницы лосса и moving average лосса (кодируются через bidirectional LSTM), а также эмбеддингы класса и текущей эпохи обучения. В качестве таргета выступает бинарная переменная, которая равна 1, если мы подали лосс для корректного лейбла, и 0, есля для шумного неправильного. Таким образом, MentorNet учится присваивать низкие веса примерам с неправильной разметкой. Здесь можно посмотреть симпатичные визуализации того, как такая сетка может выучивать фиксированные расписания типа hard mining или CL, а также трушное data-driven расписание.
Звучит всё это крайне привлекательно, но главная проблема - где же взять такой чистый, красивый, репрезентативный валидационный сет? ?
Incremental learning
Пожалуй, incremental learning - это тема для отдельного поста, но я всё-таки оставлю краткое упоминание здесь, ведь очень-очень редко все требования к аутпутам ML-систем известны заранее и фиксированы. По крайней мере у нас новые патологии в разметке появляются регулярно.
В таких случаях можно заново переучить всю сеть, дообучить её только на новых данных (и столкнуться с catastrophic forgetting), а можно использовать методы class-incremental learning. Из интересного в этой области могу выделить статью Retrospective Class Incremental Learning, в которой рассматривается реалистичный кейс, в котором у нас сохраняется доступ ко всей старой датке (в дополнение к новой), но нам не хочется тратить много времени на полное переобучение на всех старых данных.
Выводы
Главный вывод, как и всегда, неутешительный... В любом случае понадобится мало-мальски адекватный валидационный сет и правильная метрика - как минимум хочется уметь сравнивать разные методы сэмплинга, а в лучше случае можно даже попробовать выучить data-driven расписание. Но для начала я рекомендовал бы попробовать какие-то простые методы с фиксированным расписанием и посмотреть, сильно ли просядет метрика ?
Есть ещё интересная задача core-set selection - выбор сабсета большого набора данных, который даст нам адекватное качество на нашей задаче, но об этом, конечно, же в другой раз.
Если вы хотите узнать ещё больше об организации процессов ML-разработки, подписывайтесь на наш Телеграм-канал Варим ML