Как стать автором
Поиск
Написать публикацию
Обновить
606.78
OTUS
Развиваем технологии, обучая их создателей

fit() для новичков

Уровень сложностиПростой
Время на прочтение5 мин
Количество просмотров2K

Привет, Хабр!

Эта статья для тех, кто только‑только погружается в машинное обучение и ещё не до конца понимает, что скрывается за интересным вызовом model.fit(). Вы, возможно, уже настраивали ноутбуки, пробовали разные датасеты и, может, даже словили пару неожиданных ошибок — и это нормально.


Зачем копать глубже за fit()

На старте может казаться, что достаточно написать:

model = RandomForestClassifier(n_estimators=100)
model.fit(X_train, y_train)

— и всё заработает. Но стоит проекту вырасти, можно столкнуться с подвохами:

  • Неожиданные NotFittedError при predict()

  • Упавшая память на больших выборках

  • Странное поведение при дообучении

  • Сложности в интеграции конвейеров и трансформеров

Зная почему так происходит, можно оптимизировать время обучения, контролировать ошибки и выстроить гибкий пайплайн. Сначала разберём, как fit() проверяет наши данные, а затем пройдёмся по остальным этапам.

Валидация данных

Сразу после вызова fit(X, y) у модели запускается внутренняя проверка — validatedata. И да, это не просто формальность:

  • Преобразование: если вы передали pandas.DataFrame, он конвертируется в numpy.ndarray.

  • Сверка размеров: число строк в X должно совпадать с длиной y.

  • Обработка пропусков: np.nan и разреженные форматы распознаются и обрабатываются.

  • Приведение типов: целочисленные и низкоточные данные автоматически приводятся к float64, чтобы алгоритмы могли нормально считать градиенты.

Представьте, вы случайно передали 100 строк признаков, а меток всего 99 — без этой проверки обучение просто рухнет где‑то в глубинах C‑библиотек с непонятным «segmentation fault». А так вы получите понятную ошибку и сможете сразу исправить проблему.

Куда же уходят ваши настройки и гиперпараметры?

Сохраняем гиперпараметры: BaseEstimator в действии

Все алгоритмы в scikit-learn наследуют BaseEstimator. При создании объекта, допустим:

model = LogisticRegression(C=0.1, penalty='l2')

— параметры C и penalty аккуратно ложатся в атрибут dict. Благодаря этому:

Модель можно клонировать clone(model) с теми же настройками. GridSearchCV переберёт все комбинации гиперпараметров. При сериализации joblib.dump можно быть уверенным, что вы не потеряете ни одного значения.

С настройками разобрались, дальше — где и как происходит само обучение.

Собственно обучение: что таится в _fit()

Метод fit() передаёт дело приватному _fit(), где и идёт основная математика.

  • Линейная регрессия решает нормальные уравнения:

    X_aug = np.hstack([np.ones((n,1)), X])  # добавляем единичный столбец
    theta = np.linalg.solve(X_aug.T.dot(X_aug), X_aug.T.dot(y))
    self.intercept_, self.coef_ = theta[0], theta[1:]

    Решаем систему уравнений, только на миллионах строк.

  • Стохастический градиентный спуск (SGDClassifier):

    w = np.zeros(n_features)
    for epoch in range(max_iter):
        for Xi, yi in shuffle(X, y):
            grad = compute_gradient(w, Xi, yi)
            w -= eta * grad

    Здесь важно правильно подобрать learning_rate (eta): слишком большой — не схватится за минимум, слишком маленький — будет ползти вечность.

  • Деревья решений рекурсивно разбивают выборку по лучшим признакам, чтобы минимизировать энтропию или MSE — об этом можно писать отдельный роман, но суть в том, что каждый сплит — это отдельное вычисление статистики.

Когда вычисления завершены, результаты куда‑то записываются — давайте взглянем, куда именно.

Куда деваются результаты — атрибуты с подчёркиванием

После обучения у модели появляются атрибуты, оканчивающиеся на _:

  • coef_, intercept_ у линейных моделей

  • classes_ у классификаторов

  • feature_importances_ у ансамблей деревьев

  • статистические буферы (n_iter_, история градиентов и пр.)

Чтобы убедиться, что модель обучена, я всегда использую:

from sklearn.utils.validation import check_is_fitted
check_is_fitted(model)

— и если что‑то упущено, получите дружелюбный NotFittedError.

Различия между fit(), transform() и fit_transform()

  • fit(X, y): готовит модель к работе — валидация и вычисление внутренних параметров.

  • transform(X): применяет уже посчитанные параметры для преобразования данных (нормализация, PCA и др.).

  • fit_transform(X): объединение первых двух шагов для трансформеров, экономя одну итерацию по данным.

Например:

scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)  # за один проход
X_new = scaler.transform(X_test)    # применяем те же параметры

Иногда модели умеют оптимизировать fit_transform(), объединяя вычисления в один цикл.

Но что если объём данных огромен или они приходят потоком? Тогда без partial_fit() не обойтись.

Онлайн-обучение и partial_fit()

Когда данные не помещаются в память или приходят постоянно, используем partial_fit():

from sklearn.linear_model import SGDClassifier
clf = SGDClassifier(max_iter=1, tol=None)
# Первый батч: нужно явно указать все классы
clf.partial_fit(X_batch1, y_batch1, classes=np.unique(y_full))
for Xb, yb in get_next_batches():
    clf.partial_fit(Xb, yb)

Каждый вызов берёт новую порцию данных и продолжает обучение с текущих весов, не забывая историю градиентов.

А если хочется добавить новые деревья в лес без пересчёта старых? Тогда выручит warm_start.

warm_start:

У многих ансамблей RandomForest, GradientBoosting есть опция warm_start. Вот как она работает на практике:

  1. Создаём лес из 50 деревьев:

    rf = RandomForestClassifier(n_estimators=50, warm_start=True)
    rf.fit(X_train, y_train)
  2. Хотим добавить ещё 50 деревьев:

    rf.n_estimators = 100
    rf.fit(X_train, y_train)  # новые 50 деревьев дописываются к существующим

Это позволяет постепенно наращивать сложность модели, не теряя уже обученные структуры.

Часто мы объединяем разные шаги в единый конвейер — посмотрим, как Pipeline и GridSearchCV взаимодействуют с fit().

Pipeline и GridSearchCV

С Pipeline мы объединяем несколько шагов:

from sklearn.pipeline import Pipeline

pipe = Pipeline([
    ('scaler', StandardScaler()),
    ('clf', LogisticRegression())
])
pipe.fit(X_train, y_train)

Последовательность действий:

  1. scaler.fit_transform(X_train)

  2. clf.fit(X_scaled, y_train)

Если этот пайплайн завернуть в GridSearchCV, для каждого набора параметров он будет прогонять fit() заново, сохранять результаты и выбирать наилучший вариант.

Параметры обучения и колбэки

В чистом scikit-learn у fit() самые распространённые доп. параметры — это sample_weight и флаги для проверки входных данных. Но многие сторонние библиотеки (XGBoost, LightGBM) через знакомый интерфейс fit() принимают:

  • early_stopping_rounds — остановка по валидационной метрике

  • eval_set — данные для валидации

  • verbose — подробный лог обучения

А что делать, если модель может упаковаться в мультиядерный режим? Тут пригодится n_jobs.

Отладка и профилирование fit()

Чтобы понять, куда уходит время и память, рекомендую:

  • cProfile:

    python -m cProfile train.py
  • line_profiler и @profile для детального разбора функций

  • memory_profiler:

    from memory_profiler import profile
    @profile
    def train():
        model.fit(X, y)
  • verbose у моделей для по‑шаговых логов

Что важно запомнить

В итоге — fit() это последовательность: валидация, сохранение параметров, запуск fit() и запись результатов. Для потоковых данных используйте partialfit(), для дообучения ансамблей — warm_start. Собирайте всё в Pipeline, подключайте GridSearchCV и не забывайте про логирование и профилирование.


Хотите освоить машинное обучение с нуля и стать уверенным Junior-специалистом? Онлайн-курс Otus «Machine Learning. Basic» — это ваш шанс получить знания от экспертов отрасли и прокачать навыки на реальных проектах.

За время обучения вы не только разберётесь в теории, но и будете решать реальные задачи, анализировать данные и строить свои первые модели. Прокачайте компетенции, которые востребованы на рынке, и начните карьеру в одном из самых перспективных направлений!

Теги:
Хабы:
Всего голосов 8: ↑6 и ↓2+7
Комментарии0

Публикации

Информация

Сайт
otus.ru
Дата регистрации
Дата основания
Численность
101–200 человек
Местоположение
Россия
Представитель
OTUS