Как стать автором
Обновить

Пишем «Змейку» в 12 строк кода на PyTorch

Уровень сложностиПростой
Время на прочтение5 мин
Количество просмотров16K
Автор оригинала: Elias F. Fyksen

Привет, Хабр! 🖖🏻

Меня зовут Олег Булыгин, я data scientist, аналитик, автор и спикер IT-курсов.

Я готовлю разный полезный контент, туториалы и руководства по Python, которыми бы хотел делиться с вами :)

Давайте рассмотрим, как использовать линейную алгебру и тензорные операции, чтобы создать всем известную игру в 12 строк.

И у вас сразу точно возникает несколько вопросов:

  1. Насколько длинные эти 12 строк?

    Не волнуйтесь, все они соответствуют стандарту PEP8.

  2. Зачем это вообще делать?

    Иногда надо писать код просто ради фана. Кроме того, это отличный способ познакомиться с PyTorch и возможностями, которые предоставляют тензоры.

  3. Но этом же нет никакой практической пользы?

    Напротив. Методы, используемые в этой материале, на самом деле являются фундаментальными. И они лежат в основе модуля TensorSnake, который может эмулировать параллельно 100 миллионов игр "Змейка" на карте NVIDIA A6000 с задержкой 20 миллисекунд.

Сегодня мы программируем версию "Змейки", в которой она может перетекать за границу поля и выходить с другой стороны. Тем не менее, можно будет изменить 2 строки, чтобы реализовать стандартную версию.

Будем использовать PyTorch и NumPy. Можно было использовать даже какую-то одну из библиотек, но у PyTorch прекрасное Tensor API, а в NumPy есть хорошая функция под названием unravel_index, которую мы и будем использовать.

И договоримся, что в подсчёт строк не будут входить импорты и строка с определением функции ;)

Вопросы закрыли, зафиксировали договорённости, поехали!

Кодировка

Важнейшей частью этого кода является кодировка состояния змейки — формализация хранения информации об её положении.

В результате кодировки нам хотелось бы получить матрицу, которая при выводе через plt.imshow покажет состояние игры. И эту информацию должно быть легко обновлять.

Поэтому всю игру мы представим в виде матрицы целых чисел, где каждая пустая ячейка в игре будет иметь значение 0, хвост змеи будет 1, и по мере приближения хвоста к голове значение клеток будет увеличиваться на ещё одну единицу. Места с едой (целью) определяются значением -1. Итак, для змеи размера N клеток хвост будет равен 1, а голова — N.

Теперь нам нужно как-то формализовать действия. Вместо традиционной кодировки действий [вверх, вправо, вниз, влево] мы будем использовать кодировку [влево, вперёд, вправо], для определения направления движения относительно текущего направления змейки. Игроку может быть не очень привычно, но такой подход не является избыточным, т.к. в каждый момент любое действие является валидным (поскольку змейка не может двигаться назад).

Реализация

Ну и наконец, пишем код.

Все используемые функции хорошо описаны в документацией PyTorch API, подглядывайте туда, если что-то не понимаете.

Получение текущей позиции

Первое, что нужно сделать — это получить текущее и предыдущее положение головы змеюки. Мы можем сделать это с помощью topk(2), так как голова всегда является самым большим целым числом, а предыдущая её позиция — второе по величине число. Единственная проблема, с которой мы сталкиваемся, заключается в том, что метод topk делает расчёт только по одному измерению. Поэтому нам нужно сначала разгладить тензор с помощью метода flatten(), получить максимальные k элементов, а затем использовать вышеупомянутый unravel_index, чтобы преобразовать его обратно в двухмерное состояние. И нам надо полученные два индекса в тензоры, чтобы мы могли выполнять математические вычисления и с ними.

Вычисление следующей позиции

Чтобы вычислить следующую позицию, мы сделаем pos_cur - pos_prev. Эта операция вернёт вектор, указывающий на текущее направление движения змеи. Далее мы хотим повернуть его, но насколько?

Мы хотим повернуть его на 270 + 90 * action градусов. Таким образом, когда мы будем передавать 0, то мы поворачиваем налево, 1 — мы двигаемся прямо, а 2 — поворачиваем направо.

Для получения результата мы применяем матрицу вращения. Если матрица применяется к самой себе, это даёт нам матрицу, которая эквивалентна двойному применению преобразования. Следовательно, мы можем взять вектор направления и применить матрицу вращения на 90 градусов против часовой стрелки T([[0, -1], [1, 0]]), возведённую в степень 3 + action.

Наконец, мы добавляем текущую позицию к этому новому вектору направления, чтобы получить следующую позицию. Затем мы берём новое местоположение и взятие остатка от деления на размером поля, чтобы создать функциональность "перетекания" змейки за границу.

Как умереть

Ах, извечный вопрос. Но пока мы о змейке.

Поскольку теперь у нас есть следующая позиция, становится довольно просто определить, должна ли змейка умереть или нет. Нам просто нужно проверить, является ли snake[tuple(pos_next)] > 0, так как единственными клетками со значениями больше 0 являются те, в которых в данный момент находится змея.

Если змейка умирает, мы хотим вернуть счёт текущей игры. Это также довольно просто, поскольку счёт в игре равен длине змейки минус 2 (предполагая, что мы начинаем игру при длине змеи 2). Чтобы получить длину, нам просто нужно получить значение pos_cur, так как это текущая голова змеи. Это означает, что текущий счет равен snake[tuple(pos_cur)] — 2.

Как кушать

Время собраться с духом, следующие 3 строчки — самые сложные в игре.

Чтобы проверить, поймала ли змейка еду, мы сравниваемsnake[pos_next] с -1. Если они равны, то нам нужно найти все позиции на доске, которые в данный момент равны 0. Это пустые ячейки, куда мы потенциально можем положить следующую цель.

Когда у нас будут все эти позиции, нам нужно случайным образом выбрать один из этих индексов и обновить его значение до -1. Нам не нужно редактировать текущую запись -1, так как змея перезапишет её при перемещении.

Чтобы найти все места, которые в данный момент равны 0, мы просто используем snake == 0 (это возвращает логический тензор). Далее мы делаем .multinomial(1) для того, чтобы выбрать одну из позиций наугад. Функция multinominal(n) выбирает n случайных индексов из тензора с вероятностью, основанной на значении элемента.

Однако multinomial работает только с одной размерностью (как и topk), а также принимает только значения с плавающей точкой. Следовательно, нам нужно сначала использовать методы flatten() и .to(t.float). Таким образом, каждый индекс, значение которого равно 0, имеет одинаковую вероятность выбора, а каждый индекс, значение которого не равно 0, имеет нулевую вероятность выбора.

Как только это будет сделано, нам нужно снова использовать unravel, чтобы вернуть всё к двухмерному состоянию и обновить тензор snake.

Чтобы переместить змею, мы уменьшаем текущую змейку и добавляем новую голову на следующую позицию.

Однако, мы хотим уменьшить змею только в том случае, если змейка не поймала цель. Если же поймала, то мы хотим увеличить её размер на 1 ячейку.

Поэтому мы добавляем блок else к ветке if на случай уменьшения. Поскольку каждая ячейка змейки пронумерована, у хвоста значение 1, мы можем вычесть 1 из значения каждой ячейки, размер которой больше 0 (так как только ячейки самой змейки вообще больше нуля). Вот мы и подрезали змейку на 1 клетку.

Теперь нам нужно добавить ей голову на новую позицию, т.е. установить значение следующей позиции, как значение предыдущей +1.

Заключение

А вот и всё. Вы написали "Змейку" в 12 строк кода.

def do(snake: t.Tensor, action: int):
    positions = snake.flatten().topk(2)[1]
    [pos_cur, pos_prev] = [T(unravel(x, snake.shape)) for x in positions]
    rotation = T([[0, -1], [1, 0]]).matrix_power(3 + action)
    pos_next = (pos_cur + (pos_cur - pos_prev) @ rotation) % T(snake.shape)
    
    if (snake[tuple(pos_next)] > 0).any():
        return (snake[tuple(pos_cur)] - 2).item() 
    
    if snake[tuple(pos_next)] == -1:
        pos_food = (snake == 0).flatten().to(t.float).multinomial(1)[0]
        snake[unravel(pos_food, snake.shape)] = -1
    else:
        snake[snake > 0] -= 1  
        
    snake[tuple(pos_next)] = snake[tuple(pos_cur)] + 1

Интерфейс

А, дак вы и поиграть в неё еще хотите?

Создание простенького графического интерфейса будет стоить нам ещё 15 строк, держите:

snake = t.zeros((32, 32), dtype=t.int)
snake[0, :3] = T([1, 2, -1]) 

fig, ax = plt.subplots(1, 1)
img = ax.imshow(snake)
action = {'val': 1}
action_dict = {'a': 0, 'd': 2}

fig.canvas.mpl_connect('key_press_event',
                       lambda e: action.__setitem__('val', action_dict[e.key]))

score = None
while score is None: 
    img.set_data(snake)
    fig.canvas.draw_idle()
    plt.pause(0.1) 
    score = do(snake, action['val']) 
    action['val'] = 1 
    
print('Score:', score)

Теперь можете играть сколько душе угодно :)

🐍Если тебе интересны и другие полезные материалы по Python и IT, то подписывайся на мой канал в tg: PythonTalk 🫶

Теги:
Хабы:
+29
Комментарии26

Публикации

Истории

Работа

Python разработчик
123 вакансии
Data Scientist
58 вакансий

Ближайшие события

One day offer от ВСК
Дата16 – 17 мая
Время09:00 – 18:00
Место
Онлайн
Конференция «Я.Железо»
Дата18 мая
Время14:00 – 23:59
Место
МоскваОнлайн
Антиконференция X5 Future Night
Дата30 мая
Время11:00 – 23:00
Место
Онлайн
Конференция «IT IS CONF 2024»
Дата20 июня
Время09:00 – 19:00
Место
Екатеринбург
Summer Merge
Дата28 – 30 июня
Время11:00
Место
Ульяновская область