Как стать автором
Обновить

Как нейросеть достопримечательности на фотокарточках распознавала

Уровень сложностиСредний
Время на прочтение7 мин
Количество просмотров1.8K

Введение

Всем привет. Это мой первый пост и первый обзор на работу. В двух словах опишу, чем это я тут занимался.

Цель проекта заключалась в распознавании достопримечательностей на фотографиях при помощи машинного обучения, а именно свёрточных нейросетей. Данная тема была выбрана из следующих соображений:

  • у меня уже был некий опыт работы с задачами компьютерного зрения

  • задача звучала так, как будто её можно сделать очень быстро и не прикладывать большое количество усилий, и, что немаловажно, вычислительных ресурсов (все сетки обучались в колабе или на кагле)

  • задача может иметь какое-то практическое применение (ну, в теории...)

Сначала он планировался как исключительно учебный проект, но потом я проникся его идеей и решил доработать его.

Далее, я буду рассказывать о том, как я подходил к решению этой задачи, и при этом буду стараться идти по коду из ноутбука, в котором и происходила вся магия, при этом стараться пояснять какие-то свои действия. Возможно, это поможет кому-нибудь преодолеть боязнь "чистого листа" и увидеть, что подобного рода вещи делаются действительно просто!

Инструменты

Ну и первым делом расскажу об инструментах, которые использовались при реализации проекта.

  • Colab/Kaggle: использовались для того, чтобы обучать сети на ГПУ.

  • Weights And Biases: сервис, в который я сохранял модели, их описания, добавлял лоссы, значения метрик, параметры обучения, препроцессинга. С данными можно ознакомиться по ссылке. В процессе написания кода немного изменялся раздел метадаты, который, по сути, содержит параметры обучения и препроцессинга. В разделе с файлами вы сможете ознакомиться с описанием сети (как устроены её слои), скачать обученные веса сети, а также глянуть на значение лоссов и метрик.

Данные для обучения

Ну, наверное, стоило бы начать с выбора данных для обучения нейронной сети. Для этого я покопался в датасетах на кагле (тык), и вот такой сайт мне еще приглянулся.

Как выяснилось, существует соревнование от Google, связанное как раз таки с распознаванием достопримечательностей. Здесь появилась первая проблема: датасет весит \approx100 гб, и содержит около 200.000 классов и 5.000.000 изображений. Понимая, что сетки в дальнейшем я буду учить не на своей пекарне, от данного варианта пришлось отказаться. Полистав еще, я остановился на этом датасете. В нем содержится 210 классов и примерно по 50 фотокарточек на каждый из классов. Картинки все разного размера, снятые с разных ракурсов, с разного расстояния. В общем, датасет совсем не рафинированный, а пока что я работал только с такими. Ну, самому данные размечать не надо, за это уже лайк! Приведу вам парочку особо удачных фотографий:

Так, например, выглядит "Большой театр в Москве"
Так, например, выглядит "Большой театр в Москве"
А вот здесь нам вместо фотографии Централ Парка в Нью-Йорке подсунули его схему
А вот здесь нам вместо фотографии Централ Парка в Нью-Йорке подсунули его схему

Хранение и обработка данных (ч.1)

В данном разделе я бы хотел рассказать о том, как данные хранились и обрабатывались.

Первым делом, мы проверим картинки на то, сколько каналов они содержат. Привычнее всего все таки работать с изображениями с тремя каналами (RGB), да и выбранный датасет по большей части содержит фотографии именно такого формата. Но помимо трехканальных изображений в данном датасете нам встречаются и черно-белые картинки, и фотографии формата RGBA. Так как таких объектов мало (19), то их удаление не должно вызвать снижение точности итоговой модели.

Для хранения данных я написал несколько классов, которые наследуются от torch.utils.data.Dataset. При реализации таких классов необходимым условием является переопределение методов __getitem__ и __len__ (то есть добавить классу возможность получать элемент по индексу, и возвращать длину экземпляра класса). Этим я и решил заняться, так как реализовывать данный функционал можно разными способами, два из которых я рассмотрел ниже.

Датасет с быстрым доступом (FastDataset)

Первое, что пришло в голову: давайте просто считывать изображения, приводить их к одному размеру, переводить их в тензоры pytorch и хранить тензоры. То есть, весь этап преобработки, а именно: ресайзинг, приведение к тензорам, нормировка, будет производиться при инициализации. Далее, когда мы хотим перебрать элементы датасета, мы просто достаем из памяти тензоры, не выполняя никакой дополнительной обработки. Казалось бы, что может пойти не так... Но ответ, на самом деле, очевиден: хранить обработанные данные - удовольствие недешевое, и за него приходится платить чеканной монетой памятью. Что же делать...

Датасет с медленным доступом (CustomDataset)

Второе, что пришло в голову: а давайте мы просто будем хранить список, содержащий пути до наших картинок. Таким образом, расходы памяти становятся в разы меньше. Но при таком подходе мы жертвуем временем, за которое происходит обход. Ведь при хранении данных в виде списка путей при каждом обращении мы должны считать изображение по его пути, применить операции ресайзинга и приведения к тензорам, и только после этого мы можем работать с полученным объектом. Долго, да, но ничего не поделать.

Обучение сети

В данном разделе мы немного отойдем от ноутбука.

Разбиение данных

Итак, в нашем арсенале уже имеется два класса, позволяющих работать с данными. Давайте уже что-нибудь обучим. Для этого нам нужно написать цикл обучения сети, в котором также будем рассчитывать метрики на валидационной выборке для грамотного подбора гиперпараметров сети. Но для того, чтоб таковая выборка появилась, надо разбить данные на тренировочную и валидационную выборки. Для этого в каждом классе я реализовал метод разбиения исходного датасета на две непересекающихся части. Хочу заметить, что внутри датасета каждый класс обрабатывался отдельно, поэтому на выходе мы получили сбалансированное разделение.

Обучение

При обучении использовался оптимизатор Adam из модуля pytorch, и в качестве функции потерь была выбрана nn.CrossEntropyLoss.

Сначала я пробовал писать и обучать совсем простые сети, которые состояли из двух частей:

  • сверточной части, в которой использовались свертки и пуллинги;

  • полносвязной части, в которой использовались линейные слои и чуть-чуть дропауты (на wandb это - нулевая версия CNN).

Стало понятно, что нужно усложнять архитектуру. Добавил слои батч-нормализации, и качество очень приятно подскочило. В общем, меняя параметры обучения и архитектуру сети, где-то с удалось поднять качество значение метрики F1 на валидационной выборке до 93%. Тогда я подумал, что цель достигнута, и получилось отделаться малой кровью, но не тут то было. И просто для того, чтобы убедиться в том, что все действительно хорошо, решил погуглить про метрику, которую я использовал. Оказалось, что считалось совсем не то, что я ожидал.

Дело в том, что у реализации F1 есть несколько вариаций работы. И тот, который был написан в моем коде, возвращал Accuracy. Для многоклассовой классификации это никуда не годится. И когда я все исправил, значение метрики на валидационной выборке вернулось к 31%, а вот на тренировочной выборке было 96%. Вот это уже по нашему! Сетка хорошо так переобучается. Давайте решать проблему.

Хранение и обработка данных (ч.2)

Первая идея, которая меня посетила: скорее всего, сетка просто не может научиться на 45 изображениях, некоторые из которых еще и не самого лучшего качества. Что можно с этим сделать? Ну давайте применим аугментацию. Не знаю, какой контингент читает эту рукопись, так что дам краткое пояснение. Аугментация, по сути - увеличение объема данных, за счёт которого можно обучить сетку. Давайте попробуем искусственно расширить множество уже имеющихся картинок путем применения к ним неких трансформаций.

Идея следующая: давайте к каждому существующему изображению будем применять набор преобразований, например, поворот на 180 градусов, или осуществлять небольшой поворот изображения.

Вот такие картинки получились из верхней левой.
Вот такие картинки получились из верхней левой.

Мы смогли расширить наш датасет аж в 7 раз! Давайте используем это для обучения сетки.

Далее, я реализовал еще два класса, по аналогии с датасетами из ч.1: с быстрым доступом AugmentedFastDataset и медленным доступом AugmentedCustomDataset. Проблема возникла моментально: уже на данном этапе, при применении 7 различных видов трансформаций, датасет с быстрым доступом сжирал всю память, и все падало. Соответственно, пришлось использовать его менее быструю, но более экономичную (в плане памяти) версию.

Ну и что мы видим (посмотреть можно CNN.v9): модель все равно очень сильно переобучается. Что же еще можно такого придумать...

И пришла в голову следующая идея: зачем применять по одной трансформации(назову так операцию изменения исходного изображения) за раз? Можно ведь применять их последовательно. Тогда, делая различные комбинации, мы сможем еще больше расширить датасет. Давайте попробуем реализовать эту задумку в классе AdvancedCustomDataset. Кратко опишу процесс: в конструктор класса мы передаем аргумент ex_amount, которая отвечает за то, сколько экземпляров для каждого класса мы хотим получить. Далее, проходимся по каждому классу, и до тех пор, пока не получим нужное число изображений, применяем случайный набор трансформаций к случайному изображению. Ниже, можно увидеть пример того, как работает данная задумка.

Из одной картинки мы получили 44
Из одной картинки мы получили 44

Также, произошли еще некоторые минорные изменения, связанные с заменой некоторых функций на их аналоги из других модулей. Причина проста: так как датасет сильно расширился, и доступ к элементам медленный, на один проход уходит уйма времени. Поэтому приятно было бы сэкономить на таких мелочах пару-тройку минут.

  • открытие изображение раньше производилось при помощи библиотеки PIL. Как показали сравнения, открытие изображений при помощи библиотеки cv2 работает гораздо быстрее. Поэтому, в отличии от остальных классов, в последнем датасете используется аналог из cv2.

  • операции по изменению изображений были взяты из модуля torchvision.transforms. Как выяснилось позже, аналогичные функции из модуля albumentations работают быстрее.

Ну что, данных наделали, метрики наладили, пора учиться. Коль теперь я был волен выбирать, сколько картинок для каждого класса я хочу, я выставил значение 2000. И после длительного процесса обучения и валидации получаем модель со значением F1 = 60% на валидационной выборке. Уже хорошо, я считаю.

Немножко про FineTuning

Ну что же, какое-то качество мы уже получили, но все же оно меня не совсем устраивает. Давайте теперь отвлечемся от самописной архитектуры, и попробуем переобучить уже существующую сеть. В качестве такой сети я взял модель VGG13 с батч-нормализацией с предобученными весами. Далее, заморозил всю свёрточную часть, немного поигрался с классификатором, и поставил это все дело учиться. Получилось еще лучше, чем было до этого: метрика на валидационной выборке равна 70% (тык).

Послесловие

Итак, что мы получили на выходе: две сети, которые работают хорошо, и действительно качественно распознают изображения. Допускаю даже то, что та самая тридцати процентная ошибка возникает из-за кривых фотографий в датасете (примеры приводил выше).

Попытался оформить все это дело в мини-проект, который можно скачать с гитхаба.

Прошу писать любые замечания, связанные с написанным, буду рад набраться опыта!

Теги:
Хабы:
Всего голосов 2: ↑2 и ↓0+2
Комментарии1

Публикации

Истории

Работа

Data Scientist
78 вакансий
Python разработчик
120 вакансий

Ближайшие события

7 – 8 ноября
Конференция byteoilgas_conf 2024
МоскваОнлайн
7 – 8 ноября
Конференция «Матемаркетинг»
МоскваОнлайн
15 – 16 ноября
IT-конференция Merge Skolkovo
Москва
22 – 24 ноября
Хакатон «AgroCode Hack Genetics'24»
Онлайн
28 ноября
Конференция «TechRec: ITHR CAMPUS»
МоскваОнлайн
25 – 26 апреля
IT-конференция Merge Tatarstan 2025
Казань