Обновить
4K+
2
Александра Насекайло@AlexaN123

Пользователь

8
Рейтинг
Отправить сообщение

Отличное замечание - действительно, это стоит учитывать. Пример упрощен.

Колебания loss могут указывать на слишком большие шаги градиентного спуска. В таких случаях целесообразно уменьшить скорость обучения (learning rate). 

  1. Запоминаем лучший loss и счетчик эпох без улучшения.

  2. Вычисляем loss текущей эпохи. Если loss улучшился хотя бы на LOSS_MIN_DELTA относительно best_loss, то обновляем best_loss и сбрасываем epochs_without_improvement. Иначе увеличиваем счетчик epochs_without_improvement.

  3. Если количество эпох без улучшения достигло допустимого, то уменьшаем LR.

LR_DECAY = 0.5           # во сколько раз уменьшать LR
PATIENCE = 5             # сколько проверок ждать улучшения
MIN_LR = 1e-6            # минимальный LR
LOSS_MIN_DELTA = 1e-6    # минимальное улучшение, чтобы шум не считался прогрессом

def adjust_learning_rate_on_plateau(loss, best_loss, learning_rate, epochs_without_improvement):
    if loss < best_loss - LOSS_MIN_DELTA:
        return loss, 0, learning_rate

    epochs_without_improvement += 1

    if epochs_without_improvement >= PATIENCE:
        learning_rate = max(learning_rate * LR_DECAY, MIN_LR)
        epochs_without_improvement = 0
        print(f"Learning rate reduced to {learning_rate:.6f}")

    return best_loss, epochs_without_improvement, learning_rate

class FeedForwardNeuralNetwork:

  # Код класса
  # .................

    def train(self, input_data, target_output, epochs, learning_rate):
        best_loss = float('inf')
        epochs_without_improvement = 0

        for epoch in range(epochs):
            hidden_activation, output_activation = self.forward(input_data)
            self.backward(input_data, target_output, learning_rate, hidden_activation, output_activation)

            loss = np.mean(np.square(target_output - output_activation))

            best_loss, epochs_without_improvement, learning_rate = adjust_learning_rate_on_plateau(
                loss, 
                best_loss, 
                learning_rate, 
                epochs_without_improvement
            )

            if epoch % 100 == 0:
                print(f'Epoch {epoch}, Loss: {loss:.4f}')

Информация

В рейтинге
889-й
Зарегистрирован
Активность

Специализация

Бэкенд разработчик