Pull to refresh

Хитрые методики сэмплинга данных

Reading time6 min
Views2.7K

Любой, кто хоть раз обучал нейронки, знает, что принято на каждой эпохе шаффлить датасет, чтобы не повторялся порядок батчей. А зачем это делать? Обычно это объясняют тем, что шаффлинг улучшает генерализацию сетей, делает точнее эстимейт градиента на батчах и уменьшает вероятность застревания SGD в локальных минимумах. Здесь можно посмотреть визуализацию поведения градиентов батчей с шаффлингом и без шаффлинга. Ну и самый простой и традиционный для ML аргумент - наши эксперименты подтверждают, что отключение шаффлинга действительно ухудшает метрики, так что проверяйте, не забагован ли ваш трейн-луп ? Еще больше полезной информации в нашем telegram-канале Варим ML

Почему вообще может быть полезно использовать хитрые стратегии сэмплинга данных? Помимо уже упомянутого дизбаланса классов ваш датасет может иметь ещё ряд особенностей:

  • Шумная разметка - из-за ошибок разметчиков или при использовании псевдолейбелинга.

  • Различная сложность сэмплов - в датасете могут превалировать простые примеры, которые сеть легко и быстро выучивает.

  • Динамичность - датасет может "обрастать" новыми сэмплами, в которых могут содержаться новые классы или паттерны.

Часто эти свойства ещё и идут в комплекте - например, в medical imaging довольно много шумной и неконсистентной разметки, много простых случаев с очевидными признаками заболевания, и часто в ходе развития продукта добавляются новые патологии и случаи. Давайте посмотрим какие способы работы с такими данными может предложить нам DL-литература.

Hard-example mining

Идея hard-example mining известна всем любителям детекторов. Популярнейшая статья Training Region-based Object Detectors with OHEM развила старую идею бутстрапинга примеров для обучения - для бэкворд-пасса выбираем те регионы картинки, которые дают наибольший лосс. Такой подход может увеличить метрики и ускорить обучение. Подобная идея может использоваться не только для детекции - например, можно чаще сэмплить картинки с высоким лоссом, картинки со сложными классами и даже использовать предсказание лосса для активного обучения.

Недостаток такого подхода - он может не сработать, а то и ухудшить метрики ? Особенно это вероятно, если датасет шумный - ведь тогда наибольший лосс будут иметь как раз примеры с неверной или неточной разметкой. В таких случаях можно попробовать semi-hard sampling, к примеру, выбирать не самые сложные примеры, а сэмплы с лоссом между 85 и 95 перцентилем;

Curriculum learning

Идея curriculum learning (очень сложное слово, поэтому дальше CL) черпает вдохновение из процесса обучения детей - ведь в школе мы начинаем учиться простым вещам и концепциям, а затем уже переходим к более сложным. Собственно говоря, слово curriculum (последний раз!) переводится как "учебный план". Чтобы применить подобный подход к нейронкам нам нужно иметь два основных компонента - измеритель сложности сэмпла и scheduler, который будет выстраивать процесс обучения на основе сложности.

Как можно измерить сложность сэмпла для нейронки?

  • Эвристки. Мы можем предположить, что сложность предсказания будет коррелировать с каким-то расчётным показателем - количеством объектов на изображении, их размером, согласованностью разметчиков, классом изображения или какой-то ещё мета-информацией.

  • Self-taught. В этом подходе мы обучаем сетку фиксированное количество эпох, рассчитываем лосс каждого сэмпла и используем его для дальнейшего шедулинга.

  • Self-paced. Это группа онлайн-методов, которые динамически измеряют сложность примеров в батче на основе текущего лосса, по норме градиентов или другим способом.

Какие бывают шедулеры?

  • CL - по ходу обучения увеличиваем вероятность сэмплинга сложных кейсов.

  • Anti-CL - наоборот, идём от сложного к простому (я вот знаю людей, которые любят так делать).

  • Гибридные способы - например, учитываем сложность примера, но стараемся поддерживать diversity по классам, чтобы каждый класс хотя бы минимально участвовал в каждой эпохе.

При применении методов CL есть ряд нюансов, которые стоит иметь в виду:

  • Важность сэмпла для сетки может меняться по ходу обучения.

  • Априори определить оптимальное расписание для конкретной задачи практически невозможно - нужны эксперименты.

  • Статьи по CL обычно написаны на примере задач классификации, и часто требуется их адаптация под детекцию и сегментацию.

  • CL увеличивает время обучения - часто нужен лишний форвард-пасс или большой батч.

  • Можно изменять вероятность сэмплинга, а можно вес каждого сэмпла при расчёте лосса.

Для желающих подробнее ознакомиться с CL-методами оставляю ссылки на surveys - раз и два.

Data-driven learning

Сэмплы с низким лоссом могут быть слишком лёгкими и бесполезными, с высоким - шумными, с высоким uncertainty - странными выбросами, нерелевантными для задачи. Ааааа, что делать?

Отдельная крайне любопытная группа методов фокусируется на том, чтобы выучить оптимальное расписание обучения или схему ревейтинга сэмплов. Естественно, чтобы что-то выучить, нам нужны данные. Если конкретнее, то для данной задачи нам потребуется хороший, сбалансированный, соответствующий тест-распределению валидационный сет. Предположим, он есть (хаха). Что можно делать дальше?

  • Prioritized Training on points that are learnable, worth learning, and not yet learned. Давайте затреним небольшую модель на нашем валидационном сете, а затем посчитаем и запомним лосс каждого сэмпла на трейн-сете. Назовём это irreducible holdout loss - ниже этого лосса мы не можем спуститься без обучения на трейне. Заново инициализируем модель и начинаем учить её на трейне. На каждом сэмпле большого батча размера B мы считаем текущий трейн-лосс и вычитаем из него зафиксированный ранее irreducible loss. Полученную разницу назовём reducible loss. Дальше всё просто - берём b примеров из батча с наибольшим reducible loss и добавляем их айдишники в список. Этот список в итоге и будет расписанием, по которому мы будем обучать настоящую модель. Почему это может работать? Для простых кейсов reducible loss будет низким, потому что у них будет низкий training loss. Для шумных сэмлов он тоже будет низким, ведь у них будет высокий irreducible loss - шумные кейсы нереально предсказать корректно после обучения на валидационном сете. Наконец, выбросы из трейн-сета, которых нет в нашей чистой валидации, тоже будут иметь высокий irreducible loss. Бинго, мы набрали список полезных и чистеньких сэмплов!

  • Learning to Reweight Examples for Robust DL. В данном методе мы пытаемся выучить оптимальные веса сэмплов в трейн-сете. Для этого мы выучиваем эти веса с помощью градиентного спуска. Для этого мы двигаем веса примеров на текущем трейн-батче в сторону, которая минимизировала бы лосс на текущем вал-батче.

Выглядит жутко, но идея на самом деле не очень сложная
Выглядит жутко, но идея на самом деле не очень сложная
  • MentorNet. Если сетки - это дети, то им нужен учитель, который будет оберегать их от вредного воздействия. В качестве такого учителя выступает отдельная сетка. В неё в качестве фичей залетают лоссы данного сэмпла на последних нескольких эпохах и разницы лосса и moving average лосса (кодируются через bidirectional LSTM), а также эмбеддингы класса и текущей эпохи обучения. В качестве таргета выступает бинарная переменная, которая равна 1, если мы подали лосс для корректного лейбла, и 0, есля для шумного неправильного. Таким образом, MentorNet учится присваивать низкие веса примерам с неправильной разметкой. Здесь можно посмотреть симпатичные визуализации того, как такая сетка может выучивать фиксированные расписания типа hard mining или CL, а также трушное data-driven расписание.

Рекомендую разбираться под пивко
Рекомендую разбираться под пивко

Звучит всё это крайне привлекательно, но главная проблема - где же взять такой чистый, красивый, репрезентативный валидационный сет? ?

Incremental learning

Пожалуй, incremental learning - это тема для отдельного поста, но я всё-таки оставлю краткое упоминание здесь, ведь очень-очень редко все требования к аутпутам ML-систем известны заранее и фиксированы. По крайней мере у нас новые патологии в разметке появляются регулярно.

В таких случаях можно заново переучить всю сеть, дообучить её только на новых данных (и столкнуться с catastrophic forgetting), а можно использовать методы class-incremental learning. Из интересного в этой области могу выделить статью Retrospective Class Incremental Learning, в которой рассматривается реалистичный кейс, в котором у нас сохраняется доступ ко всей старой датке (в дополнение к новой), но нам не хочется тратить много времени на полное переобучение на всех старых данных.

Выводы

Главный вывод, как и всегда, неутешительный... В любом случае понадобится мало-мальски адекватный валидационный сет и правильная метрика - как минимум хочется уметь сравнивать разные методы сэмплинга, а в лучше случае можно даже попробовать выучить data-driven расписание. Но для начала я рекомендовал бы попробовать какие-то простые методы с фиксированным расписанием и посмотреть, сильно ли просядет метрика ?

Есть ещё интересная задача core-set selection - выбор сабсета большого набора данных, который даст нам адекватное качество на нашей задаче, но об этом, конечно, же в другой раз.

Если вы хотите узнать ещё больше об организации процессов ML-разработки, подписывайтесь на наш Телеграм-канал Варим ML

Tags:
Hubs:
Total votes 6: ↑5 and ↓1+4
Comments3

Articles