«Вспомнить все» или решение проблемы катастрофической забывчивости для чайников

    Эта моя статья будет посвящена проблеме катастрофической забывчивости и новейшим методам ее решения. Будут приведены примеры реализации этих методов, которые легко адаптировать под почти любую конфигурацию нейронной сети.

    Сначала напомним, что это, собственно, за проблема. Если вдруг так оказалось, что вам нужно обучать нейронную сетку сначала на одном датасете, а затем на другом, то вы обнаружите, что по мере обучения на втором датасете сетка быстро забывает первый датасет, то есть теряет навык, полученный при обучении на нем. Или же если вы используете transfer learning и доучиваете готовую сетку на своих примерах, то будет наблюдаться тот же эффект – сетка успешно доучится на ваших данных, но при этом существенно утеряет предыдущие навыки, то есть то, ради чего весь transfer learning и затевался. Если вдруг датасетов, на которых надо последовательно учиться, не два а, к примеру, пять, то к концу обучения на пятом сетка забудет первый датасет практически полностью.

    С этой проблемой мы имеем дело с середины прошлого века, то есть с момента, когда начали проводиться эксперименты по обучению нейронных сетей. За прошедшее время об проблему обломали немало зубов, и она успела впитаться в коллективное бессознательное как нечто трудноразрешимое, с чем приходится мириться или придумывать разные трюки для того, чтобы хоть как-то снизить ее негативное влияние.

    Однако четыре года назад группой «британских ученых» (на самом деле это были парни из британского отделения DeepMind) был сделан значительный прорыв – они придумали метод эластичного закрепления весов (EWC). Если объяснять на пальцах, то его суть в следующем. Если навык обученной нейронной сетки заключен в весах ее связей (и не только связей, а вообще параметров), то «не все они одинаково полезны». То есть какие-то связи более важны для выученного навыка, какие-то менее. Идея в том, чтобы, когда сетка уже обучена навыку (датасету), при дальнейшем обучении другим навыкам не давать важным весам уходить далеко от эталонных значений, то есть полученных после обучения первому навыку. Реализуется это добавлением слагаемого-регуляризатора в функцию стоимости, так что при изменении параметра регуляризатор тянет его назад, в сторону, противоположную изменению. И чем «важнее» параметр, тем сильнее его тянет обратно. Получается, что веса сети как будто привязаны к эталонным значениям резинками разной упругости.

    Все это выглядит просто и замечательно, однако основная проблема – как понять, насколько каждый вес или параметр сети важен или наоборот – не важен. Тут «британские ученые» тоже не оплошали и предложили использовать в качестве «важности» весов диагональные элементы информационной матрицы Фишера. И это отлично сработало! То есть, если после обучения на первом датасете посчитать таким способом важности весов, а потом, при обучении на втором датасете, добавить регуляризатор в функцию стоимости, то сетка научится второму датасету не потеряв навыка первого. Ну почти не потеряв. И таким образом можно дообучить сетку последовательно хоть на десятке датасетов, добавляя соответствующие регуляризаторы. И все навыки будут сохраняться пока будет хватать емкости сетки.

    К счастью, потом еще один математик доказал, что правильнее не регуляризаторы добавлять, а важности весов суммировать, что существенно проще. К сожалению, у метода все же есть пара недостатков: чтоб все работало, выход сети должен быть вероятностным распределением (то есть softmax), и из этого распределения нужно сэмплировать, а потом и строить дополнительный граф вычислений для диагонали информационной матрицы Фишера и считать его для каждого примера датасета. Хорошая новость в том, что tensorflow или torch все сделает за вас, плохая в том, что посчитать придется много.

    Далее привожу пример реализации EWC для глубокой сетки из нескольких полносвязных слоев. Последовательное обучение сетки нескольким датасетам проводится следующим образом: берем датасет, открываем урок для сетки, учим ее, закрываем урок, считая важности на текущем датасете и, суммируя с предыдущими значениями важности, повторяем все для следующего датасета.

    Пример реализации EWC тут.
    # код рассчитан на python 3.7 и tensorflow 1.15.0
    
    import datetime
    from copy import deepcopy
    
    import matplotlib.pyplot as plt
    import numpy as np
    import tensorflow as tf
    from tensorflow.examples.tutorials.mnist import input_data
    
    
    def _weight_variable(shape):
        # для весов полносвязного слоя инициализируем значения по Каймину Ге (Хе)
        stddev = 2. / np.sqrt(shape[0])
        initial = tf.random.truncated_normal(shape=shape, mean=0.0, stddev=stddev)
        return tf.Variable(initial, dtype=tf.float32)
    
    
    def _bias_variable(shape):
        # смещения инициализируем нулями
        initial = tf.constant(0.0, shape=shape)
        return tf.Variable(initial, dtype=tf.float32)
    
    
    class Model:
        def __init__(self, shape, session):
            """
            :param shape:   структура сети - список из чисел нейронов в каждом слое сети
                            от входа к выходу справа налево, например, [784, 100, 10]
            :param session: tensorflow-сессия для расчетов сети
            """
    
            self.session = session
            self._shape = shape
            depth = len(shape) - 1
            if depth < 1:
                raise ValueError("Недопустимая структура сети!")
    
            # заглушки для входных данных
            self.x = tf.placeholder(tf.float32, shape=[None, shape[0]])
            self.labels = tf.placeholder(tf.float32, shape=[None, shape[-1]])
    
            # все веса слоев сети будем хранить в списке
            self.var_list = []
            for ins, outs in zip(shape[:-1], shape[1:]):
                self.var_list.append(_weight_variable([ins, outs]))
                self.var_list.append(_bias_variable([outs]))
    
            # инициализируем веса сети
            for v in self.var_list:
                session.run(v.initializer)
    
            # список для хранения важностей весов сети
            self.wb_importance = [np.zeros(v.shape, dtype=np.float32) for v in self.var_list]
    
            # строим вычислительный граф
            x, y, z = self.x, None, None
            for i in range(depth):
                z = tf.matmul(x, self.var_list[i * 2]) + self.var_list[i * 2 + 1]
                y = tf.nn.softmax(z) if i == depth-1 else tf.nn.leaky_relu(z)
                x = y
    
            # функция стоимости
            self.loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=z, labels=self.labels))
    
            # точность (accuracy)
            self.correct_preds = tf.equal(tf.argmax(z, axis=1), tf.argmax(self.labels, axis=1))
            self.accuracy = tf.reduce_mean(tf.cast(self.correct_preds, tf.float32))
    
            # сэмплируем из выхода сети как из вероятностного распределения
            class_ind = tf.cast(tf.random.categorical(tf.math.log(y), 1)[0][0], tf.int32)
            # вычисляем градиенты вероятности
            self.prob_grads = tf.gradients(tf.math.log(y[0, class_ind]), self.var_list)
    
            self.train_step = None
    
        def open_lesson(self, learning_rate=1.0, lmbda=0.0):
            """
            Открытие урока обучения сети на отдельном датасете
            :param learning_rate: скорость обучения для SGD
            :param lmbda:         коэффициент влияния важностей - насколько сильно
                                  важности тянут веса к эталонным значениям
            """
            loss = self.loss
    
            if hasattr(self, "star_vars") and lmbda != 0:
                # добавляем к функции стоимости слагаемые-регуляризаторы
                for v in range(len(self._shape)*2-2):
                    loss += tf.reduce_sum(tf.multiply(
                        tf.constant(lmbda / 2. * self.wb_importance[v], tf.float32),
                        tf.square(self.var_list[v] - tf.constant(self.star_vars[v], tf.float32))
                    ))
            # устанавливаем шаг оптимизатора
            self.train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss)
    
        def close_lesson(self, closing_set=None):
            """
            Закрытие урока обучения сети на отдельном датасете. Расчет и накопление важностей весов.
            :param closing_set: датасет, на котором будут рассчитаны важности весов после обучения
            :return:
            """
    
            # рассчитываем важности весов на закрывающем датасете
            addendum = self._compute_fisher(closing_set)
            # добавляем рассчитанные важности к сохраненным
            for i, a in zip(self.wb_importance, addendum):
                i += a
            # запоминаем текущие эталонные веса сети после обучения
            self._store_weights_and_biases()
    
        def _store_weights_and_biases(self):
            self.star_vars = [v.eval() for v in self.var_list]
    
        def _compute_fisher(self, closing_set):
            # вычисление диагональных элементов информационной
            # матрицы Фишера для весов сети на заданном датасете
            num_samples = len(closing_set)
    
            # инициализируем значения нулями
            fisher = [np.zeros(self.var_list[v].shape, dtype=np.float32) for v in range(len(self.var_list))]
    
            for i in range(num_samples):
                # вычисляем первые производные логарифма вероятности
                feed_dict = {self.x: closing_set[i:i+1]}
                derivatives = self.session.run(self.prob_grads, feed_dict=feed_dict)
                # возводим их в квадрат т. к. это диагональные элементы
                for f, d in zip(fisher, derivatives):
                    f += np.square(d)
    
            # усредняем по количеству примеров в датасете.
            # такой подход применим если мы хотим чтоб каждый урок оказывал одинаковый вклад в важности.
            # если же мы хотим чтоб каждый пример при закрытии урока оказывал одинаковое влияние, то
            # нужно делить на некую подобранную константу, а не на число примеров.
            for f in fisher:
                f /= num_samples
    
            return fisher
    
    
    # функция случайным образом переставляет входы одинаково для всех примеров датасета
    def permute_mnist(mnist):
        perm_inds = list(range(mnist.train.images.shape[1]))
        np.random.shuffle(perm_inds)
        mnist2 = deepcopy(mnist)
        sets = ["train", "validation", "test"]
        for set_name in sets:
            this_set = getattr(mnist2, set_name)
            this_set._images = np.transpose(np.array([this_set.images[:, c] for c in perm_inds]))
        return mnist2
    
    
    def train_model(model, train_set, test_sets, batch_size=100, epochs=1):
        """
        Обучение модели
        :param model:       обучаемая модель
        :param train_set:   обучающий датасет
        :param test_sets:   список датасетов, на которых будет считаться средняя точность
        :param batch_size:  размер батча
        :param epochs:      количество эпох обучения
        :return:            средняя точность на тестовых датасетах после обучения
        """
        num_iters = int(np.ceil(len(train_set.train.labels) * epochs / batch_size))
        for idx in range(num_iters):
            train_batch = train_set.train.next_batch(batch_size)
            feed_dict = {model.x: train_batch[0], model.labels: train_batch[1]}
            model.train_step.run(feed_dict=feed_dict)
            print(f'\rTraining  {idx + 1}/{num_iters} done.', end='')
    
        print(f'\rTraining  {num_iters}/{num_iters} iterations done.')
    
        accuracy = 0.
        for t, test_set in enumerate(test_sets):
            feed_dict = {model.x: test_set.test.images, model.labels: test_set.test.labels}
            accuracy += model.accuracy.eval(feed_dict=feed_dict)
        accuracy /= len(test_sets)
        print(f'Evaluating on {len(test_sets)} test sets done. Accuracy {accuracy}')
        return accuracy
    
    
    def continual_learning(net_struct, data_sets, session, lr, lmbda):
        """
        Последовательное обучение на нескольких обучающих наборах
        :param net_struct: структура сети
        :param data_sets:  список обучающих датасетов для последовательного обучения
        :param session:    tf-сессия
        :param lr:         скорость обучения
        :param lmbda:      степень влияния важностей на обучение
        :return:           список усредненных по выученным датасетам оценок
        """
        model = Model(net_struct, session)
        test_sets = []
        accuracies = []
        for data_set in data_sets:
            test_sets.append(data_set)
            model.open_lesson(lr, lmbda)
            accuracy = train_model(model, data_set, test_sets, 100, 4)
            accuracies.append(accuracy)
            model.close_lesson(data_set.validation.images)
        del model
        return accuracies
    
    # считываем данные MNIST
    mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
    # создаем tf-сессию
    sess = tf.InteractiveSession()
    
    # создаем 10 различных обучающих наборов для последовательного обучения
    mnist0 = mnist
    mnist1 = permute_mnist(mnist)
    mnist2 = permute_mnist(mnist)
    mnist3 = permute_mnist(mnist)
    mnist4 = permute_mnist(mnist)
    mnist5 = permute_mnist(mnist)
    mnist6 = permute_mnist(mnist)
    mnist7 = permute_mnist(mnist)
    mnist8 = permute_mnist(mnist)
    mnist9 = permute_mnist(mnist)
    
    start_time = datetime.datetime.now()
    
    # определим параметры обучения
    data_sets = [mnist0, mnist1, mnist2, mnist3, mnist4, mnist5, mnist6, mnist7, mnist8, mnist9]
    net_struct = [784, 300, 150, 10]
    lmbda = 15.
    learning_rate = 0.2
    
    accuracies = continual_learning(net_struct, data_sets, sess, lr=learning_rate, lmbda=lmbda)
    print ('Total time spent', datetime.datetime.now() - start_time)
    
    dataset_num = range(1, len(accuracies) + 1)
    
    # нарисуем график деградации средней точности на всех выученных датасетах
    plt.figure(figsize=(7, 3.5))
    plt.ylim(0.40, 1.)
    plt.xlim(1, len(accuracies))
    plt.ylabel('Total accuracy')
    plt.xlabel('Number of tasks')
    plt.plot(dataset_num, accuracies, marker=".")
    #plt.legend()
    plt.show()

    Для иллюстрации проблемы катастрофической забывчивости в этом коде можно выставить значение лямбда в 0, и увидеть, как сильно начнет падать средняя точность

    Примерно через три месяца изобрели другой способ считать важности весов Synaptic Intelligence (SI) – это насколько менялась функция стоимости при изменениях веса в процессе обучения. В качестве достоинства метода авторы указывают, что важности считаются прямо в процессе обучения сетки. Авторы также утверждают, что с вычисленными таким способом важностями навыки сохраняются лучше, чем у EWC. Но строгих доказательств не приводят, кроме эксперимента, проведенного на конкретных значениях гиперпараметров сети

    В этом способе хорошо то, что он тоже работает, а также, что его использование не зависит от типа выхода сети (то есть выход не обязательно должен быть распределением/softmax). Вычислительно метод существенно легче, чем оригинальный EWC. Но также требует построения дополнительного графа вычислений и неотделимо сопровождает процесс обучения, то есть имеет пропорциональную обучению вычислительную стоимость, что не есть хорошо.

    Пример реализации SI тут.
    # код рассчитан на python 3.7 и tensorflow 1.15.0
    
    import datetime
    from copy import deepcopy
    
    import matplotlib.pyplot as plt
    import numpy as np
    import tensorflow as tf
    from tensorflow.examples.tutorials.mnist import input_data
    
    
    _epsilon = 0.1
    
    
    def _weight_variable(shape):
        # для весов полносвязного слоя инициализируем значения по Каймину Ге (Хе)
        stddev = 2. / np.sqrt(shape[0])
        initial = tf.random.truncated_normal(shape=shape, mean=0.0, stddev=stddev)
        return tf.Variable(initial, dtype=tf.float32)
    
    
    def _bias_variable(shape):
        # смещения инициализируем нулями
        initial = tf.constant(0.0, shape=shape)
        return tf.Variable(initial, dtype=tf.float32)
    
    
    class Model:
        def __init__(self, shape, session):
            """
            :param shape:   структура сети - список из чисел нейронов в каждом слое сети
                            от входа к выходу справа налево, например, [784, 100, 10]
            :param session: tensorflow-сессия для расчетов сети
            """
    
            self.session = session
            self._shape = shape
            depth = len(shape) - 1
            if depth < 1:
                raise ValueError("Недопустимая структура сети!")
    
            # заглушки для входных данных
            self.x = tf.placeholder(tf.float32, shape=[None, shape[0]])
            self.labels = tf.placeholder(tf.float32, shape=[None, shape[-1]])
    
            # все веса слоев сети будем хранить в списке
            self.var_list = []
            for ins, outs in zip(shape[:-1], shape[1:]):
                self.var_list.append(_weight_variable([ins, outs]))
                self.var_list.append(_bias_variable([outs]))
    
            # инициализируем веса сети
            for v in self.var_list:
                session.run(v.initializer)
    
            # список для накопления важностей весов сети за текущий урок
            self._accums = [np.zeros(v.shape, dtype=np.float32) for v in self.var_list]
    
            # список для хранения важностей весов сети за все завершенные уроки
            self.wb_importance = [np.zeros(v.shape, dtype=np.float32) for v in self.var_list]
    
            # строим вычислительный граф
            x, y, z = self.x, None, None
            for i in range(depth):
                z = tf.matmul(x, self.var_list[i * 2]) + self.var_list[i * 2 + 1]
                y = tf.nn.softmax(z) if i == depth-1 else tf.nn.leaky_relu(z)
                x = y
    
            # функция стоимости
            self.loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=z, labels=self.labels))
    
            # точность (accuracy)
            self.correct_preds = tf.equal(tf.argmax(z, axis=1), tf.argmax(self.labels, axis=1))
            self.accuracy = tf.reduce_mean(tf.cast(self.correct_preds, tf.float32))
    
            self.grads = tf.gradients(self.loss, self.var_list)
    
            self._train_step = None
    
        def open_lesson(self, learning_rate=1.0, lmbda=0.0):
            """
            Открытие урока обучения сети на отдельном датасете
            :param learning_rate: скорость обучения для SGD
            :param lmbda:         коэффициент влияния важностей - насколько сильно
                                  важности тянут веса к эталонным значениям
            """
            loss = self.loss
    
            if hasattr(self, "star_vars"):
                if lmbda != 0:
                    # добавляем к функции стоимости слагаемые-регуляризаторы
                    for v in range(len(self._shape)*2-2):
                        loss += tf.reduce_sum(tf.multiply(
                            tf.constant(lmbda / 2. * self.wb_importance[v], tf.float32),
                            tf.square(self.var_list[v] - tf.constant(self.star_vars[v], tf.float32))
                        ))
            else:
                # запоминаем текущие веса
                self.star_vars = [v.eval() for v in self.var_list]
    
            # устанавливаем шаг оптимизатора
            self._train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss)
    
            #
            self.cur_vars = [v.eval() for v in self.var_list]
    
        def train_step_run(self, feed_dict):
            self._train_step.run(feed_dict=feed_dict)
    
            # считаем обновленные оптимизатором значения параметров сети
            new_vars = [v.eval() for v in self.var_list]
    
            # рассчитаем градиенты
            grads = self.session.run(self.grads, feed_dict=feed_dict)
    
            # аккумулируем изменение функции стоимости (производная
            # по параметру, умноженная на изменение параметра)
            for acc, grad, prev_var, new_var in zip(self._accums, grads, self.cur_vars, new_vars):
                acc -= grad * (new_var - prev_var)
    
            # мастера tensorflow вероятно смогут из train_step.run вытащить
            # уже подсчитанные градиенты по весам и сами изменения весов чтобы
            # не считать их еще раз и сэкономить вычислительные ресурсы
    
            # сохраним новые значения весов для следующей итерации
            self.cur_vars = new_vars
    
        def close_lesson(self):
            """
            Закрытие урока обучения сети. Накопление важностей весов.
            :return:
            """
    
            # рассчитаем квадраты смещений параметров
            deltas = [np.square(v.eval() - prev_v) + _epsilon
                      for v, prev_v in zip(self.var_list, self.star_vars)]
    
            # добавляем рассчитанные важности к сохраненным
            for i, a, d in zip(self.wb_importance, self._accums, deltas):
                i += a / d
    
            # запоминаем текущие эталонные веса сети после обучения
            self.star_vars = [v.eval() for v in self.var_list]
    
    
    # функция случайным образом переставляет входы одинаково для всех примеров датасета
    def permute_mnist(mnist):
        perm_inds = list(range(mnist.train.images.shape[1]))
        np.random.shuffle(perm_inds)
        mnist2 = deepcopy(mnist)
        sets = ["train", "validation", "test"]
        for set_name in sets:
            this_set = getattr(mnist2, set_name)
            this_set._images = np.transpose(np.array([this_set.images[:, c] for c in perm_inds]))
        return mnist2
    
    
    def train_model(model, train_set, test_sets, batch_size=100, epochs=1):
        """
        Обучение модели
        :param model:       обучаемая модель
        :param train_set:   обучающий датасет
        :param test_sets:   список датасетов, на которых будет считаться средняя точность
        :param batch_size:  размер батча
        :param epochs:      количество эпох обучения
        :return:            средняя точность на тестовых датасетах после обучения
        """
        num_iters = int(np.ceil(len(train_set.train.labels) * epochs / batch_size))
        for idx in range(num_iters):
            train_batch = train_set.train.next_batch(batch_size)
            feed_dict = {model.x: train_batch[0], model.labels: train_batch[1]}
            model.train_step_run(feed_dict=feed_dict)
            print(f'\rTraining  {idx + 1}/{num_iters} done.', end='')
    
        print(f'\rTraining  {num_iters}/{num_iters} iterations done.')
    
        accuracy = 0.
        for t, test_set in enumerate(test_sets):
            feed_dict = {model.x: test_set.test.images, model.labels: test_set.test.labels}
            accuracy += model.accuracy.eval(feed_dict=feed_dict)
        accuracy /= len(test_sets)
        print(f'Evaluating on {len(test_sets)} test sets done. Accuracy {accuracy}')
        return accuracy
    
    
    def continual_learning(net_struct, data_sets, session, lr, lmbda):
        """
        Последовательное обучение на нескольких обучающих наборах
        :param net_struct: структура сети
        :param data_sets:  список обучающих датасетов для последовательного обучения
        :param session:    tf-сессия
        :param lr:         скорость обучения
        :param lmbda:      степень влияния важностей на обучение
        :return:           список усредненных по выученным датасетам оценок
        """
        model = Model(net_struct, session)
        test_sets = []
        accuracies = []
        for data_set in data_sets:
            test_sets.append(data_set)
            model.open_lesson(lr, lmbda)
            accuracy = train_model(model, data_set, test_sets, 100, 4)
            accuracies.append(accuracy)
            model.close_lesson()
        del model
        return accuracies
    
    # считываем данные MNIST
    mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
    # создаем tf-сессию
    sess = tf.InteractiveSession()
    
    # создаем 10 различных обучающих наборов для последовательного обучения
    mnist0 = mnist
    mnist1 = permute_mnist(mnist)
    mnist2 = permute_mnist(mnist)
    mnist3 = permute_mnist(mnist)
    mnist4 = permute_mnist(mnist)
    mnist5 = permute_mnist(mnist)
    mnist6 = permute_mnist(mnist)
    mnist7 = permute_mnist(mnist)
    mnist8 = permute_mnist(mnist)
    mnist9 = permute_mnist(mnist)
    
    start_time = datetime.datetime.now()
    
    # определим параметры обучения
    data_sets = [mnist0, mnist1, mnist2, mnist3, mnist4, mnist5, mnist6, mnist7, mnist8, mnist9]
    net_struct = [784, 300, 150, 10]
    lmbda = 0.1
    learning_rate = 0.02
    
    accuracies = continual_learning(net_struct, data_sets, sess, lr=learning_rate, lmbda=lmbda)
    print ('Total time spent', datetime.datetime.now() - start_time)
    
    dataset_num = range(1, len(accuracies) + 1)
    
    # нарисуем график деградации средней точности на всех выученных датасетах
    plt.figure(figsize=(7, 3.5))
    plt.ylim(0.40, 1.)
    plt.xlim(1, len(accuracies))
    plt.ylabel('Total accuracy')
    plt.xlabel('Number of tasks')
    plt.plot(dataset_num, accuracies, marker=".")
    #plt.legend()
    plt.show()

    Еще примерно через полгода важность весов предложили считать по степени зависимости выходных сигналов обученной сетки от весов на каждом примере датасета, что довольно логично. Этот метод назвали Memory Aware Synapses (MAS). Метод показывает довольно хорошие результаты – лучше, чем у обоих предыдущих, но также без строгого доказательства. Метод не требует softmax-выхода сети. Полноценная версия MAS вычислительно существенно легче EWC и сравнима с SI. Однако, в отличие от SI, отделима от обучения.

    Интересно, что при использовании нейросетки с полносвязными слоями и ReLU-активациями в методе MAS важность связи между нейронами равна произведению выхода нейрона-источника на выход нейрона-приемника. И это прямо один в один обучение по Хеббу, только в хеббовом обучении так считается изменение веса связи, а в MAS это важность связи.

    Пример реализации полноценного MAS тут.
    # код рассчитан на python 3.7 и tensorflow 1.15.0
    
    import datetime
    from copy import deepcopy
    
    import matplotlib.pyplot as plt
    import numpy as np
    import tensorflow as tf
    from tensorflow.examples.tutorials.mnist import input_data
    
    
    def _weight_variable(shape):
        # для весов полносвязного слоя инициализируем значения по Каймину Ге (Хе)
        stddev = 2. / np.sqrt(shape[0])
        initial = tf.random.truncated_normal(shape=shape, mean=0.0, stddev=stddev)
        return tf.Variable(initial, dtype=tf.float32)
    
    
    def _bias_variable(shape):
        # смещения инициализируем нулями
        initial = tf.constant(0.0, shape=shape)
        return tf.Variable(initial, dtype=tf.float32)
    
    
    class Model:
        def __init__(self, shape, session):
            """
            :param shape:   структура сети - список из чисел нейронов в каждом слое сети
                            от входа к выходу справа налево, например, [784, 100, 10]
            :param session: tensorflow-сессия для расчетов сети
            """
    
            self.session = session
            self._shape = shape
            depth = len(shape) - 1
            if depth < 1:
                raise ValueError("Недопустимая структура сети!")
    
            # заглушки для входных данных
            self.x = tf.placeholder(tf.float32, shape=[None, shape[0]])
            self.labels = tf.placeholder(tf.float32, shape=[None, shape[-1]])
    
            # все веса слоев сети будем хранить в списке
            self.var_list = []
            for ins, outs in zip(shape[:-1], shape[1:]):
                self.var_list.append(_weight_variable([ins, outs]))
                self.var_list.append(_bias_variable([outs]))
    
            # инициализируем веса сети
            for v in self.var_list:
                session.run(v.initializer)
    
            # список для хранения важностей весов сети
            self.wb_importance = [np.zeros(v.shape, dtype=np.float32) for v in self.var_list]
    
            # строим вычислительный граф
            x, y, z = self.x, None, None
            for i in range(depth):
                z = tf.matmul(x, self.var_list[i * 2]) + self.var_list[i * 2 + 1]
                y = tf.nn.softmax(z) if i == depth-1 else tf.nn.leaky_relu(z)
                x = y
    
            # функция стоимости
            self.loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=z, labels=self.labels))
    
            # точность (accuracy)
            self.correct_preds = tf.equal(tf.argmax(z, axis=1), tf.argmax(self.labels, axis=1))
            self.accuracy = tf.reduce_mean(tf.cast(self.correct_preds, tf.float32))
    
            # значения выходов сети
            F = tf.reduce_mean(tf.reduce_sum(tf.square(y), axis=1))
    
            # вычисляем градиенты значений по весам
            self.cur_importances = [tf.abs(grad) for grad in tf.gradients(F, self.var_list)]
    
            self.train_step = None
    
        def open_lesson(self, learning_rate=1.0, lmbda=0.0):
            """
            Открытие урока обучения сети на отдельном датасете
            :param learning_rate: скорость обучения для SGD
            :param lmbda:         коэффициент влияния важностей - насколько сильно
                                  важности тянут веса к эталонным значениям
            """
            loss = self.loss
    
            if hasattr(self, "star_vars") and lmbda != 0:
                # добавляем к функции стоимости слагаемые-регуляризаторы
                for v in range(len(self._shape)*2-2):
                    loss += tf.reduce_sum(tf.multiply(
                        tf.constant(lmbda / 2. * self.wb_importance[v], tf.float32),
                        tf.square(self.var_list[v] - tf.constant(self.star_vars[v], tf.float32))
                    ))
    
            # устанавливаем шаг оптимизатора
            self.train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss)
    
        def close_lesson(self, closing_set=None):
            """
            Закрытие урока обучения сети на отдельном датасете. Расчет и накопление важностей весов.
            :param closing_set: датасет, на котором будут рассчитаны важности весов после обучения
            :return:
            """
    
            # рассчитываем важности весов на закрывающем датасете
            addendum = self.session.run(self.cur_importances, feed_dict={self.x: closing_set})
    
            # добавляем рассчитанные важности к сохраненным
            for i, a in zip(self.wb_importance, addendum):
                i += a
    
            # запоминаем текущие эталонные веса сети после обучения
            self.star_vars = [v.eval() for v in self.var_list]
    
    
    # функция случайным образом переставляет входы одинаково для всех примеров датасета
    def permute_mnist(mnist):
        perm_inds = list(range(mnist.train.images.shape[1]))
        np.random.shuffle(perm_inds)
        mnist2 = deepcopy(mnist)
        sets = ["train", "validation", "test"]
        for set_name in sets:
            this_set = getattr(mnist2, set_name)
            this_set._images = np.transpose(np.array([this_set.images[:, c] for c in perm_inds]))
        return mnist2
    
    
    def train_model(model, train_set, test_sets, batch_size=100, epochs=1):
        """
        Обучение модели
        :param model:       обучаемая модель
        :param train_set:   обучающий датасет
        :param test_sets:   список датасетов, на которых будет считаться средняя точность
        :param batch_size:  размер батча
        :param epochs:      количество эпох обучения
        :return:            средняя точность на тестовых датасетах после обучения
        """
        num_iters = int(np.ceil(len(train_set.train.labels) * epochs / batch_size))
        for idx in range(num_iters):
            train_batch = train_set.train.next_batch(batch_size)
            feed_dict = {model.x: train_batch[0], model.labels: train_batch[1]}
            model.train_step.run(feed_dict=feed_dict)
            print(f'\rTraining  {idx + 1}/{num_iters} done.', end='')
    
        print(f'\rTraining  {num_iters}/{num_iters} iterations done.')
    
        accuracy = 0.
        for t, test_set in enumerate(test_sets):
            feed_dict = {model.x: test_set.test.images, model.labels: test_set.test.labels}
            accuracy += model.accuracy.eval(feed_dict=feed_dict)
        accuracy /= len(test_sets)
        print(f'Evaluating on {len(test_sets)} test sets done. Accuracy {accuracy}')
        return accuracy
    
    
    def continual_learning(net_struct, data_sets, session, lr, lmbda):
        """
        Последовательное обучение на нескольких обучающих наборах
        :param net_struct: структура сети
        :param data_sets:  список обучающих датасетов для последовательного обучения
        :param session:    tf-сессия
        :param lr:         скорость обучения
        :param lmbda:      степень влияния важностей на обучение
        :return:           список усредненных по выученным датасетам оценок
        """
        model = Model(net_struct, session)
        test_sets = []
        accuracies = []
        for data_set in data_sets:
            test_sets.append(data_set)
            model.open_lesson(lr, lmbda)
            accuracy = train_model(model, data_set, test_sets, 100, 4)
            accuracies.append(accuracy)
            model.close_lesson(data_set.validation.images)
        del model
        return accuracies
    
    # считываем данные MNIST
    mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
    # создаем tf-сессию
    sess = tf.InteractiveSession()
    
    # создаем 10 различных обучающих наборов для последовательного обучения
    mnist0 = mnist
    mnist1 = permute_mnist(mnist)
    mnist2 = permute_mnist(mnist)
    mnist3 = permute_mnist(mnist)
    mnist4 = permute_mnist(mnist)
    mnist5 = permute_mnist(mnist)
    mnist6 = permute_mnist(mnist)
    mnist7 = permute_mnist(mnist)
    mnist8 = permute_mnist(mnist)
    mnist9 = permute_mnist(mnist)
    
    start_time = datetime.datetime.now()
    
    # определим параметры обучения
    data_sets = [mnist0, mnist1, mnist2, mnist3, mnist4, mnist5, mnist6, mnist7, mnist8, mnist9]
    net_struct = [784, 300, 150, 10]
    lmbda = 15.
    learning_rate = 0.2
    
    accuracies = continual_learning(net_struct, data_sets, sess, lr=learning_rate, lmbda=lmbda)
    print ('Total time spent', datetime.datetime.now() - start_time)
    
    dataset_num = range(1, len(accuracies) + 1)
    
    # нарисуем график деградации средней точности на всех выученных датасетах
    plt.figure(figsize=(7, 3.5))
    plt.ylim(0.40, 1.)
    plt.xlim(1, len(accuracies))
    plt.ylabel('Total accuracy')
    plt.xlabel('Number of tasks')
    plt.plot(dataset_num, accuracies, marker=".")
    #plt.legend()
    plt.show()

    Наконец недавно автор (этого опуса) экспериментально обнаружил, что в качестве важности веса связи можно использовать просто суммарный по модулю сигнал, прошедший через связь обученной сетки в процессе пропускания через нее обучающего набора. И такой метод EWC-signal (EWC-S) тоже работает, хоть и немногим хуже, чем предыдущие методы EWC, SI и MAS. В смысле же вычислительной сложности этот метод заслуженно можно назвать китайским – настолько он вычислительно дешев. Он не требует ни выхода сети в виде распределения/softmax, ни расчета каких-либо производных или градиентов. Кроме того, с его помощью важности можно считать прямо в процессе завершающих этапов обучения. Однако, для этого метода tensorflow и torch не построят вычисления для каждой важности за вас, и код придется писать руками.

    Пример реализации EWC-S тут.
    # код рассчитан на python 3.7 и tensorflow 1.15.0
    
    import datetime
    from copy import deepcopy
    
    import matplotlib.pyplot as plt
    import numpy as np
    import tensorflow as tf
    from tensorflow.examples.tutorials.mnist import input_data
    
    
    def _weight_variable(shape):
        # для весов полносвязного слоя инициализируем значения по Каймину Ге (Хе)
        stddev = 2. / np.sqrt(shape[0])
        initial = tf.random.truncated_normal(shape=shape, mean=0.0, stddev=stddev)
        return tf.Variable(initial, dtype=tf.float32)
    
    
    def _bias_variable(shape):
        # смещения инициализируем нулями
        initial = tf.constant(0.0, shape=shape)
        return tf.Variable(initial, dtype=tf.float32)
    
    
    class Model:
        def __init__(self, shape, session):
            """
            :param shape:   структура сети - список из чисел нейронов в каждом слое сети
                            от входа к выходу справа налево, например, [784, 100, 10]
            :param session: tensorflow-сессия для расчетов сети
            """
    
            self.session = session
            self._shape = shape
            depth = len(shape) - 1
            if depth < 1:
                raise ValueError("Недопустимая структура сети!")
    
            # заглушки для входных данных
            self.x = tf.placeholder(tf.float32, shape=[None, shape[0]])
            self.labels = tf.placeholder(tf.float32, shape=[None, shape[-1]])
    
            # все веса слоев сети будем хранить в списке
            self.var_list = []
            for ins, outs in zip(shape[:-1], shape[1:]):
                self.var_list.append(_weight_variable([ins, outs]))
                self.var_list.append(_bias_variable([outs]))
    
            # инициализируем веса сети
            for v in self.var_list:
                session.run(v.initializer)
    
            # список для хранения важностей весов сети
            self.wb_importance = [np.zeros(v.shape, dtype=np.float32) for v in self.var_list]
    
            # строим вычислительный граф
            outputs = []
            x, y, z = self.x, None, None
            for i in range(depth):
                z = tf.matmul(x, self.var_list[i * 2]) + self.var_list[i * 2 + 1]
                y = tf.nn.softmax(z) if i == depth-1 else tf.nn.leaky_relu(z)
                outputs.append(y)
                x = y
    
            # функция стоимости
            self.loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=z, labels=self.labels))
    
            # точность (accuracy)
            self.correct_preds = tf.equal(tf.argmax(z, axis=1), tf.argmax(self.labels, axis=1))
            self.accuracy = tf.reduce_mean(tf.cast(self.correct_preds, tf.float32))
    
            # вычисляем суммарный по модулю прошедший сигнал
            self.signals = []
    
            os = tf.reduce_mean(tf.abs(self.x), axis=0)
            for i in range(depth):
                ws = tf.transpose(tf.multiply(os, tf.transpose(tf.abs(self.var_list[i*2]))))
                self.signals.append(ws)
                os = tf.reduce_mean(tf.abs(outputs[i]), axis=0)
                self.signals.append(os)
    
            self.train_step = None
    
        def open_lesson(self, learning_rate=1.0, lmbda=0.0):
            """
            Открытие урока обучения сети на отдельном датасете
            :param learning_rate: скорость обучения для SGD
            :param lmbda:         коэффициент влияния важностей - насколько сильно
                                  важности тянут веса к эталонным значениям
            """
            loss = self.loss
    
            if hasattr(self, "star_vars") and lmbda != 0:
                # добавляем к функции стоимости слагаемые-регуляризаторы
                for v in range(len(self._shape)*2-2):
                    loss += tf.reduce_sum(tf.multiply(
                        tf.constant(lmbda / 2. * self.wb_importance[v], tf.float32),
                        tf.square(self.var_list[v] - tf.constant(self.star_vars[v], tf.float32))
                    ))
    
            # устанавливаем шаг оптимизатора
            self.train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss)
    
        def close_lesson(self, closing_set=None):
            """
            Закрытие урока обучения сети на отдельном датасете. Расчет и накопление важностей весов.
            :param closing_set: датасет, на котором будут рассчитаны важности весов после обучения
            :return:
            """
    
            # рассчитываем важности весов на закрывающем датасете
            addendum = self.session.run(self.signals, feed_dict={self.x: closing_set})
    
            # добавляем рассчитанные важности к сохраненным
            for i, a in zip(self.wb_importance, addendum):
                i += a
    
            # запоминаем текущие эталонные веса сети после обучения
            self.star_vars = [v.eval() for v in self.var_list]
    
    
    # функция случайным образом переставляет входы одинаково для всех примеров датасета
    def permute_mnist(mnist):
        perm_inds = list(range(mnist.train.images.shape[1]))
        np.random.shuffle(perm_inds)
        mnist2 = deepcopy(mnist)
        sets = ["train", "validation", "test"]
        for set_name in sets:
            this_set = getattr(mnist2, set_name)
            this_set._images = np.transpose(np.array([this_set.images[:, c] for c in perm_inds]))
        return mnist2
    
    
    def train_model(model, train_set, test_sets, batch_size=100, epochs=1):
        """
        Обучение модели
        :param model:       обучаемая модель
        :param train_set:   обучающий датасет
        :param test_sets:   список датасетов, на которых будет считаться средняя точность
        :param batch_size:  размер батча
        :param epochs:      количество эпох обучения
        :return:            средняя точность на тестовых датасетах после обучения
        """
        num_iters = int(np.ceil(len(train_set.train.labels) * epochs / batch_size))
        for idx in range(num_iters):
            train_batch = train_set.train.next_batch(batch_size)
            feed_dict = {model.x: train_batch[0], model.labels: train_batch[1]}
            model.train_step.run(feed_dict=feed_dict)
            print(f'\rTraining  {idx + 1}/{num_iters} done.', end='')
    
        print(f'\rTraining  {num_iters}/{num_iters} iterations done.')
    
        accuracy = 0.
        for t, test_set in enumerate(test_sets):
            feed_dict = {model.x: test_set.test.images, model.labels: test_set.test.labels}
            accuracy += model.accuracy.eval(feed_dict=feed_dict)
        accuracy /= len(test_sets)
        print(f'Evaluating on {len(test_sets)} test sets done. Accuracy {accuracy}')
        return accuracy
    
    
    def continual_learning(net_struct, data_sets, session, lr, lmbda):
        """
        Последовательное обучение на нескольких обучающих наборах
        :param net_struct: структура сети
        :param data_sets:  список обучающих датасетов для последовательного обучения
        :param session:    tf-сессия
        :param lr:         скорость обучения
        :param lmbda:      степень влияния важностей на обучение
        :return:           список усредненных по выученным датасетам оценок
        """
        model = Model(net_struct, session)
        test_sets = []
        accuracies = []
        for data_set in data_sets:
            test_sets.append(data_set)
            model.open_lesson(lr, lmbda)
            accuracy = train_model(model, data_set, test_sets, 100, 4)
            accuracies.append(accuracy)
            model.close_lesson(data_set.validation.images)
        del model
        return accuracies
    
    # считываем данные MNIST
    mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
    # создаем tf-сессию
    sess = tf.InteractiveSession()
    
    # создаем 10 различных обучающих наборов для последовательного обучения
    mnist0 = mnist
    mnist1 = permute_mnist(mnist)
    mnist2 = permute_mnist(mnist)
    mnist3 = permute_mnist(mnist)
    mnist4 = permute_mnist(mnist)
    mnist5 = permute_mnist(mnist)
    mnist6 = permute_mnist(mnist)
    mnist7 = permute_mnist(mnist)
    mnist8 = permute_mnist(mnist)
    mnist9 = permute_mnist(mnist)
    
    start_time = datetime.datetime.now()
    
    # определим параметры обучения
    data_sets = [mnist0, mnist1, mnist2, mnist3, mnist4, mnist5, mnist6, mnist7, mnist8, mnist9]
    net_struct = [784, 300, 150, 10]
    lmbda = 1.
    learning_rate = 0.2
    
    accuracies = continual_learning(net_struct, data_sets, sess, lr=learning_rate, lmbda=lmbda)
    print ('Total time spent', datetime.datetime.now() - start_time)
    
    dataset_num = range(1, len(accuracies) + 1)
    
    # нарисуем график деградации средней точности на всех выученных датасетах
    plt.figure(figsize=(7, 3.5))
    plt.ylim(0.40, 1.)
    plt.xlim(1, len(accuracies))
    plt.ylabel('Total accuracy')
    plt.xlabel('Number of tasks')
    plt.plot(dataset_num, accuracies, marker=".")
    #plt.legend()
    plt.show()

    Также было экспериментально обнаружено, что, для сохранения выученных навыков нейросетки, можно не только привязывать веса к эталонным на резиночках разной упругости (где упругость пропорциональна важности связи), но и просто замедлять скорость изменения (то есть градиенты) весов, как будто для них увеличивается сила трения при сдвиге пропорционально важности веса связи. Это опять слегка ухудшает способность метода сохранять навыки при последовательном обучении, но сильно экономит память, потому что не надо хранить эталонные веса. Такой метод называется Weight Velocity Attenuation (WVA).

    Пример реализации WVA на базе суммарного по модулю сигнала тут.
    # код рассчитан на python 3.7 и tensorflow 1.15.0
    
    import datetime
    from copy import deepcopy
    
    import matplotlib.pyplot as plt
    import numpy as np
    import tensorflow as tf
    from tensorflow.examples.tutorials.mnist import input_data
    from tensorflow.python.training.optimizer import Optimizer
    
    
    def _weight_variable(shape):
        # для весов полносвязного слоя инициализируем значения по Каймину Ге (Хе)
        stddev = 2. / np.sqrt(shape[0])
        initial = tf.random.truncated_normal(shape=shape, mean=0.0, stddev=stddev)
        return tf.Variable(initial, dtype=tf.float32)
    
    
    def _bias_variable(shape):
        # смещения инициализируем нулями
        initial = tf.constant(0.0, shape=shape)
        return tf.Variable(initial, dtype=tf.float32)
    
    
    class _WVA_SGD(tf.train.GradientDescentOptimizer):
    
        def __init__(self, learning_rate, use_locking=False, name="GradientDescent"):
            super(_WVA_SGD, self).__init__(learning_rate, use_locking, name)
    
        def minimize(self, loss, global_step=None, var_list=None,
                     gate_gradients=Optimizer.GATE_OP, aggregation_method=None,
                     colocate_gradients_with_ops=False, name=None,
                     grad_loss=None, impacts=None):
            """ comments """
            grads_and_vars = self.compute_gradients(
                loss, var_list=var_list, gate_gradients=gate_gradients,
                aggregation_method=aggregation_method,
                colocate_gradients_with_ops=colocate_gradients_with_ops,
                grad_loss=grad_loss)
    
            vars_with_grad = [v for g, v in grads_and_vars if g is not None]
            if not vars_with_grad:
                raise ValueError(
                    "No gradients provided for any variable, check your graph for ops"
                    " that do not support gradients, between variables %s and loss %s." %
                    ([str(v) for _, v in grads_and_vars], loss))
    
            if impacts is None:
                processed_grads_and_vars = grads_and_vars
            else:
                impact = iter(impacts)
                processed_grads_and_vars = []
                for g, v in grads_and_vars:
                    if g is None:
                        processed_grads_and_vars.append((g, v))
                    else:
                        processed_grads_and_vars.append((tf.multiply(g, next(impact)), v))
    
            return self.apply_gradients(processed_grads_and_vars,
                                        global_step=global_step,
                                        name=name)
    
    
    class Model:
        def __init__(self, shape, session):
            """
            :param shape:   структура сети - список из чисел нейронов в каждом слое сети
                            от входа к выходу справа налево, например, [784, 100, 10]
            :param session: tensorflow-сессия для расчетов сети
            """
    
            self.session = session
            self._shape = shape
            depth = len(shape) - 1
            if depth < 1:
                raise ValueError("Недопустимая структура сети!")
    
            # заглушки для входных данных
            self.x = tf.placeholder(tf.float32, shape=[None, shape[0]])
            self.labels = tf.placeholder(tf.float32, shape=[None, shape[-1]])
    
            # все веса слоев сети будем хранить в списке
            self.var_list = []
            for ins, outs in zip(shape[:-1], shape[1:]):
                self.var_list.append(_weight_variable([ins, outs]))
                self.var_list.append(_bias_variable([outs]))
    
            # инициализируем веса сети
            for v in self.var_list:
                session.run(v.initializer)
    
            # список для хранения важностей весов сети
            self.wb_importance = [np.zeros(v.shape, dtype=np.float32) for v in self.var_list]
    
            # строим вычислительный граф
            outputs = []
            x, y, z = self.x, None, None
            for i in range(depth):
                z = tf.matmul(x, self.var_list[i * 2]) + self.var_list[i * 2 + 1]
                y = tf.nn.softmax(z) if i == depth-1 else tf.nn.leaky_relu(z)
                outputs.append(y)
                x = y
    
            # функция стоимости
            self.loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=z, labels=self.labels))
    
            # точность (accuracy)
            self.correct_preds = tf.equal(tf.argmax(z, axis=1), tf.argmax(self.labels, axis=1))
            self.accuracy = tf.reduce_mean(tf.cast(self.correct_preds, tf.float32))
    
            # вычисляем суммарный по модулю прошедший сигнал
            self.signals = []
    
            os = tf.reduce_mean(tf.abs(self.x), axis=0)
            for i in range(depth):
                ws = tf.transpose(tf.multiply(os, tf.transpose(tf.abs(self.var_list[i*2]))))
                self.signals.append(ws)
                os = tf.reduce_mean(tf.abs(outputs[i]), axis=0)
                self.signals.append(os)
    
            self.train_step = None
    
        def open_lesson(self, learning_rate=1.0, lmbda=0.0):
            """
            Открытие урока обучения сети на отдельном датасете
            :param learning_rate: скорость обучения для SGD
            :param lmbda:         коэффициент влияния важностей - насколько сильно
                                  важности тянут веса к эталонным значениям
            """
            impacts = [tf.constant(1. / (1. + lmbda * v)) for v in self.wb_importance]
    
            # устанавливаем шаг оптимизатора
            self.train_step = _WVA_SGD(learning_rate).minimize(self.loss, impacts=impacts)
    
        def close_lesson(self, closing_set=None):
            """
            Закрытие урока обучения сети на отдельном датасете. Расчет и накопление важностей весов.
            :param closing_set: датасет, на котором будут рассчитаны важности весов после обучения
            :return:
            """
    
            # рассчитываем важности весов на закрывающем датасете
            addendum = self.session.run(self.signals, feed_dict={self.x: closing_set})
    
            # добавляем рассчитанные важности к сохраненным
            for i, a in zip(self.wb_importance, addendum):
                i += a
    
    
    # функция случайным образом переставляет входы одинаково для всех примеров датасета
    def permute_mnist(mnist):
        perm_inds = list(range(mnist.train.images.shape[1]))
        np.random.shuffle(perm_inds)
        mnist2 = deepcopy(mnist)
        sets = ["train", "validation", "test"]
        for set_name in sets:
            this_set = getattr(mnist2, set_name)
            this_set._images = np.transpose(np.array([this_set.images[:, c] for c in perm_inds]))
        return mnist2
    
    
    def train_model(model, train_set, test_sets, batch_size=100, epochs=1):
        """
        Обучение модели
        :param model:       обучаемая модель
        :param train_set:   обучающий датасет
        :param test_sets:   список датасетов, на которых будет считаться средняя точность
        :param batch_size:  размер батча
        :param epochs:      количество эпох обучения
        :return:            средняя точность на тестовых датасетах после обучения
        """
        num_iters = int(np.ceil(len(train_set.train.labels) * epochs / batch_size))
        for idx in range(num_iters):
            train_batch = train_set.train.next_batch(batch_size)
            feed_dict = {model.x: train_batch[0], model.labels: train_batch[1]}
            model.train_step.run(feed_dict=feed_dict)
            print(f'\rTraining  {idx + 1}/{num_iters} done.', end='')
    
        print(f'\rTraining  {num_iters}/{num_iters} iterations done.')
    
        accuracy = 0.
        for t, test_set in enumerate(test_sets):
            feed_dict = {model.x: test_set.test.images, model.labels: test_set.test.labels}
            accuracy += model.accuracy.eval(feed_dict=feed_dict)
        accuracy /= len(test_sets)
        print(f'Evaluating on {len(test_sets)} test sets done. Accuracy {accuracy}')
        return accuracy
    
    
    def continual_learning(net_struct, data_sets, session, lr, lmbda):
        """
        Последовательное обучение на нескольких обучающих наборах
        :param net_struct: структура сети
        :param data_sets:  список обучающих датасетов для последовательного обучения
        :param session:    tf-сессия
        :param lr:         скорость обучения
        :param lmbda:      степень влияния важностей на обучение
        :return:           список усредненных по выученным датасетам оценок
        """
        model = Model(net_struct, session)
        test_sets = []
        accuracies = []
        for data_set in data_sets:
            test_sets.append(data_set)
            model.open_lesson(lr, lmbda)
            accuracy = train_model(model, data_set, test_sets, 100, 4)
            accuracies.append(accuracy)
            model.close_lesson(data_set.validation.images)
        del model
        return accuracies
    
    # считываем данные MNIST
    mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
    # создаем tf-сессию
    sess = tf.InteractiveSession()
    
    # создаем 10 различных обучающих наборов для последовательного обучения
    mnist0 = mnist
    mnist1 = permute_mnist(mnist)
    mnist2 = permute_mnist(mnist)
    mnist3 = permute_mnist(mnist)
    mnist4 = permute_mnist(mnist)
    mnist5 = permute_mnist(mnist)
    mnist6 = permute_mnist(mnist)
    mnist7 = permute_mnist(mnist)
    mnist8 = permute_mnist(mnist)
    mnist9 = permute_mnist(mnist)
    
    start_time = datetime.datetime.now()
    
    # определим параметры обучения
    data_sets = [mnist0, mnist1, mnist2, mnist3, mnist4, mnist5, mnist6, mnist7, mnist8, mnist9]
    net_struct = [784, 300, 150, 10]
    lmbda = 250.
    learning_rate = 0.2
    
    accuracies = continual_learning(net_struct, data_sets, sess, lr=learning_rate, lmbda=lmbda)
    print ('Total time spent', datetime.datetime.now() - start_time)
    
    dataset_num = range(1, len(accuracies) + 1)
    
    # нарисуем график деградации средней точности на всех выученных датасетах
    plt.figure(figsize=(7, 3.5))
    plt.ylim(0.40, 1.)
    plt.xlim(1, len(accuracies))
    plt.ylabel('Total accuracy')
    plt.xlabel('Number of tasks')
    plt.plot(dataset_num, accuracies, marker=".")
    #plt.legend()
    plt.show()

    Резюмируем. Если вы планируете учить вашу нейросетку последовательно и разному, и у вас мало времени и памяти, то стоит использовать WVA. Если времени мало, но памяти завались, то стоит посмотреть на EWC-S. Если времени вагон, а памяти мало, то стоит важности весов рассчитывать как в MAS, а использовать как в WVA, то есть сделать гибрид WVA-MAS. Если есть и время, и память, и требуется наилучшее сохранение навыков без компромиссов, и код писать ну очень лениво, то стоит использовать полноценный MAS.

    P.S. Подозреваю, что MAS будет выбираться чаще всего именно по последней причине…

    P.P.S. У всех перечисленных методов есть одна тонкость – они работают только если для каждого выхода сети в каждом датасете для последовательного обучения есть примеры его (выход) активирующие. Если датасет содержит примеры, активирующие только часть выходов, нужно применять специальные трюки (считать функцию потерь только на активируемых в датасете выходах). Подробности можно посмотреть в статье про Synaptic Intelligence – см. Split MNIST.

    Средняя зарплата в IT

    120 000 ₽/мес.
    Средняя зарплата по всем IT-специализациям на основании 3 391 анкеты, за 1-ое пол. 2021 года Узнать свою зарплату
    Реклама
    AdBlock похитил этот баннер, но баннеры не зубы — отрастут

    Подробнее

    Комментарии 8

      +1

      В закладки. Не забыть бы прочитать.

        0

        Спасибо! 4 года назад читал оригинальную статью, но потом не смог вспомнить ни авторов, ни название метода — так бы и не нашёл ее, если бы не Вы! С Новым Годом)

          0
          Решением забывшивости нейросеток является динамическое наращивание обучающей способности сети. Вы даете ей новые данные а сеть растет по мере их обработки с новыми данными растут возможности самой сети.
            +1

            А чем это лучше запуска новой параллельной сети?

            0
            Можно такой вопрос: модуль произведения веса на производную лосса по этому весу (как критерий важности, используемый в методе прунинга, предложенном ребятами из Nvidia) кем-нибудь использовался в данной задаче? (Может, прячится за какой-то из аббревиатур, что Вы привели — я пока в первоисточники не ходил)
              0
              С Новым Годом! То, что вы говорите похоже на метод SI — там тоже производная лосса используется в важности. Или на EWC-S — там вес умножается на входной сигнал чтоб важность получить. Вообще у меня есть подозрение, что все эти методы примерно одинаковый результат дают — надо бы посчитать корреляцию важностей, полученных разными методами на одной и той же обученной сетке. Но пока руки не дошли.
                0
                Да, хорошая мысль) Если дойдут руки, поделитесь результатом?
                  0
                  Конечно

            Только полноправные пользователи могут оставлять комментарии. Войдите, пожалуйста.

            Самое читаемое