Как стать автором
Обновить

Исследование и оптимизации RNN

Время на прочтение4 мин
Количество просмотров681

Поскольку я не профессионал, буду использовать свои находки и предположения.
И никакой математики.

В этой статье я буду анализировать и улучшать ATR, LRN, LSTM, GRU в задаче генерации текста. Обучать их я буду на классике Достоевского: «gs://trax‑ml/reformer/crime‑and‑punishment-2554.txt». Он первый под руку попался.
Чтобы не мучаться с точной настройкой планировщика, я использовал свой, который понижает скорость, когда растут потери.

Также я использовал label_smoothing 0.5. Даёт ~1% точности и эффект стабилизации.
Планировщик и label_smoothing позволили сохранить стабильность.
Pre и Post нормализации дали до 8%.

Каждая сеть была общим размером ~96т параметров, эмбеддинг [100, 64], 1 слой.
Шаг 1024 токена, точность mixed_float16, Adam с lr 0.003.
TF версия 2.15.1, в 2.16 Keras сломан.

Я использовал CuDNN реализацию LSTM, GRU. Код LSTM для справки:

  def call(self, inputs, states):
    h, c = states
    f = tf.sigmoid(self.wf_l(inputs) + self.uf_l(h))
    i = tf.sigmoid(self.wi_l(inputs) + self.ui_l(h))
    o = tf.sigmoid(self.wo_l(inputs) + self.uo_l(h))
    c_new = tf.tanh(self.wc_l(inputs) + self.uc_l(h))
    c_new = f * c + i * c_new
    h_new = o * tf.tanh(c_new)
    return h_new, [h_new, c_new]

(f * c) + (i * c_new) и (wx + wh) похожи на подобие остаточного соединения.
В LSTM(1997) задолго до ResNet(2015).

Базовые результаты точности(%) по эпохам. Перечисление заканчивается там где рост < 1%.

LSTM

48, 56, 57, 58.

GRU

49, 56, 58, 59.

GRU(4.47м)

54, 59, сломался.

#Код ATR (Addition-Subtraction Twin-Gated RNN):
class ATRCell(Layer):
  def __init__(self, units):
    super().__init__()
    self.state_size = units

  def build(self, input_shape):
    self.p_l = Dense(self.state_size)
    self.q_l = Dense(self.state_size)
    self.built = True

  def call(self, inputs, states):
    s = states[0]
    p = self.p_l(inputs)
    q = self.q_l(s)
    f = tf.sigmoid(p - q)
    i = tf.sigmoid(p + q)
    o = i * p + f * s
    return o, [o]

ATR

49, 56, 57, 58.

  1. Опыт от предыдущего эксперимента подсказал, что активаторы могут мешать даже если теоретически нужны. Особенно Sigmoid.
    Объясняю это угасанием градиентов и тем, что обучаемые параметры могут компенсировать удаление зажимов и взаимно стабилизироваться.
    Вообще, я думаю, есть несколько средств стабилизации RNN:
    Активаторы-зажимы(Sigmoid, TanH, ...), ворота ввода и забывания(i, f), Bias, умножение матриц.
    По логике предполагается, что значения ворот f и i должны быть в пределах 0-1.
    Замена их на TanH дало улучшение: 50, 56, 58, 59, 60. Выходит, что отрицательные значения всё же полезны.

  2. Удалил активатор i: +1%. Разработчик предлагает замену для i на 1 - f.
    Попробовал: -1%.

  3. i и q выглядят вторичными параметрами в текущей конструкции.
    Решил прибавить значимость i, заменив им q, а f выразил через p - i. +1%.
    Это добавило и стабильности позволив окончательно избавиться от активаторов.

  4. Покрутил я это так и сяк. И заменил f на p. И знаете что?
    Стабильность снова повысилась.
    Далее всё же применил немного математики и переписал это в o = p * (i + s).

  5. Bias в i_l является необходимым для обучения.
    Тут, Bias это просто циклически прибавляемая константа.
    Какой в этом смысл? Это позволяет смещать значения от нуля. p * (0*x + 0) = 0.
    Заменил s на 0.1. Обучать Bias больше не нужно.

Ещё можно вынести p из цикла. Поскольку от ATR ничего не осталось, сменю название.
Далее предполагается, что закомментированные строки находятся вне цикла.

#Код SMR (Simple Multiplicative RNN):
  def call(self, inputs, states):
    s = states[0]
    #p = self.p_l(inputs)
    i = self.i_l(s)
    o = p * (i + 0.1)
    return o, [o]

SMR(46т)

50, 55, 56, 57.

SMR(96т)

53, 58, 59, 60 > 62.

SMR(340т)

55, 61, 63, 64, 65, 66.

SMR(1.19м)

57, 62, 65, 67, 69, 71, 72, 73.

SMR(4.47м)

57, 63, 66, 69, 73, 75, 78, 79.

Это уже абсурд какой то. Простейший SMR унижает LSTM "как Тузик грелку".

Рассмотрим LRN. Его автор предлагает убрать из цикла все умножения матриц(аналогично q, k, v).

#Код LRN (Lightweight Recurrent Network):
  def call(self, inputs, states):
    s = states[0]
    #p = self.p_l(inputs)
    #q = self.q_l(inputs)
    #r = self.r_l(inputs)
    f = tf.sigmoid(q - s)
    i = tf.sigmoid(p + s)
    o = i * r + f * s
    return o, [o]

LRN

45, 52, 54, 55.

LRN(stack=3)

46, 54, 56, 57.

Как видно, это очень похоже на ATR. Результаты плохие.

#Код оптимизированного LRN:
  def call(self, inputs, states):
    s = states[0]
    #p = self.p_l(inputs)
    #q = self.q_l(inputs)
    #r = self.r_l(inputs)
    o = p * r + q * s
    o = tf.tanh(o)
    return o, [o]

ILRN

48, 52, 54, 55.

ILRN(stack=3)

51, 56, 57, 59.

Цикл h -> self.h_l(h) -> o может использоваться для генерации потока, но с inputs так нельзя. Так почему же это работает? Я думаю, что сама входная последовательность является источником нелинейности. Умножение матриц может смешивать значения вектора. Поскольку их тут нет, то остаётся только словарь(Embedding). Его достаточно для предсказания следующего токена. Выходит, что для RNN не обязательны умножения матриц и активаторы. Хотя p, q, r умножения матриц вычисляются вне цикла, они всего лишь аугментируют данные из словаря. По мне это логично. Подобные действия это "предсказания в вакууме", поскольку нет зависимостей. Смешивают значения вектора словаря с помощью своих обучаемых параметров и создают производные. Но эти данные можно поместить непосредственно в словарь, и аугментация станет не нужна. Хотя словарь понадобится намного больше. Теперь можно создать SLR.

#Код SLR (Simplest Language RNN):
  def call(self, p, states):
    s = states[0]
    o = p * (s + 0.1)
    return o, [o]

SLR(emb=223)

46, 50, 51 > 52.

SLR(emb=512)

48, 53, 54 > 56.

SLR(emb=1024, lr=0.001)

48, нестабильно > 57.

Это, конечно, примитивная RNN, но она может моделировать язык.
Даже простой Embedding + Head(без RNN) по сути даёт Bigram модель.
Хватает для достижения 26%. Для сравнения:

  def call(self, p, states):
    s = states[0]
    o = tf.tanh(p + s)
    return o, [o]

(emb=223): 42, 47, 48, 49.
(emb=512): 42, 48, 50, 51, 52.

  def call(self, inputs, states):
    s = states[0]
    #p = self.p_l(inputs)
    o = p + self.s_l(s)
    o = tf.tanh(o)
    return o, [o]

48, 55, 57, 58, 59.

Похоже, LSTM худший выбор для задач LM. Складывается впечатление, что он сделан по принципу - навалим побольше матриц и активаторов, может поможет. Но почему результаты всех этих RNN так слабо отличаются? Полагаю, что основная точность достигается запоминанием устойчивых последовательностей. И для дальнейшего улучшения принципиально чего то не хватает.

Качество генерации текста у всех архитектур схожее.

Теги:
Хабы:
+3
Комментарии0

Публикации

Истории

Работа

Data Scientist
84 вакансии

Ближайшие события