От диплома до продакшена: Как я создавал архитектуры ИИ-проектов

Часть 4: Обучение и валидация модели — 250 эпох, 94.55% точности и борьба с переобучением

В этой части я расскажу о самом критическом этапе — обучении модели. Здесь 250 эпох отделяют работающую модель от неработающей, а правильная настройка гиперпараметров определяет успех всего проекта.


Содержание


Введение: Почему обучение — это не просто «нажать кнопку»

Многие считают, что обучение нейросети — это просто загрузить данные и нажать «старт». На практике всё сложнее. Неправильная настройка гиперпараметров, отсутствие контроля переобучения или неправильная подготовка данных могут свести на нет всю предыдущую работу.

В этой части я подробно расскажу о том, как проходило обучение моей модели для распознавания голосовых команд, с какими проблемами я столкнулся и как их решал.


Подготовка данных к обучению

Структура датасета

Перед обучением необходимо было подготовить данные в формате, пригодном для подачи в нейросеть. Исходные аудиофайлы были обработаны и преобразованы в числовые признаки.

Параметр

Значение

Описание

Общее количество примеров

273

После отбора и структурирования

Классы команд

4

Комната, Дверь, Камера, Фон

Примеров на класс

~68

В среднем на каждый класс

Признаков на пример

21

9 SSR + 9 CHZ + 3 MFC

Проверочная выборка

20%

55 примеров для валидации

Нормализация данных

Перед обучением все признаки были нормализованы для исключения резких разниц в расчётах:

from sklearn import preprocessing as prep

# Нормализация данных
xTrainSSR = prep.normalize(xTrainSSR)
xTrainCHZ = prep.normalize(xTrainCHZ)
xTrainMFC = prep.normalize(xTrainMFC)

Почему нормализация важна:

Вижу что наборы значений сильно отличаются между собой, там где цифровые значения низкие, будут менее значимы при объединении. Можно использовать вариант с нормализацией данных, а можно написать нейронку с несколькими входами.

Нормализация обеспечивает:

  • Равный вклад всех признаков в обучение

  • Стабильную сходимость градиентного спуска

  • Защиту от доминирования признаков с большими значениями


Гиперпараметры и их обоснование

Таблица гиперпараметров

Параметр

Значение

Обоснование

epochs

250

Экспериментально подобрано. 100 эпох — недостаточная сходимость (val_accuracy ~87%), 500 эпох — начинается переобучение (val_accuracy падает до ~91%)

batch_size

10

Маленький датасет (273 примера). Меньший батч даёт более стабильные градиенты

learning_rate

1e-4

Стандарт для Adam. Слишком высокий — нестабильность, слишком низкий — медленная сходимость

validation_split

0.2

Стандартное соотношение. 20% данных достаточно для контроля переобучения

optimizer

Adam

Адаптивный оптимизатор. Лучше SGD для небольших датасетов

loss

categorical_crossentropy

Для многоклассовой классификации (4 класса)

Компиляция модели

model.compile(optimizer=Adam(1e-4), 
              loss='categorical_crossentropy',
              metrics=['accuracy'])

Процесс обучения: ключевые эпохи

Полный лог обучения содержит 250 эпох. Для наглядности привожу ключевые эпохи, которые показывают динамику обучения:

Эпоха

loss

val_loss

accuracy

val_accuracy

1

1.82

1.26

21%

94.55%

20

0.55

1.34

79%

27%

77

0.38

0.61

83%

94.55%

150

0.28

1.01

87%

85%

247

0.16

0.91

94%

94.55%

Наблюдения на ранних эпохах

Эпоха 1: Точность на обучении 21.10%, на валидации 94.55%

Высокая валидация на первой эпохе — случайное совпадение. Модель ещё не обучилась, но случайная инициализация весов дала хороший результат на маленькой валидационной выборке.

Эпоха 2-3: Точность на валидации упала до 0%

Модель начала переобучаться на обучающей выборке. Это нормальное явление на ранних этапах обучения.

Эпоха 20: Стабилизация на уровне 78.90% (train) и 27.27% (val)

Начало сходимости модели. После этой эпохи точность на валидации начинает расти.

Полные логи обучения (фрагмент)

Epoch 1/250
22/22 [==============================] - 3s 33ms/step 
- loss: 1.8197 - accuracy: 0.2110 
- val_loss: 1.2615 - val_accuracy: 0.9455

.......

Epoch 77/250
22/22 [==============================] - 0s 13ms/step 
- loss: 0.3819 - accuracy: 0.8257 
- val_loss: 0.6105 - val_accuracy: 0.9455

.......

Epoch 247/250
22/22 [==============================] - 0s 16ms/step 
- loss: 0.1618 - accuracy: 0.9404 
- val_loss: 0.9144 - val_accuracy: 0.9455

Проблемы и решения

Проблема 1: Нестабильность на ранних эпохах

Проблема: Точность на валидации падала до 0% на 2-3 эпохах

Причина:

  • Маленький размер батча (10 примеров)

  • Небольшой размер датасета (273 примера)

  • Случайная инициализация весов

Решение:

  • Продолжение обучения до стабилизации

  • Dropout (0.2-0.3) для регуляризации

  • BatchNormalization для стабилизации градиентов

Результат: После эпохи 20 точность на валидации начала стабильно расти.

Проблема 2: Переобучение

Проблема: Разница между точностью на обучении и валидации

Причина:

  • Модель «запоминает» обучающие данные

  • Недостаточная регуляризация

Решение:

  • Dropout (0.2 на ранних слоях, 0.3 на поздних)

  • BatchNormalization после каждого Conv1D и Dense слоя

  • 20% валидационная выборка для контроля

Результат:

Метрика

Значение

Интерпретация

Train Accuracy

94.04%

Модель хорошо выучила данные

Val Accuracy

94.55%

Хорошая обобщающая способность

Разница

0.51%

Признак отсутствия значительного переобучения

Разница между точностью на обучении и валидации менее 1% — признак отсутствия значительного переобучения.

Проблема 3: Маленький датасет

Проблема: Всего 273 примера для обучения

Причина:

  • Ограниченное время на сбор данных

  • Ограниченные ресурсы для разметки

Решение:

  • Тщательный отбор качественных примеров

  • Нормализация данных для стабильности обучения

  • Регуляризация (Dropout + BatchNorm)

Что можно улучшить:

# Data Augmentation для аудио
def add_noise(data, noise_level=0.005):
    noise = np.random.randn(len(data)) * noise_level
    return data + noise

def change_speed(data, speed_factor=1.1):
    return librosa.effects.time_stretch(data, rate=speed_factor)

def change_pitch(data, pitch_factor=0.7):
    return librosa.effects.pitch_shift(data, sr=22050, n_steps=pitch_factor)

Ожидаемый эффект: Увеличение датасета в 5-10 раз, улучшение обобщающей способности.


Результаты и выводы

Финальные метрики

Метрика

Значение

Комментарий

Train Accuracy

94.04%

Модель хорошо выучила данные

Val Accuracy

94.55%

Хорошая обобщающая способность

Train Loss

0.1618

Низкая ошибка на обучении

Val Loss

0.9144

Приемлемая ошибка на валидации

Эпох обучения

250

Достаточно для сходимости

Размер батча

10

Малый размер из-за ограничений Colab

Ключевые выводы

  1. Стабильность важнее скорости — 250 эпох обеспечили стабильную сходимость

  2. Регуляризация работает — Dropout + BatchNormalization предотвратили переобучение

  3. Валидация обязательна — 20% на валидацию позволили контролировать качество

  4. Есть куда расти — Data Augmentation, Transfer Learning, Early Stopping могут улучшить результат

Что можно улучшить в будущем

Метод

Ожидаемый эффект

Сложность реализации

Data Augmentation

+3-5% точности

Низкая

Transfer Learning

+5-10% точности

Средняя

Early Stopping

Экономия времени

Низкая

Квантование

Уменьшение размера модели в 4 раза

Средняя


Что будет в следующей части?

Часть 5: Интеграция с устройствами «Умного дома»

В следующей части я расскажу о том, как обученная модель интегрируется с реальными устройствами умного дома:

  • Протоколы связи (Wi-Fi, Bluetooth, Zigbee)

  • Управление освещением

  • Управление дверями и замками

  • Управление камерами наблюдения

  • Интеграция с бытовой техникой

  • Пример: GSM-звонок и шлагбаум

  • Пример: Детекция госномера

  • Масштабируемость архитектуры


Источники и ресурсы

Исходный код проекта

Файл

Описание

Ссылка

Jupyter Notebook

Код модели и обучение

[SmartHome v4.6.ipynb](SmartHome v4.6.ipynb)

GitHub

Репозиторий проекта

github.com/AlekseyVB/SmartHome

Библиотеки и инструменты

# Основные библиотеки для работы с аудио
import librosa              # Обработка аудио
import librosa.display      # Визуализация аудио

# Библиотеки для нейросетей
import tensorflow as tf     # Фреймворк для глубокого обучения
from tensorflow.keras import layers, models

# Утилиты
from sklearn.preprocessing import StandardScaler  # Нормализация

Вопросы для обсуждения

  1. Какие методы регуляризации вы используете в своих проектах?

  2. Как вы боретесь с переобучением на маленьких датасетах?

  3. Используете ли вы Transfer Learning для аудио-задач?

Делитесь в комментариях! 👇