1. Проблема

Когда мы обучаем модели машинного обучения, почти всегда возникает один и тот же вопрос:

Что именно происходит во время обучения?

Обычно мы смотрим на графики метрик и пытаемся вручную интерпретировать происходящее:

  • Модель недообучена

  • Модель переобучена

  • Имбаланс датасета.

  • Сильно шумные данные.

Можно посмотреть на learning curves и понять, что происходит:

График с рутиной ML)
График с рутиной ML)

Но этот анализ почти всегда выполняется вручную или с помощью простейших эвристических правил. А ведь сколько времени, сил и нервов можно было бы сэкономить, если обучить до 100 эпохи а не до 500 (см картинка выше) :-(

Но можно задать интересный вопрос:

А можно ли автоматически определить состояние обучения модели?

2. Идея

А что если научить отдельную модель, которая будет автоматически определять состояние обучения?

Обучение модели → learning curves → признаки → мета-классификатор → остановка в иде��льный момент (ну или в начале, если будет дисбаланс или данные ужасны)

То есть вместо ручного анализа мы обучаем модель, которая делает это автоматически. Но насколько это эффективно, на проде, интерпретируемы ли эти результаты для разных типов задач и т. д. это вопрос, на который я планирую получить ответ.

3. Генерация датасета

Чтобы обучить такой классификатор, нужен датасет с различными сценариями обучения.

Я решил сгенерировать его программно.

В качестве базового датасета использовался MNIST — классический набор изображений рукописных цифр.

Эксперименты на MNIST для генерации датасета проводились с разными моделями, среди них:

  • logistic regression

  • небольшой MLP

  • большой MLP

  • маленькая CNN

  • большая CNN

Для каждого эксперимента варьировались параметры:

  • размер обучающей выборки

  • случайный seed

  • наличие дисбаланса классов

  • тип сдвига данных

По итогу я обучил 270 моделей и посмотрел их после 1, 5, 6, 11,16,21,26 эпох. По каждой записи были сохранены:

Столбец

Тип

Описание

model

str

Название модели, использованной для обучения (logreg, mlp_small, mlp_large, cnn_small, cnn_large).

train_size

int

Размер выборки для обучения в конкретном эксперименте.

seed

int

Значение random seed для воспроизводимости случайной выборки.

imbalance

bool

Флаг, указывающий, использовался ли искусственный дисбаланс классов (True) или нет (False).

shift_type

str

Тип сдвига данных на тестовой выборке (none, noise, invert).

train_acc

float

Точность модели на тренировочной выборке после текущей эпохи.

val_acc

float

Точность модели на валидационной выборке после текущей эпохи.

test_acc

float

Точность модели на тестовой выборке (с учетом возможного сдвига данных).

gap

float

Разница между тренировочной и валидационной точностью (train_acc - val_acc). Используется для диагностики переобучения.

epochs

int

Количество эпох обучения (для функции train_and_evaluate) — либо номер эпохи в train_with_history.

val_curve

list of list

История точности на валидационной выборке по эпохам до текущей.

epoch

int

Номер текущей эпохи обучения (используется при пошаговом train_with_history).

underfitting

int (0/1)

Диагностический флаг: модель недообучена, если train_acc < 0.7.

overfitting

int (0/1)

Диагностический флаг: модель переобучена, если gap > 0.15.

dataset_shift

int (0/1)

Диагностический флаг: есть смещение тестовых данных, если val_acc - test_acc > 0.15.

С мериками получилось сложно, нельзя точно сказать, что при val_acс 0.9 нет переобучения, однако, в рамках работы я просто тестил всё на test_dataset и ставил метки по нему. правила для меток:

def diagnose(metrics):

    return {
        "underfitting": int(metrics["train_acc"] < 0.7),
        "overfitting": int(metrics["gap"] > 0.15),
        "dataset_shift": int(metrics["val_acc"] - metrics["test_acc"] > 0.15)
    }

В итоге в датасете я получил:

Кол-во меток в  датасете
Кол-во меток в датасете

Касаемо качества датасета, меня устаивает, есть как и ужасные модели, так и неплохие, acc достиг 0.9.

4. Признаки для мета-классификатора

Одним из самых интересных источников информации является форма learning curve. Я вытащил из него много признаков, все признаки на которых я делал метрики (подразумеваются как недоступные я удалил из обучения)

df["curve_start"] = df["val_curve"].apply(lambda x: x[0])
df["curve_mid"] = df["val_curve"].apply(lambda x: x[len(x)//2])
df["curve_end"] = df["val_curve"].apply(lambda x: x[-1])

df["curve_growth"] = df["curve_end"] - df["curve_start"]
df["curve_stability"] = df["val_curve"].apply(np.std)

5. Обучение моделей

Для классификации текущего состояния модели было протестированано несколько алгоритмов:

  • Random Forest

  • XGBoost

  • Logistic Regression

  • ансамбль моделей

Поскольку задача имеет несколько независимых меток, использовался MultiOutputClassifier.

rf = RandomForestClassifier(
    n_estimators=200,
    random_state=42
)

model = MultiOutputClassifier(rf)

model.fit(X_train, y_train)

pred = model.predict(X_test)

Итоги после обучения:

          precision    recall  f1-score   support

       0       0.94      0.89      0.91       177
       1       0.96      0.97      0.96       593
       2       0.97      0.88      0.92       233
       3       0.75      0.73      0.74       419
micro avg      0.90      0.87      0.89      1422
macro avg      0.90      0.87      0.88      1422
weighted avg   0.90      0.87      0.88      1422
samples avg    0.86      0.84      0.83      1422
Важность признаков в случайном лесе.
Важность признаков в случайном лесе.

Лучшие результаты показал Random Forest.

Он хорошо определял:

  • underfitting

  • dataset shift

Логистическая регрессия показала более низкое качество — что ожидаемо, так как она является линейным классификатором. Ансамбль моделей практически не улучшил результат.

6. Результаты

Этот подход можно использовать в ML-pipeline.

Это может позволить:

  • автоматически выявлять переобучение

  • обнаруживать проблемы с данными

  • останавливать обучение раньше

  • экономить вычислительные ресурсы

На этом всё, если кому-то интересно потрогать руками напишите в комментах, на гитхабе надо убраться, буду рад критике, есть что добавить и так, планирую дописать 2 часть. Вот мой гитхаб, в целом там иногда есть что-то интересное.

Спасибо, всем хорошего дня.