Dive into pyTorch

  • Tutorial

Всем привет. Меня зовут Артур Кадурин, я руковожу исследованиями в области глубокого обучения для разработки новых лекарственных препаратов в компании Insilico Medicine. В Insilico мы используем самые современные методы машинного обучения, а также сами разрабатываем и публикуем множество статей для того чтобы вылечить такие заболевания как рак или болезнь Альцгеймера, а возможно и старение как таковое.


В рамках подготовки своего курса по глубокому обучению я собираюсь опубликовать серию статей на тему Состязательных(Adversarial) сетей с разбором того что же это такое и как этим пользоваться. Эта серия статей не будет очередным обзором GANов(Generative Adversarial Networks), но позволит глубже заглянуть под капот нейронных сетей и охватит более широкий спектр архитектур. Хотя GANы мы конечно тоже разберем.


Для того чтобы дальше беспрепятственно обсуждать состязательные сети я решил сначала сделать небольшое введение в pyTorch. Хочу сразу заметить, что это не введение в нейронные сети, поэтому я исхожу из того, что вы уже знаете такие слова как "слой", "батч", "бэкпроп" и т.д. Помимо базовых знаний о нейросетях, вам, конечно, понадобится понимание языка python.


Для того чтобы было удобно пользоваться pyTorch я подготовил докер-контейнер с jupyter'ом и кодом в ноутбуках. Если вы захотите запускать обучение на видеокарте, то для видеокарт от NVIDIA вам потребуется nvidia-docker, думаю с этой частью у большинства из вас проблем не будет, поэтому остальное я оставляю вам.


Все необходимое для этого поста доступно в моем репозитории spoilt333/adversarial с тегом intro на Docker Hub, или в моем репозитории на GitHub.


После установки докера запустить контейнер можно например с помощью такой команды:


docker run -id --name intro -p 8765:8765 spoilt333/adversarial:intro

В контейнере автоматически запустится сервер jupyter'а, который будет доступен по http://127.0.0.1:8765 с паролем "password"(без кавычек). Если вы не хотите запускать чужой контейнер у себя на машине(правильно!), то собрать свой такой же, предварительно проверив что там все ок, можно из докерфайла который есть в репозитории на GitHub.


Если у вас все запустилось и вы смогли подключиться к jupyter, то давайте перейдем к тому что же из себя представляет pyTorch. pyTorch — это большой фреймворк позволяющий создавать динамические графы вычислений и автоматически вычислять градиенты по этим графам. Для машинного обучения это как раз то что нужно. Но, помимо самой возможности обучать модели, pyTorch это еще и огромная библиотека включающая датасеты, готовые модели, современные слои и комьюнити вокруг всего этого.


В Deep Learning, довольно продолжительное время, было практически стандартом тестировать все новые модели на задаче распознавания рукописных цифр. Датасет MNIST представляет из себя 70.000 размеченных рукописных цифр примерно поровну распределенных между классами. Он сразу же разбит на тренировочное и тестовое множества для того чтобы обеспечить одинаковые условия всем кто тестируется на этом датасете. В pyTorch, естественно, для него есть простые интерфейсы. Несмотря на то что сравнивать между собой state-of-the-art модели на этом датасете уже не имеет большого смысла, для демонстрационных целей он нам подойдет идеально.


Примеры цифр из MNIST

Каждый пример в MNIST представляет из себя изображение размером 28х28 пикселей в оттенках серого. И, как нетрудно заметить, далеко не все цифры легко может "распознать" даже человек. В ноутбуке mnist.ipynb вы можете посмотреть на пример загрузки и отображения датасета, а несколько полезных функций вынесены в файл utils.py. Но давайте перейдем к основному "блюду".


В ноутбуке mnist-basic.ipynb реализована двухслойная полносвязная нейронная сеть решающая задачу классификации. Один из способов сделать нейронную сеть с помощью pyTorch — это наследоваться от класса nn.Module и реализовать свои функции инициализации и forward


def __init__(self):
    super(Net, self).__init__()
    self.fc1 = nn.Linear(784, 100)
    self.fc2 = nn.Linear(100, 10)

Внутри функции __init__ мы объявляем слои будущей нейронной сети. В нашем случае это линейные слои nn.Linear которые имеют вид W'x+b, где W — матрица весов размером (input, output) и b — вектор смещения размером output. Эти самые веса и будут "обучаться" в процессе тренировки нейронной сети.


def forward(self, x):
    x = x.view(-1, 28*28)
    x = F.relu(self.fc1(x))
    x = self.fc2(x)
    x = F.softmax(x, dim=1)
    return x

Метод forward используется непосредственно для преобразования входных данных с помощью заданной нейройнной сети в ее выходы. Для простоты примера мы будем работать с примерами из MNIST не как с изображениями, а как с векторами каждая размерность которых соответствует одному из пикселей. Функция view() это аналог numpy.reshape(), она переиндексирует тензор с данными заданным образом. "-1" в качестве первого аргумента функции означает, что количество элементов в первой размерности будет вычислено автоматически. Если исходный тензор x имеет размерность (N, 28, 28), то после


x = x.view(-1, 28*28)

его размерность станет равна (N, 784).


x = F.relu(self.fc1(x))

Применение слоев к данным в pyTorch реализовано максимально просто, вы можете "вызвать" слой передав ему в качестве аргумента батч данных и получить на выходе результат преобразования. Аналогичным образом устроены и функции активации. В данном случае я использую relu, так как это наиболее популярная функция активации в задачах компьютерного зрения, однако вы легко можете поэкспериментировать с другими реализованными в pyTorch функциями, благо их там достаточно.


x = self.fc2(x)
x = F.softmax(x, dim=1)

Так как мы решаем задачу классификации на 10 классов, то и выход нашей сети имеет размерность 10. В качестве функции активации на выходе сети мы используем softmax. Теперь значения которые возвращает функция forward можно интерпретировать как вероятности того что входной пример принадлежит к соответствующим классам.


model = Net()
optimizer = optim.SGD(model.parameters(), lr=0.01)

Теперь мы можем создать экземпляр нашей сети и выбрать функцию оптимизации. Для того чтобы получился симпатичный график обучения я выбрал обыкновенный стохастический градиентный спуск, но в pyTorch, конечно же, реализованы и более продвинутые методы. Вы можете попробовать например RMSProp или Adam.


def train(epoch):
    for data, target in train_loader:
        optimizer.zero_grad()
        output = model(data)
        loss = F.cross_entropy(output, target)
        loss.backward()
        optimizer.step()

Функция train содержит основной цикл обучения в котором мы итерируемся по батчам из тренировочного множества. data — это примеры, а target — соответствующие метки. В начале каждой итерации мы обнуляем текущее значение градиентов:


optimizer.zero_grad()

Обработка данных всей сетью в pyTorch ничем не отличается от применения отдельного слоя. За вызовом model(data) скрыт вызов функции forward, поэтому в output попадают выходы сети. Теперь остается только посчитать значение функции ошибки и сделать шаг обратного распространения:


loss = F.cross_entropy(output, target)
loss.backward()

На самом деле, при вызове loss.backward() веса сети еще не обновляются, но для всех весов использовавшихся при вычислении ошибки pyTorch считает градиенты используя построенный граф вычислений. Для того чтобы обновить веса мы вызываем optimizer.step(), который опираясь на свои параметры(у нас это learning rate) обновляет веса.



После 20 эпох обучения наша сеть угадывает цифры с точностью 91%, что, конечно, далеко от SOTA результатов, однако, весьма неплохо для 5 минут программирования. Вот пример из тестового множества с предсказанными ответами



[[1 1 5 2 4 6 9 9 9 9]
 [2 3 5 4 4 1 3 2 4 7]
 [5 0 3 9 4 5 3 2 3 2]
 [0 3 8 2 5 5 8 7 8 6]
 [8 3 6 8 4 8 5 1 3 9]]

В следующих постах я постараюсь рассказать о состязательных сетях в таком же стиле с примерами кода и подготовленными докер-контейнерами, в частности я планирую коснуться таких тем как domain adaptation, style transfer, generative adversarial networks и разобрать несколько наиболее важных статей в этой области.


Upd.1: Как правильно указали в комментариях оборачивать тензоры в Variable больше не нужно, поэтому я удалил соответствующу строчку. В докер-контейнере она естественно останется, однако без нее все тоже работает.
Upd.2: Картинки с цифрами былы перепутаны местами, так что я их поменял

Отус

197,00

Профессиональные онлайн-курсы для разработчиков

Поделиться публикацией
Комментарии 10
    +1
    Напишите что-нибудь поинтереснее циферок. Их даже школьники уже распознают.
    0
    В Insilico мы используем самые современные методы машинного обучения, а также сами разрабатываем и публикуем множество статей для того чтобы вылечить такие заболевания как рак или болезнь Альцгеймера, а возможно и старение как таковое.
    Какое-то пафосное вступление, а далее просто о распознавании циферок.

    Как именно машинное обучение вылечит болезнь Альцгеймера, а возможно и старение как таковое? Или вы просто имели в виду, что ваша компания этим занимается, а вы, как маленький винтик сложной машины, пока что просто распознаете циферки?
      0
      Как я понял, автор публикует серию статей, начал с азов. Будьте терпеливее и терпимее.
      А в медицине машинное обучение используется для ускорения подбора/перебора препаратов.
        +3
        В начале июня я проведу открытый вебинар на платформе Otus в котором расскажу какие задачи есть в разработке лекарственных препаратов и почему их можно решать с помощью нейронных сетей. В этой статье я показываю пример того как пользоваться pyTorch для того чтобы в следующих было понятно что вообще происходит. А в роли маленького винтика я действительно пишу относительно простые статьи в сравнении с тем что делают мои коллеги.
        +3

        Variable в pyTorch — всё (вроде бы уже с версии 0.3 они стали не обязательны), больше можно в них не оборачивать.

          +1

          Спасибо, учту.

            +1

            С 0.4, которая недавно вышла.

            +2
            А почему предсказанные в конце ответы абсолютно не подходят к тестовому множеству?
              0
              А потому что картинки местами перепутал:)

            Только полноправные пользователи могут оставлять комментарии. Войдите, пожалуйста.

            Самое читаемое