#Импортируем все необходимые библиотеки import pandas as pd from catboost import CatBoostClassifier from sklearn.model_selection import train_test_split from sklearn.metrics import accuracy_score import numpy as np import matplotlib.pyplot as plt import seaborn as sns import json
# 🔕 Отключаем предупреждения, чтобы не загромождали вывод import warnings warnings.filterwarnings('ignore')
### Установим красивые дефолтные настройки ### Может быть лень постоянно прописывать ### У графиков параметры цвета, размера, шрифта ### Можно положить их в словарь дефолтных настроек import matplotlib as mlp # Сетка (grid) mlp.rcParams['axes.grid'] = True mlp.rcParams['grid.color'] = '#D3D3D3' mlp.rcParams['grid.linestyle'] = '--' mlp.rcParams['grid.linewidth'] = 1 # Цвет фона mlp.rcParams['axes.facecolor'] = '#F9F9F9' # светло-серый фон внутри графика mlp.rcParams['figure.facecolor'] = '#FFFFFF' # фон всего холста # Легенда mlp.rcParams['legend.fontsize'] = 14 mlp.rcParams['legend.frameon'] = True mlp.rcParams['legend.framealpha'] = 0.9 mlp.rcParams['legend.edgecolor'] = '#333333' # Размер фигуры по умолчанию mlp.rcParams['figure.figsize'] = (10, 6) # Шрифты mlp.rcParams['font.family'] = 'DejaVu Sans' # можешь заменить на 'Arial', 'Roboto' и т.д. mlp.rcParams['font.size'] = 16 # Цвет осей (спинки) mlp.rcParams['axes.edgecolor'] = '#333333' mlp.rcParams['axes.linewidth'] = 2 # Цвет основного текста mlp.rcParams['text.color'] = '#222222'
# Отдельно скачиваю train... train_df = pd.read_csv('../data/train.csv')
# ... и отдельно test test_df = pd.read_csv('../data/test.csv')
# Посмотрим первые 10 строк train'a train_df.head(10)
PassengerId | Survived | Pclass | Name | Sex | Age | SibSp | Parch | Ticket | Fare | Cabin | Embarked | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 1 | 0 | 3 | Braund, Mr. Owen Harris | male | 22.0 | 1 | 0 | A/5 21171 | 7.2500 | NaN | S |
1 | 2 | 1 | 1 | Cumings, Mrs. John Bradley (Florence Briggs Th... | female | 38.0 | 1 | 0 | PC 17599 | 71.2833 | C85 | C |
2 | 3 | 1 | 3 | Heikkinen, Miss. Laina | female | 26.0 | 0 | 0 | STON/O2. 3101282 | 7.9250 | NaN | S |
3 | 4 | 1 | 1 | Futrelle, Mrs. Jacques Heath (Lily May Peel) | female | 35.0 | 1 | 0 | 113803 | 53.1000 | C123 | S |
4 | 5 | 0 | 3 | Allen, Mr. William Henry | male | 35.0 | 0 | 0 | 373450 | 8.0500 | NaN | S |
5 | 6 | 0 | 3 | Moran, Mr. James | male | NaN | 0 | 0 | 330877 | 8.4583 | NaN | Q |
6 | 7 | 0 | 1 | McCarthy, Mr. Timothy J | male | 54.0 | 0 | 0 | 17463 | 51.8625 | E46 | S |
7 | 8 | 0 | 3 | Palsson, Master. Gosta Leonard | male | 2.0 | 3 | 1 | 349909 | 21.0750 | NaN | S |
8 | 9 | 1 | 3 | Johnson, Mrs. Oscar W (Elisabeth Vilhelmina Berg) | female | 27.0 | 0 | 2 | 347742 | 11.1333 | NaN | S |
9 | 10 | 1 | 2 | Nasser, Mrs. Nicholas (Adele Achem) | female | 14.0 | 1 | 0 | 237736 | 30.0708 | NaN | C |
# Посмотрим информацию по train'у train_df.info()
RangeIndex: 891 entries, 0 to 890 Data columns (total 12 columns): # Column Non-Null Count Dtype --- ------ -------------- ----- 0 PassengerId 891 non-null int64 1 Survived 891 non-null int64 2 Pclass 891 non-null int64 3 Name 891 non-null object 4 Sex 891 non-null object 5 Age 714 non-null float64 6 SibSp 891 non-null int64 7 Parch 891 non-null int64 8 Ticket 891 non-null object 9 Fare 891 non-null float64 10 Cabin 204 non-null object 11 Embarked 889 non-null object dtypes: float64(2), int64(5), object(5) memory usage: 83.7+ KB
Видно, что есть пропуски в признаке возраст. Очень много пропусков в признаке кабина
# Получаем базовую статистику по числовым признакам (среднее, медиана, std и т.д.) train_df.describe()
PassengerId | Survived | Pclass | Age | SibSp | Parch | Fare | |
|---|---|---|---|---|---|---|---|
count | 891.000000 | 891.000000 | 891.000000 | 714.000000 | 891.000000 | 891.000000 | 891.000000 |
mean | 446.000000 | 0.383838 | 2.308642 | 29.699118 | 0.523008 | 0.381594 | 32.204208 |
std | 257.353842 | 0.486592 | 0.836071 | 14.526497 | 1.102743 | 0.806057 | 49.693429 |
min | 1.000000 | 0.000000 | 1.000000 | 0.420000 | 0.000000 | 0.000000 | 0.000000 |
25% | 223.500000 | 0.000000 | 2.000000 | 20.125000 | 0.000000 | 0.000000 | 7.910400 |
50% | 446.000000 | 0.000000 | 3.000000 | 28.000000 | 0.000000 | 0.000000 | 14.454200 |
75% | 668.500000 | 1.000000 | 3.000000 | 38.000000 | 1.000000 | 0.000000 | 31.000000 |
max | 891.000000 | 1.000000 | 3.000000 | 80.000000 | 8.000000 | 6.000000 | 512.329200 |
# И статистику по объектным признакам train_df.describe(include='object')
Name | Sex | Ticket | Cabin | Embarked | |
|---|---|---|---|---|---|
count | 891 | 891 | 891 | 204 | 889 |
unique | 891 | 2 | 681 | 147 | 3 |
top | Braund, Mr. Owen Harris | male | 347082 | G6 | S |
freq | 1 | 577 | 7 | 4 | 644 |
# Посмотрим на распределение таргета sns.countplot(x='Survived', data=train_df) plt.title('Распределение выживших') plt.show()

Дисбаланса классов не неаблюдаем
# Посмотрим на распределение таргета по признакам # Пол plt.figure(figsize=(5,4)) sns.countplot(x='Sex', hue='Survived', data=train_df) plt.title('Пол и выживание') plt.show() # Класс каюты plt.figure(figsize=(5,4)) sns.countplot(x='Pclass', hue='Survived', data=train_df) plt.title('Класс каюты и выживание') plt.show() # Порт посадки plt.figure(figsize=(5,4)) sns.countplot(x='Embarked', hue='Survived', data=train_df) plt.title('Порт посадки и выживание') plt.show()



Видно, что все признаки являются важными для таргета. В противном случае графики для разных признаков были бы одинаковыми.
# Для возраста и платы за проезд посмотрим на ящики с усами # Возраст plt.figure(figsize=(6,5)) sns.boxplot(x='Survived', y='Age', data=train_df) plt.title('Возраст и выживание (boxplot)') plt.show() # Fare plt.figure(figsize=(6,5)) sns.boxplot(x='Survived', y='Fare', data=train_df) plt.title('Стоимость билета и выживание (boxplot)') plt.show()


Тоже видны различия в зависимости от таргета
# Сделаем списки: категориальные и числовые признаки categorical_cols = ['Sex', 'Pclass', 'Embarked', 'Cabin'] numeric_cols = ['Age', 'Fare', 'SibSp', 'Parch']
# Посмотрим тепловую карту корреляции между числовыми признаками plt.figure(figsize=(10,8)) sns.heatmap(train_df.corr(numeric_only=True), annot=True, cmap='coolwarm', fmt=".2f") plt.title('Корреляция числовых признаков') plt.show()

Визуально видно, что мультиколлинеарных признаков нет, но проверим с помощью функции
### Секретная функция со Stackovervlow для фильтрации признаков def get_redundant_pairs(df): pairs_to_drop = set() cols = df.columns for i in range(0, df.shape[1]): for j in range(0, i+1): pairs_to_drop.add((cols[i], cols[j])) return pairs_to_drop def get_top_abs_correlations(df, n=5): au_corr = df.corr().abs().unstack() labels_to_drop = get_redundant_pairs(df) au_corr = au_corr.drop(labels=labels_to_drop).sort_values(ascending=False) return au_corr[0:n] print("Top Absolute Correlations") print(get_top_abs_correlations(train_df[numeric_cols], 10))
Top Absolute Correlations SibSp Parch 0.414838 Age SibSp 0.308247 Fare Parch 0.216225 Age Parch 0.189119 Fare SibSp 0.159651 Age Fare 0.096067 dtype: float64
Мультиколлинеарности нет - подтверждаем
# Смотрим пропуски train_df.isnull().sum()
PassengerId 0 Survived 0 Pclass 0 Name 0 Sex 0 Age 177 SibSp 0 Parch 0 Ticket 0 Fare 0 Cabin 687 Embarked 2 dtype: int64
#Удаляем пропуски в колонке "Embarked" #Так как их всего два #Сейчас до объединения #Тренировочных и тестовых данных #Можно это делать train_df = train_df.dropna(subset=['Embarked']).copy()
#Смотрим снова train_df.isnull().sum()
PassengerId 0 Survived 0 Pclass 0 Name 0 Sex 0 Age 177 SibSp 0 Parch 0 Ticket 0 Fare 0 Cabin 687 Embarked 0 dtype: int64
После удаления можно объединять строки. Удаляли до объединения, чтобы не повредить тестовые данные.
Теперь можно готовить данные к объединению и обработке общего датафрейма.
Обязательно нужно всё сделать правильно, чтобы не испортить данные
#Добавим колонку-метку, чтобы потом правильно разделить данные обратно train_df['is_train'] = 1 test_df['is_train'] = 0
#Добавим фиктивную колонку `Survived` в тест (чтобы структура была одинаковая) test_df['Survived'] = np.nan
#Сохраняем PassengerId из теста для submission passenger_ids = test_df['PassengerId'].copy()
(колонка не важна для обучения, но требуется в итоговом файле решения)
#Удаляем колонку PassengerId перед объединением — она не нужна для модели train_df = train_df.drop(columns=['PassengerId']) test_df = test_df.drop(columns=['PassengerId'])
#Объединяем тренировочные и тестовые данные для одинаковой обработки данных full_df = pd.concat([train_df, test_df], axis=0).reset_index(drop=True)
#Пропущенные значения к признаке Age заменяем медианным значением по всем пассажирам #Но которые были только в тренировочных данных #Медианное значение менее чувствительно к выбросам в данных full_df['Age'] = full_df['Age'].fillna(train_df['Age'].median())
# Снова смотрим пропуски, но уже в объединённом датафрейме full_df.isnull().sum()
Survived 418 Pclass 0 Name 0 Sex 0 Age 0 SibSp 0 Parch 0 Ticket 0 Fare 1 Cabin 1014 Embarked 0 is_train 0 dtype: int64
# Посмотрим ещё раз на данные, чтобы принять решение, что делать с признаком плата за проезд full_df.head(20)
Survived | Pclass | Name | Sex | Age | SibSp | Parch | Ticket | Fare | Cabin | Embarked | is_train | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 0.0 | 3 | Braund, Mr. Owen Harris | male | 22.0 | 1 | 0 | A/5 21171 | 7.2500 | NaN | S | 1 |
1 | 1.0 | 1 | Cumings, Mrs. John Bradley (Florence Briggs Th... | female | 38.0 | 1 | 0 | PC 17599 | 71.2833 | C85 | C | 1 |
2 | 1.0 | 3 | Heikkinen, Miss. Laina | female | 26.0 | 0 | 0 | STON/O2. 3101282 | 7.9250 | NaN | S | 1 |
3 | 1.0 | 1 | Futrelle, Mrs. Jacques Heath (Lily May Peel) | female | 35.0 | 1 | 0 | 113803 | 53.1000 | C123 | S | 1 |
4 | 0.0 | 3 | Allen, Mr. William Henry | male | 35.0 | 0 | 0 | 373450 | 8.0500 | NaN | S | 1 |
5 | 0.0 | 3 | Moran, Mr. James | male | 28.0 | 0 | 0 | 330877 | 8.4583 | NaN | Q | 1 |
6 | 0.0 | 1 | McCarthy, Mr. Timothy J | male | 54.0 | 0 | 0 | 17463 | 51.8625 | E46 | S | 1 |
7 | 0.0 | 3 | Palsson, Master. Gosta Leonard | male | 2.0 | 3 | 1 | 349909 | 21.0750 | NaN | S | 1 |
8 | 1.0 | 3 | Johnson, Mrs. Oscar W (Elisabeth Vilhelmina Berg) | female | 27.0 | 0 | 2 | 347742 | 11.1333 | NaN | S | 1 |
9 | 1.0 | 2 | Nasser, Mrs. Nicholas (Adele Achem) | female | 14.0 | 1 | 0 | 237736 | 30.0708 | NaN | C | 1 |
10 | 1.0 | 3 | Sandstrom, Miss. Marguerite Rut | female | 4.0 | 1 | 1 | PP 9549 | 16.7000 | G6 | S | 1 |
11 | 1.0 | 1 | Bonnell, Miss. Elizabeth | female | 58.0 | 0 | 0 | 113783 | 26.5500 | C103 | S | 1 |
12 | 0.0 | 3 | Saundercock, Mr. William Henry | male | 20.0 | 0 | 0 | A/5. 2151 | 8.0500 | NaN | S | 1 |
13 | 0.0 | 3 | Andersson, Mr. Anders Johan | male | 39.0 | 1 | 5 | 347082 | 31.2750 | NaN | S | 1 |
14 | 0.0 | 3 | Vestrom, Miss. Hulda Amanda Adolfina | female | 14.0 | 0 | 0 | 350406 | 7.8542 | NaN | S | 1 |
15 | 1.0 | 2 | Hewlett, Mrs. (Mary D Kingcome) | female | 55.0 | 0 | 0 | 248706 | 16.0000 | NaN | S | 1 |
16 | 0.0 | 3 | Rice, Master. Eugene | male | 2.0 | 4 | 1 | 382652 | 29.1250 | NaN | Q | 1 |
17 | 1.0 | 2 | Williams, Mr. Charles Eugene | male | 28.0 | 0 | 0 | 244373 | 13.0000 | NaN | S | 1 |
18 | 0.0 | 3 | Vander Planke, Mrs. Julius (Emelia Maria Vande... | female | 31.0 | 1 | 0 | 345763 | 18.0000 | NaN | S | 1 |
19 | 1.0 | 3 | Masselmani, Mrs. Fatima | female | 28.0 | 0 | 0 | 2649 | 7.2250 | NaN | C | 1 |
#Также закодируем и цену за проезд #Удалять после объединения нельзя - можно удалить строку из тестовых данных full_df['Fare'] = full_df['Fare'].fillna(train_df['Fare'].median())
# Проверяем full_df.isnull().sum()
Survived 418 Pclass 0 Name 0 Sex 0 Age 0 SibSp 0 Parch 0 Ticket 0 Fare 0 Cabin 1014 Embarked 0 is_train 0 dtype: int64
Три четверти данных по колонке Cabin в тестовых данных являются NaN. Скорее всего, это пассажиры второго или третьего класса, у которых просто не было собственной кабины. Совсем избавляться от этой колонки, наверное, не стоит — и не обязательно. Вместо этого мы закодируем её как наличие или отсутствие палубы.
# Создаём новый бинарный признак: была ли указана каюта full_df['Has_Cabin'] = full_df['Cabin'].notnull().astype(int)
# Удаляем оригинальную колонку Cabin, чтобы она не мешала full_df = full_df.drop(columns='Cabin')
# Посмотрим на наш изменённый датафрейм full_df
Survived | Pclass | Name | Sex | Age | SibSp | Parch | Ticket | Fare | Embarked | is_train | Has_Cabin | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 0.0 | 3 | Braund, Mr. Owen Harris | male | 22.0 | 1 | 0 | A/5 21171 | 7.2500 | S | 1 | 0 |
1 | 1.0 | 1 | Cumings, Mrs. John Bradley (Florence Briggs Th... | female | 38.0 | 1 | 0 | PC 17599 | 71.2833 | C | 1 | 1 |
2 | 1.0 | 3 | Heikkinen, Miss. Laina | female | 26.0 | 0 | 0 | STON/O2. 3101282 | 7.9250 | S | 1 | 0 |
3 | 1.0 | 1 | Futrelle, Mrs. Jacques Heath (Lily May Peel) | female | 35.0 | 1 | 0 | 113803 | 53.1000 | S | 1 | 1 |
4 | 0.0 | 3 | Allen, Mr. William Henry | male | 35.0 | 0 | 0 | 373450 | 8.0500 | S | 1 | 0 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
1302 | NaN | 3 | Spector, Mr. Woolf | male | 28.0 | 0 | 0 | A.5. 3236 | 8.0500 | S | 0 | 0 |
1303 | NaN | 1 | Oliva y Ocana, Dona. Fermina | female | 39.0 | 0 | 0 | PC 17758 | 108.9000 | C | 0 | 1 |
1304 | NaN | 3 | Saether, Mr. Simon Sivertsen | male | 38.5 | 0 | 0 | SOTON/O.Q. 3101262 | 7.2500 | S | 0 | 0 |
1305 | NaN | 3 | Ware, Mr. Frederick | male | 28.0 | 0 | 0 | 359309 | 8.0500 | S | 0 | 0 |
1306 | NaN | 3 | Peter, Master. Michael J | male | 28.0 | 1 | 1 | 2668 | 22.3583 | C | 0 | 0 |
1307 rows × 12 columns
Начинаем избавляться от ненужных, не несущих полезной информации для обучения модели, признаков
# Имя и номер билета удаляем full_df = full_df.drop(columns=['Name','Ticket'])
# Смотрим результат full_df
Survived | Pclass | Sex | Age | SibSp | Parch | Fare | Embarked | is_train | Has_Cabin | |
|---|---|---|---|---|---|---|---|---|---|---|
0 | 0.0 | 3 | male | 22.0 | 1 | 0 | 7.2500 | S | 1 | 0 |
1 | 1.0 | 1 | female | 38.0 | 1 | 0 | 71.2833 | C | 1 | 1 |
2 | 1.0 | 3 | female | 26.0 | 0 | 0 | 7.9250 | S | 1 | 0 |
3 | 1.0 | 1 | female | 35.0 | 1 | 0 | 53.1000 | S | 1 | 1 |
4 | 0.0 | 3 | male | 35.0 | 0 | 0 | 8.0500 | S | 1 | 0 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
1302 | NaN | 3 | male | 28.0 | 0 | 0 | 8.0500 | S | 0 | 0 |
1303 | NaN | 1 | female | 39.0 | 0 | 0 | 108.9000 | C | 0 | 1 |
1304 | NaN | 3 | male | 38.5 | 0 | 0 | 7.2500 | S | 0 | 0 |
1305 | NaN | 3 | male | 28.0 | 0 | 0 | 8.0500 | S | 0 | 0 |
1306 | NaN | 3 | male | 28.0 | 1 | 1 | 22.3583 | C | 0 | 0 |
1307 rows × 10 columns
В целом не стесняемся часто смотреть и проверять результат. В процессе постоянного отсмотра данных может придти идея, которая улучшит качество модели
Теперь закодируем колонки, которые являются объектами (object), как категории (category). Это особенно важно, если мы собираемся использовать модель CatBoost — она умеет напрямую работать с категориальными признаками и не требует их one-hot-кодирования.
CatBoost сам обработает эти признаки, если они будут иметь тип category, поэтому просто приведём нужные колонки к этому типу.
# Приводим колонки пол и порт посадки к категориальному виду full_df['Sex'] = full_df['Sex'].astype('category') full_df['Embarked'] = full_df['Embarked'].astype('category')
# Проверяем full_df.describe(include='all')
Survived | Pclass | Sex | Age | SibSp | Parch | Fare | Embarked | is_train | Has_Cabin | |
|---|---|---|---|---|---|---|---|---|---|---|
count | 889.000000 | 1307.000000 | 1307 | 1307.000000 | 1307.000000 | 1307.000000 | 1307.000000 | 1307 | 1307.000000 | 1307.000000 |
unique | NaN | NaN | 2 | NaN | NaN | NaN | NaN | 3 | NaN | NaN |
top | NaN | NaN | male | NaN | NaN | NaN | NaN | S | NaN | NaN |
freq | NaN | NaN | 843 | NaN | NaN | NaN | NaN | 914 | NaN | NaN |
mean | 0.382452 | 2.296863 | NaN | 29.471821 | 0.499617 | 0.385616 | 33.209595 | NaN | 0.680184 | 0.224178 |
std | 0.486260 | 0.836942 | NaN | 12.881592 | 1.042273 | 0.866092 | 51.748768 | NaN | 0.466584 | 0.417199 |
min | 0.000000 | 1.000000 | NaN | 0.170000 | 0.000000 | 0.000000 | 0.000000 | NaN | 0.000000 | 0.000000 |
25% | 0.000000 | 2.000000 | NaN | 22.000000 | 0.000000 | 0.000000 | 7.895800 | NaN | 0.000000 | 0.000000 |
50% | 0.000000 | 3.000000 | NaN | 28.000000 | 0.000000 | 0.000000 | 14.454200 | NaN | 1.000000 | 0.000000 |
75% | 1.000000 | 3.000000 | NaN | 35.000000 | 1.000000 | 0.000000 | 31.275000 | NaN | 1.000000 | 0.000000 |
max | 1.000000 | 3.000000 | NaN | 80.000000 | 8.000000 | 9.000000 | 512.329200 | NaN | 1.000000 | 1.000000 |
Кроме того, колонка Pclass изначально имеет тип int, но на самом деле это категориальный признак (класс обслуживания: 1, 2 или 3). Если оставить её как числовую, модель может ошибочно посчитать, что класс 3 «больше» и важнее, чем класс 2, а тот — важнее, чем класс 1. Чтобы избежать этого, мы также приведём Pclass к категориальному типу.
# Приводим колонку класс к категориальному виду full_df['Pclass'] = full_df['Pclass'].astype('category')
Обработка данных завершена и теперь разделяем данные обратно:
# Разделим обратно: X_train = full_df[full_df['is_train'] == 1].drop(['is_train', 'Survived'], axis=1) y_train = full_df[full_df['is_train'] == 1]['Survived'] X_test = full_df[full_df['is_train'] == 0].drop(['is_train', 'Survived'], axis=1)
# Проверяем размеры print(X_train.shape, y_train.shape, X_test.shape)
(889, 8) (889,) (418, 8)
Начинаем обучение модели
# Начинаем обучение # Сначали сплиттим выборку X_train_split, X_valid, y_train_split, y_valid = train_test_split( X_train, y_train, test_size=0.2, random_state=42, stratify=y_train )
# Положим в список категориальных признаков для CatBoost наши приведённые к типу Category колонки cat_features = X_train.select_dtypes(include='category').columns.tolist()
# Проверяем cat_features
['Pclass', 'Sex', 'Embarked']
# Обучаем модель с достаточно средними параметрами # Пока не используем перебор гиперпараметров model = CatBoostClassifier( iterations=1000, learning_rate=0.05, depth=6, eval_metric='Accuracy', random_seed=42, early_stopping_rounds=50, verbose=100 ) model.fit( X_train_split, y_train_split, eval_set=(X_valid, y_valid), cat_features=cat_features )
0: learn: 0.8227848 test: 0.7977528 best: 0.7977528 (0) total: 160ms remaining: 2m 39s Stopped by overfitting detector (50 iterations wait) bestTest = 0.8314606742 bestIteration = 32 Shrink model to first 33 iterations.
# Оценим качество y_pred = model.predict(X_valid) acc = accuracy_score(y_valid, y_pred) print(f"Validation Accuracy: {acc:.4f}")
Validation Accuracy: 0.8315
Доля правильных ответов 83,15%


# Предсказание на тесте test_preds = model.predict(X_test)
# Посмотрим на предсказания модели test_preds
array([0., 0., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0., 1., 0., 1., 1., 0., 0., 0., 1., 0., 0., 1., 0., 1., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 0., 0., 0., 1., 0., 0., 0., 1., 1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 1., 1., 1., 1., 0., 0., 1., 1., 0., 0., 0., 1., 0., 0., 1., 0., 1., 1., 0., 0., 0., 0., 0., 1., 0., 1., 1., 0., 0., 1., 0., 0., 0., 1., 0., 1., 0., 1., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 0., 0., 1., 0., 1., 1., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 1., 1., 0., 0., 1., 1., 1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 1., 1., 0., 1., 1., 0., 0., 1., 0., 1., 0., 1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 1., 1., 0., 0., 1., 1., 0., 1., 0., 0., 1., 0., 1., 0., 0., 0., 0., 1., 0., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 1., 0., 1., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 0., 0., 0., 0., 1., 0., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 0., 1., 0., 0., 0., 0., 0., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 1., 1., 0., 1., 0., 0., 0., 0., 0., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0., 1., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 1., 0., 1., 0., 1., 1., 0., 0., 0., 1., 0., 1., 0., 0., 0., 0., 1., 1., 0., 1., 0., 0., 0., 1., 0., 0., 1., 0., 0., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 1., 0., 1., 0., 0., 1., 0., 1., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 0., 1., 0., 0., 0.])
# Создание submission.csv submission = pd.DataFrame({ 'PassengerId': passenger_ids, 'Survived': test_preds.astype(int) }) submission.to_csv('../submissions/submission.csv', index=False) print("✅ Submission файл сохранён как submission.csv")
✅ Submission файл сохранён как submission.csv
# Посмотрим файл submission
PassengerId | Survived | |
|---|---|---|
0 | 892 | 0 |
1 | 893 | 0 |
2 | 894 | 0 |
3 | 895 | 0 |
4 | 896 | 0 |
... | ... | ... |
413 | 1305 | 0 |
414 | 1306 | 1 |
415 | 1307 | 0 |
416 | 1308 | 0 |
417 | 1309 | 0 |
418 rows × 2 columns

Результат Топ 2'400
Теперь попробуем улучшить качество модели с помощью подбора гиперпараметров
Буду использовать случайный подбор параметров с помощью RandomizedSearchCV
# Импортируем модуль from sklearn.model_selection import RandomizedSearchCV
#Сетка гиперпараметров param_grid = { 'depth': [4, 6, 8, 10], # Максимальная глубина дерева (чем глубже — тем сложнее модель) 'learning_rate': [0.01, 0.05, 0.1], # Скорость обучения (маленькое значение = медленнее обучение, но может быть точнее) 'iterations': [300, 500, 1000], # Количество деревьев (итераций бустинга) 'l2_leaf_reg': [1, 3, 5, 7, 9], # L2-регуляризация — предотвращает переобучение 'border_count': [32, 64, 128] # Количество бинов для дискретизации числовых признаков } #Randomized Search с кросс-валидацией random_search = RandomizedSearchCV( estimator=model, param_distributions=param_grid, n_iter=45, # Сколько случайных комбинаций попробовать scoring='accuracy', # Метрика качества, которую нужно максимизировать cv=10, # Количество фолдов (разбиений) для кросс-валидации verbose=2, # Показывать процесс обучения в терминале n_jobs=-1 # Использовать все доступные ядра CPU для ускорения )
# Создаём экземпляр модели model = CatBoostClassifier(silent=True, random_state=42) # random state фиксированный
# Фиксируем некоторые параметры для модели fit_params = { "eval_set": [(X_valid, y_valid)], # Набор валидационных данных (для контроля переобучения и использования early stopping) "early_stopping_rounds": 100, # Если метрика не улучшается в течение 100 итераций — обучение остановится "cat_features": cat_features, # Указываем, какие признаки являются категориальными (CatBoost работает с ними нативно) "verbose": 1 # Показывать прогресс обучения во время тренировки }
#Запуск подбора random_search.fit(X_train_split, y_train_split, **fit_params)
Fitting 10 folds for each of 45 candidates, totalling 450 fits 0: learn: 0.7988748 test: 0.7752809 best: 0.7752809 (0) total: 18.8ms remaining: 5.63s 1: learn: 0.8016878 test: 0.7808989 best: 0.7808989 (1) total: 39.8ms remaining: 5.93s 2: learn: 0.8101266 test: 0.7921348 best: 0.7921348 (2) total: 56.4ms remaining: 5.58s 3: learn: 0.8045007 test: 0.7865169 best: 0.7921348 (2) total: 76.9ms remaining: 5.69s 4: learn: 0.8030942 test: 0.7865169 best: 0.7921348 (2) total: 97.1ms remaining: 5.73s 5: learn: 0.8087201 test: 0.7977528 best: 0.7977528 (5) total: 118ms remaining: 5.78s 6: learn: 0.8087201 test: 0.7977528 best: 0.7977528 (5) total: 139ms remaining: 5.82s 7: learn: 0.8101266 test: 0.8033708 best: 0.8033708 (7) total: 160ms remaining: 5.85s 8: learn: 0.8101266 test: 0.7977528 best: 0.8033708 (7) total: 181ms remaining: 5.86s 9: learn: 0.8101266 test: 0.7977528 best: 0.8033708 (7) total: 201ms remaining: 5.82s 10: learn: 0.8101266 test: 0.7977528 best: 0.8033708 (7) total: 220ms remaining: 5.77s 11: learn: 0.8115331 test: 0.7977528 best: 0.8033708 (7) total: 241ms remaining: 5.78s 12: learn: 0.8171589 test: 0.7977528 best: 0.8033708 (7) total: 262ms remaining: 5.78s 13: learn: 0.8185654 test: 0.7977528 best: 0.8033708 (7) total: 293ms remaining: 5.98s 14: learn: 0.8185654 test: 0.8033708 best: 0.8033708 (7) total: 322ms remaining: 6.12s 15: learn: 0.8185654 test: 0.8033708 best: 0.8033708 (7) total: 348ms remaining: 6.17s 16: learn: 0.8185654 test: 0.8033708 best: 0.8033708 (7) total: 369ms remaining: 6.14s 17: learn: 0.8185654 test: 0.8033708 best: 0.8033708 (7) total: 385ms remaining: 6.03s 18: learn: 0.8199719 test: 0.8033708 best: 0.8033708 (7) total: 407ms remaining: 6.03s 19: learn: 0.8227848 test: 0.8033708 best: 0.8033708 (7) total: 430ms remaining: 6.02s 20: learn: 0.8227848 test: 0.8033708 best: 0.8033708 (7) total: 452ms remaining: 6s 21: learn: 0.8227848 test: 0.8033708 best: 0.8033708 (7) total: 473ms remaining: 5.98s 22: learn: 0.8255977 test: 0.8033708 best: 0.8033708 (7) total: 495ms remaining: 5.96s 23: learn: 0.8270042 test: 0.8033708 best: 0.8033708 (7) total: 501ms remaining: 5.76s 24: learn: 0.8270042 test: 0.8033708 best: 0.8033708 (7) total: 521ms remaining: 5.73s 25: learn: 0.8255977 test: 0.8033708 best: 0.8033708 (7) total: 536ms remaining: 5.65s 26: learn: 0.8255977 test: 0.8033708 best: 0.8033708 (7) total: 558ms remaining: 5.64s 27: learn: 0.8241913 test: 0.7977528 best: 0.8033708 (7) total: 576ms remaining: 5.6s 28: learn: 0.8255977 test: 0.7977528 best: 0.8033708 (7) total: 596ms remaining: 5.57s 29: learn: 0.8255977 test: 0.8033708 best: 0.8033708 (7) total: 615ms remaining: 5.54s 30: learn: 0.8255977 test: 0.8033708 best: 0.8033708 (7) total: 635ms remaining: 5.51s 31: learn: 0.8255977 test: 0.8089888 best: 0.8089888 (31) total: 655ms remaining: 5.48s 32: learn: 0.8284107 test: 0.8089888 best: 0.8089888 (31) total: 674ms remaining: 5.45s 33: learn: 0.8298172 test: 0.8146067 best: 0.8146067 (33) total: 694ms remaining: 5.43s 34: learn: 0.8340366 test: 0.8146067 best: 0.8146067 (33) total: 706ms remaining: 5.35s 35: learn: 0.8354430 test: 0.8146067 best: 0.8146067 (33) total: 728ms remaining: 5.34s 36: learn: 0.8354430 test: 0.8146067 best: 0.8146067 (33) total: 749ms remaining: 5.32s 37: learn: 0.8368495 test: 0.8089888 best: 0.8146067 (33) total: 764ms remaining: 5.27s 38: learn: 0.8382560 test: 0.8089888 best: 0.8146067 (33) total: 780ms remaining: 5.22s 39: learn: 0.8368495 test: 0.8089888 best: 0.8146067 (33) total: 799ms remaining: 5.19s 40: learn: 0.8368495 test: 0.8089888 best: 0.8146067 (33) total: 820ms remaining: 5.18s 41: learn: 0.8382560 test: 0.8089888 best: 0.8146067 (33) total: 840ms remaining: 5.16s 42: learn: 0.8382560 test: 0.8202247 best: 0.8202247 (42) total: 862ms remaining: 5.15s 43: learn: 0.8410689 test: 0.8146067 best: 0.8202247 (42) total: 884ms remaining: 5.14s 44: learn: 0.8396624 test: 0.8146067 best: 0.8202247 (42) total: 907ms remaining: 5.14s 45: learn: 0.8438819 test: 0.8258427 best: 0.8258427 (45) total: 930ms remaining: 5.14s 46: learn: 0.8466948 test: 0.8258427 best: 0.8258427 (45) total: 953ms remaining: 5.13s 47: learn: 0.8466948 test: 0.8258427 best: 0.8258427 (45) total: 976ms remaining: 5.12s 48: learn: 0.8481013 test: 0.8258427 best: 0.8258427 (45) total: 999ms remaining: 5.12s 49: learn: 0.8452883 test: 0.8314607 best: 0.8314607 (49) total: 1.02s remaining: 5.11s 50: learn: 0.8438819 test: 0.8314607 best: 0.8314607 (49) total: 1.04s remaining: 5.09s 51: learn: 0.8438819 test: 0.8314607 best: 0.8314607 (49) total: 1.06s remaining: 5.07s 52: learn: 0.8452883 test: 0.8370787 best: 0.8370787 (52) total: 1.08s remaining: 5.05s 53: learn: 0.8424754 test: 0.8370787 best: 0.8370787 (52) total: 1.1s remaining: 5.04s 54: learn: 0.8396624 test: 0.8370787 best: 0.8370787 (52) total: 1.13s remaining: 5.01s 55: learn: 0.8396624 test: 0.8314607 best: 0.8370787 (52) total: 1.15s remaining: 4.99s 56: learn: 0.8396624 test: 0.8314607 best: 0.8370787 (52) total: 1.17s remaining: 4.97s 57: learn: 0.8382560 test: 0.8314607 best: 0.8370787 (52) total: 1.18s remaining: 4.94s 58: learn: 0.8382560 test: 0.8314607 best: 0.8370787 (52) total: 1.2s remaining: 4.89s 59: learn: 0.8396624 test: 0.8314607 best: 0.8370787 (52) total: 1.22s remaining: 4.87s 60: learn: 0.8396624 test: 0.8314607 best: 0.8370787 (52) total: 1.24s remaining: 4.85s 61: learn: 0.8396624 test: 0.8314607 best: 0.8370787 (52) total: 1.27s remaining: 4.89s 62: learn: 0.8396624 test: 0.8314607 best: 0.8370787 (52) total: 1.29s remaining: 4.87s 63: learn: 0.8396624 test: 0.8314607 best: 0.8370787 (52) total: 1.32s remaining: 4.86s 64: learn: 0.8396624 test: 0.8314607 best: 0.8370787 (52) total: 1.34s remaining: 4.85s 65: learn: 0.8396624 test: 0.8314607 best: 0.8370787 (52) total: 1.36s remaining: 4.84s 66: learn: 0.8410689 test: 0.8314607 best: 0.8370787 (52) total: 1.39s remaining: 4.82s 67: learn: 0.8410689 test: 0.8314607 best: 0.8370787 (52) total: 1.41s remaining: 4.8s 68: learn: 0.8410689 test: 0.8314607 best: 0.8370787 (52) total: 1.42s remaining: 4.75s 69: learn: 0.8424754 test: 0.8314607 best: 0.8370787 (52) total: 1.44s remaining: 4.74s 70: learn: 0.8424754 test: 0.8314607 best: 0.8370787 (52) total: 1.46s remaining: 4.72s 71: learn: 0.8424754 test: 0.8314607 best: 0.8370787 (52) total: 1.49s remaining: 4.7s 72: learn: 0.8438819 test: 0.8314607 best: 0.8370787 (52) total: 1.51s remaining: 4.69s 73: learn: 0.8396624 test: 0.8314607 best: 0.8370787 (52) total: 1.52s remaining: 4.66s 74: learn: 0.8382560 test: 0.8258427 best: 0.8370787 (52) total: 1.55s remaining: 4.64s 75: learn: 0.8410689 test: 0.8258427 best: 0.8370787 (52) total: 1.57s remaining: 4.63s 76: learn: 0.8424754 test: 0.8202247 best: 0.8370787 (52) total: 1.59s remaining: 4.6s 77: learn: 0.8424754 test: 0.8202247 best: 0.8370787 (52) total: 1.61s remaining: 4.58s 78: learn: 0.8438819 test: 0.8202247 best: 0.8370787 (52) total: 1.63s remaining: 4.56s 79: learn: 0.8452883 test: 0.8202247 best: 0.8370787 (52) total: 1.65s remaining: 4.54s 80: learn: 0.8452883 test: 0.8202247 best: 0.8370787 (52) total: 1.67s remaining: 4.52s 81: learn: 0.8452883 test: 0.8202247 best: 0.8370787 (52) total: 1.69s remaining: 4.49s 82: learn: 0.8452883 test: 0.8202247 best: 0.8370787 (52) total: 1.71s remaining: 4.47s 83: learn: 0.8452883 test: 0.8202247 best: 0.8370787 (52) total: 1.73s remaining: 4.44s 84: learn: 0.8438819 test: 0.8202247 best: 0.8370787 (52) total: 1.75s remaining: 4.42s 85: learn: 0.8452883 test: 0.8202247 best: 0.8370787 (52) total: 1.77s remaining: 4.39s 86: learn: 0.8452883 test: 0.8202247 best: 0.8370787 (52) total: 1.79s remaining: 4.37s 87: learn: 0.8438819 test: 0.8202247 best: 0.8370787 (52) total: 1.8s remaining: 4.35s 88: learn: 0.8452883 test: 0.8146067 best: 0.8370787 (52) total: 1.83s remaining: 4.33s 89: learn: 0.8438819 test: 0.8146067 best: 0.8370787 (52) total: 1.85s remaining: 4.31s 90: learn: 0.8438819 test: 0.8146067 best: 0.8370787 (52) total: 1.87s remaining: 4.29s 91: learn: 0.8438819 test: 0.8146067 best: 0.8370787 (52) total: 1.89s remaining: 4.26s 92: learn: 0.8438819 test: 0.8146067 best: 0.8370787 (52) total: 1.91s remaining: 4.24s 93: learn: 0.8438819 test: 0.8258427 best: 0.8370787 (52) total: 1.92s remaining: 4.22s 94: learn: 0.8438819 test: 0.8258427 best: 0.8370787 (52) total: 1.93s remaining: 4.16s 95: learn: 0.8438819 test: 0.8258427 best: 0.8370787 (52) total: 1.95s remaining: 4.14s 96: learn: 0.8438819 test: 0.8202247 best: 0.8370787 (52) total: 1.97s remaining: 4.12s 97: learn: 0.8438819 test: 0.8258427 best: 0.8370787 (52) total: 1.99s remaining: 4.1s 98: learn: 0.8438819 test: 0.8258427 best: 0.8370787 (52) total: 2.01s remaining: 4.08s 99: learn: 0.8438819 test: 0.8258427 best: 0.8370787 (52) total: 2.02s remaining: 4.05s 100: learn: 0.8438819 test: 0.8258427 best: 0.8370787 (52) total: 2.04s remaining: 4.03s 101: learn: 0.8438819 test: 0.8258427 best: 0.8370787 (52) total: 2.06s remaining: 4.01s 102: learn: 0.8438819 test: 0.8258427 best: 0.8370787 (52) total: 2.09s remaining: 3.99s 103: learn: 0.8438819 test: 0.8258427 best: 0.8370787 (52) total: 2.11s remaining: 3.97s 104: learn: 0.8438819 test: 0.8258427 best: 0.8370787 (52) total: 2.13s remaining: 3.95s 105: learn: 0.8452883 test: 0.8202247 best: 0.8370787 (52) total: 2.15s remaining: 3.93s 106: learn: 0.8452883 test: 0.8202247 best: 0.8370787 (52) total: 2.2s remaining: 3.97s 107: learn: 0.8452883 test: 0.8202247 best: 0.8370787 (52) total: 2.24s remaining: 3.99s 108: learn: 0.8452883 test: 0.8146067 best: 0.8370787 (52) total: 2.28s remaining: 3.99s 109: learn: 0.8466948 test: 0.8146067 best: 0.8370787 (52) total: 2.3s remaining: 3.98s 110: learn: 0.8466948 test: 0.8146067 best: 0.8370787 (52) total: 2.32s remaining: 3.96s 111: learn: 0.8466948 test: 0.8146067 best: 0.8370787 (52) total: 2.35s remaining: 3.94s 112: learn: 0.8466948 test: 0.8146067 best: 0.8370787 (52) total: 2.37s remaining: 3.92s 113: learn: 0.8466948 test: 0.8146067 best: 0.8370787 (52) total: 2.39s remaining: 3.9s 114: learn: 0.8466948 test: 0.8146067 best: 0.8370787 (52) total: 2.41s remaining: 3.88s 115: learn: 0.8466948 test: 0.8146067 best: 0.8370787 (52) total: 2.43s remaining: 3.86s 116: learn: 0.8466948 test: 0.8146067 best: 0.8370787 (52) total: 2.45s remaining: 3.84s 117: learn: 0.8452883 test: 0.8202247 best: 0.8370787 (52) total: 2.48s remaining: 3.82s 118: learn: 0.8452883 test: 0.8146067 best: 0.8370787 (52) total: 2.5s remaining: 3.8s 119: learn: 0.8452883 test: 0.8146067 best: 0.8370787 (52) total: 2.52s remaining: 3.78s 120: learn: 0.8452883 test: 0.8146067 best: 0.8370787 (52) total: 2.54s remaining: 3.76s 121: learn: 0.8452883 test: 0.8146067 best: 0.8370787 (52) total: 2.56s remaining: 3.74s 122: learn: 0.8452883 test: 0.8146067 best: 0.8370787 (52) total: 2.59s remaining: 3.72s 123: learn: 0.8452883 test: 0.8146067 best: 0.8370787 (52) total: 2.61s remaining: 3.71s 124: learn: 0.8452883 test: 0.8146067 best: 0.8370787 (52) total: 2.63s remaining: 3.69s 125: learn: 0.8452883 test: 0.8146067 best: 0.8370787 (52) total: 2.65s remaining: 3.67s 126: learn: 0.8452883 test: 0.8146067 best: 0.8370787 (52) total: 2.68s remaining: 3.65s 127: learn: 0.8466948 test: 0.8146067 best: 0.8370787 (52) total: 2.7s remaining: 3.63s 128: learn: 0.8466948 test: 0.8146067 best: 0.8370787 (52) total: 2.72s remaining: 3.61s 129: learn: 0.8466948 test: 0.8146067 best: 0.8370787 (52) total: 2.73s remaining: 3.57s 130: learn: 0.8466948 test: 0.8146067 best: 0.8370787 (52) total: 2.75s remaining: 3.55s 131: learn: 0.8466948 test: 0.8146067 best: 0.8370787 (52) total: 2.77s remaining: 3.53s 132: learn: 0.8466948 test: 0.8146067 best: 0.8370787 (52) total: 2.79s remaining: 3.51s 133: learn: 0.8466948 test: 0.8146067 best: 0.8370787 (52) total: 2.8s remaining: 3.47s 134: learn: 0.8466948 test: 0.8146067 best: 0.8370787 (52) total: 2.83s remaining: 3.46s 135: learn: 0.8466948 test: 0.8146067 best: 0.8370787 (52) total: 2.85s remaining: 3.43s 136: learn: 0.8466948 test: 0.8146067 best: 0.8370787 (52) total: 2.87s remaining: 3.42s 137: learn: 0.8466948 test: 0.8146067 best: 0.8370787 (52) total: 2.89s remaining: 3.4s 138: learn: 0.8481013 test: 0.8146067 best: 0.8370787 (52) total: 2.91s remaining: 3.37s 139: learn: 0.8481013 test: 0.8146067 best: 0.8370787 (52) total: 2.94s remaining: 3.35s 140: learn: 0.8481013 test: 0.8146067 best: 0.8370787 (52) total: 2.96s remaining: 3.33s 141: learn: 0.8481013 test: 0.8146067 best: 0.8370787 (52) total: 2.98s remaining: 3.31s 142: learn: 0.8481013 test: 0.8146067 best: 0.8370787 (52) total: 3s remaining: 3.29s 143: learn: 0.8495077 test: 0.8146067 best: 0.8370787 (52) total: 3.02s remaining: 3.27s 144: learn: 0.8495077 test: 0.8146067 best: 0.8370787 (52) total: 3.04s remaining: 3.25s 145: learn: 0.8495077 test: 0.8146067 best: 0.8370787 (52) total: 3.06s remaining: 3.23s 146: learn: 0.8509142 test: 0.8146067 best: 0.8370787 (52) total: 3.08s remaining: 3.21s 147: learn: 0.8509142 test: 0.8146067 best: 0.8370787 (52) total: 3.1s remaining: 3.19s 148: learn: 0.8523207 test: 0.8146067 best: 0.8370787 (52) total: 3.12s remaining: 3.16s 149: learn: 0.8523207 test: 0.8146067 best: 0.8370787 (52) total: 3.14s remaining: 3.14s 150: learn: 0.8523207 test: 0.8146067 best: 0.8370787 (52) total: 3.16s remaining: 3.12s 151: learn: 0.8523207 test: 0.8146067 best: 0.8370787 (52) total: 3.18s remaining: 3.1s 152: learn: 0.8523207 test: 0.8146067 best: 0.8370787 (52) total: 3.2s remaining: 3.08s Stopped by overfitting detector (100 iterations wait) bestTest = 0.8370786517 bestIteration = 52 Shrink model to first 53 iterations.
RandomizedSearchCV(cv=10, estimator=<catboost.core.CatBoostClassifier object at 0x000002E16600D400>, n_iter=45, n_jobs=-1, param_distributions={'border_count': [32, 64, 128], 'depth': [4, 6, 8, 10], 'iterations': [300, 500, 1000], 'l2_leaf_reg': [1, 3, 5, 7, 9], 'learning_rate': [0.01, 0.05, 0.1]}, scoring='accuracy', verbose=2)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
RandomizedSearchCV
?Documentation for RandomizedSearchCViFitted
Parameters
<tr class="user-set"> <td><i class="copy-paste-icon"></i></td> <td class="param">estimator </td> <td class="value"><catboost.cor...002E16600D400></td> </tr> <tr class="user-set"> <td><i class="copy-paste-icon"></i></td> <td class="param">param_distributions </td> <td class="value">{'border_count': [32, 64, ...], 'depth': [4, 6, ...], 'iterations': [300, 500, ...], 'l2_leaf_reg': [1, 3, ...], ...}</td> </tr> <tr class="user-set"> <td><i class="copy-paste-icon"></i></td> <td class="param">n_iter </td> <td class="value">45</td> </tr> <tr class="user-set"> <td><i class="copy-paste-icon"></i></td> <td class="param">scoring </td> <td class="value">'accuracy'</td> </tr> <tr class="user-set"> <td><i class="copy-paste-icon"></i></td> <td class="param">n_jobs </td> <td class="value">-1</td> </tr> <tr class="default"> <td><i class="copy-paste-icon"></i></td> <td class="param">refit </td> <td class="value">True</td> </tr> <tr class="user-set"> <td><i class="copy-paste-icon"></i></td> <td class="param">cv </td> <td class="value">10</td> </tr> <tr class="user-set"> <td><i class="copy-paste-icon"></i></td> <td class="param">verbose </td> <td class="value">2</td> </tr> <tr class="default"> <td><i class="copy-paste-icon"></i></td> <td class="param">pre_dispatch </td> <td class="value">'2*n_jobs'</td> </tr> <tr class="default"> <td><i class="copy-paste-icon"></i></td> <td class="param">random_state </td> <td class="value">None</td> </tr> <tr class="default"> <td><i class="copy-paste-icon"></i></td> <td class="param">error_score </td> <td class="value">nan</td> </tr> <tr class="default"> <td><i class="copy-paste-icon"></i></td> <td class="param">return_train_score </td> <td class="value">False</td> </tr> </tbody> </table> </details> </div> </div></div></div><div class="sk-parallel"><div class="sk-parallel-item"><div class="sk-item"><div class="sk-label-container"><div class="sk-label fitted sk-toggleable"><input type="checkbox" id="sk-estimator-id-2" class="sk-toggleable__control sk-hidden--visually"><label class="sk-toggleable__label fitted sk-toggleable__label-arrow" for="sk-estimator-id-2"><div><div>best_estimator_: CatBoostClassifier</div></div></label><div data-param-prefix="best_estimator___" class="sk-toggleable__content fitted"><pre><catboost.core.CatBoostClassifier object at 0x000002E169667020></pre></div></div></div><div class="sk-serial"><div class="sk-item"><div class="sk-estimator fitted sk-toggleable"><input type="checkbox" id="sk-estimator-id-3" class="sk-toggleable__control sk-hidden--visually"><label class="sk-toggleable__label fitted sk-toggleable__label-arrow" for="sk-estimator-id-3"><div><div>CatBoostClassifier</div></div></label><div data-param-prefix="best_estimator___" class="sk-toggleable__content fitted"><pre><catboost.core.CatBoostClassifier object at 0x000002E169667020></pre></div></div></div></div></div></div></div></div></div></div>
# Выведем лучшие параметры random_search.best_params_
{'learning_rate': 0.05, 'l2_leaf_reg': 3, 'iterations': 300, 'depth': 4, 'border_count': 32}
# Выведем лучший скор random_search.best_score_
np.float64(0.82981220657277)
#Сохраняем best_params в .txt файл, чтобы не потерять with open("best_params.txt", "a") as f: json.dump(random_search.best_params_, f, indent=4)
Оступление
Дважды при малом early_stopping_rounds, равном 30, при n_iter, равном 15 и cs, равном 3
Модель показывала лучший accuracy, но при валидации на Kaggle, показывала результаты хуже
Добавил early_stopping_rounds, n_iter и cs и тогда
Получилось улучшить итоговый результат
# Посмотрим лучшую модель best_model = random_search.best_estimator_
# Обучаем модель с лучшими параметрами best_model.fit(X_train_split, y_train_split, **fit_params)
0: learn: 0.7988748 test: 0.7752809 best: 0.7752809 (0) total: 21.9ms remaining: 6.54s 1: learn: 0.8016878 test: 0.7808989 best: 0.7808989 (1) total: 48.9ms remaining: 7.29s 2: learn: 0.8101266 test: 0.7921348 best: 0.7921348 (2) total: 74.5ms remaining: 7.37s 3: learn: 0.8045007 test: 0.7865169 best: 0.7921348 (2) total: 105ms remaining: 7.76s 4: learn: 0.8030942 test: 0.7865169 best: 0.7921348 (2) total: 133ms remaining: 7.85s 5: learn: 0.8087201 test: 0.7977528 best: 0.7977528 (5) total: 159ms remaining: 7.77s 6: learn: 0.8087201 test: 0.7977528 best: 0.7977528 (5) total: 184ms remaining: 7.7s 7: learn: 0.8101266 test: 0.8033708 best: 0.8033708 (7) total: 209ms remaining: 7.63s 8: learn: 0.8101266 test: 0.7977528 best: 0.8033708 (7) total: 234ms remaining: 7.57s 9: learn: 0.8101266 test: 0.7977528 best: 0.8033708 (7) total: 258ms remaining: 7.49s 10: learn: 0.8101266 test: 0.7977528 best: 0.8033708 (7) total: 286ms remaining: 7.52s 11: learn: 0.8115331 test: 0.7977528 best: 0.8033708 (7) total: 314ms remaining: 7.54s 12: learn: 0.8171589 test: 0.7977528 best: 0.8033708 (7) total: 341ms remaining: 7.54s 13: learn: 0.8185654 test: 0.7977528 best: 0.8033708 (7) total: 369ms remaining: 7.53s 14: learn: 0.8185654 test: 0.8033708 best: 0.8033708 (7) total: 392ms remaining: 7.45s 15: learn: 0.8185654 test: 0.8033708 best: 0.8033708 (7) total: 416ms remaining: 7.38s 16: learn: 0.8185654 test: 0.8033708 best: 0.8033708 (7) total: 438ms remaining: 7.29s 17: learn: 0.8185654 test: 0.8033708 best: 0.8033708 (7) total: 456ms remaining: 7.14s 18: learn: 0.8199719 test: 0.8033708 best: 0.8033708 (7) total: 479ms remaining: 7.08s 19: learn: 0.8227848 test: 0.8033708 best: 0.8033708 (7) total: 501ms remaining: 7.02s 20: learn: 0.8227848 test: 0.8033708 best: 0.8033708 (7) total: 524ms remaining: 6.96s 21: learn: 0.8227848 test: 0.8033708 best: 0.8033708 (7) total: 546ms remaining: 6.9s 22: learn: 0.8255977 test: 0.8033708 best: 0.8033708 (7) total: 576ms remaining: 6.93s 23: learn: 0.8270042 test: 0.8033708 best: 0.8033708 (7) total: 584ms remaining: 6.72s 24: learn: 0.8270042 test: 0.8033708 best: 0.8033708 (7) total: 607ms remaining: 6.68s 25: learn: 0.8255977 test: 0.8033708 best: 0.8033708 (7) total: 624ms remaining: 6.57s 26: learn: 0.8255977 test: 0.8033708 best: 0.8033708 (7) total: 646ms remaining: 6.53s 27: learn: 0.8241913 test: 0.7977528 best: 0.8033708 (7) total: 670ms remaining: 6.51s 28: learn: 0.8255977 test: 0.7977528 best: 0.8033708 (7) total: 692ms remaining: 6.47s 29: learn: 0.8255977 test: 0.8033708 best: 0.8033708 (7) total: 716ms remaining: 6.44s 30: learn: 0.8255977 test: 0.8033708 best: 0.8033708 (7) total: 739ms remaining: 6.41s 31: learn: 0.8255977 test: 0.8089888 best: 0.8089888 (31) total: 760ms remaining: 6.37s 32: learn: 0.8284107 test: 0.8089888 best: 0.8089888 (31) total: 784ms remaining: 6.34s 33: learn: 0.8298172 test: 0.8146067 best: 0.8146067 (33) total: 807ms remaining: 6.31s 34: learn: 0.8340366 test: 0.8146067 best: 0.8146067 (33) total: 820ms remaining: 6.21s 35: learn: 0.8354430 test: 0.8146067 best: 0.8146067 (33) total: 845ms remaining: 6.2s 36: learn: 0.8354430 test: 0.8146067 best: 0.8146067 (33) total: 868ms remaining: 6.17s 37: learn: 0.8368495 test: 0.8089888 best: 0.8146067 (33) total: 886ms remaining: 6.11s 38: learn: 0.8382560 test: 0.8089888 best: 0.8146067 (33) total: 908ms remaining: 6.08s 39: learn: 0.8368495 test: 0.8089888 best: 0.8146067 (33) total: 936ms remaining: 6.08s 40: learn: 0.8368495 test: 0.8089888 best: 0.8146067 (33) total: 959ms remaining: 6.05s 41: learn: 0.8382560 test: 0.8089888 best: 0.8146067 (33) total: 979ms remaining: 6.01s 42: learn: 0.8382560 test: 0.8202247 best: 0.8202247 (42) total: 1s remaining: 5.98s 43: learn: 0.8410689 test: 0.8146067 best: 0.8202247 (42) total: 1.02s remaining: 5.94s 44: learn: 0.8396624 test: 0.8146067 best: 0.8202247 (42) total: 1.04s remaining: 5.92s 45: learn: 0.8438819 test: 0.8258427 best: 0.8258427 (45) total: 1.07s remaining: 5.89s 46: learn: 0.8466948 test: 0.8258427 best: 0.8258427 (45) total: 1.09s remaining: 5.85s 47: learn: 0.8466948 test: 0.8258427 best: 0.8258427 (45) total: 1.11s remaining: 5.82s 48: learn: 0.8481013 test: 0.8258427 best: 0.8258427 (45) total: 1.13s remaining: 5.78s 49: learn: 0.8452883 test: 0.8314607 best: 0.8314607 (49) total: 1.15s remaining: 5.75s 50: learn: 0.8438819 test: 0.8314607 best: 0.8314607 (49) total: 1.17s remaining: 5.71s 51: learn: 0.8438819 test: 0.8314607 best: 0.8314607 (49) total: 1.19s remaining: 5.67s 52: learn: 0.8452883 test: 0.8370787 best: 0.8370787 (52) total: 1.21s remaining: 5.63s 53: learn: 0.8424754 test: 0.8370787 best: 0.8370787 (52) total: 1.23s remaining: 5.59s 54: learn: 0.8396624 test: 0.8370787 best: 0.8370787 (52) total: 1.25s remaining: 5.56s 55: learn: 0.8396624 test: 0.8314607 best: 0.8370787 (52) total: 1.27s remaining: 5.53s 56: learn: 0.8396624 test: 0.8314607 best: 0.8370787 (52) total: 1.29s remaining: 5.5s 57: learn: 0.8382560 test: 0.8314607 best: 0.8370787 (52) total: 1.31s remaining: 5.46s 58: learn: 0.8382560 test: 0.8314607 best: 0.8370787 (52) total: 1.32s remaining: 5.41s 59: learn: 0.8396624 test: 0.8314607 best: 0.8370787 (52) total: 1.34s remaining: 5.38s 60: learn: 0.8396624 test: 0.8314607 best: 0.8370787 (52) total: 1.36s remaining: 5.35s 61: learn: 0.8396624 test: 0.8314607 best: 0.8370787 (52) total: 1.38s remaining: 5.31s 62: learn: 0.8396624 test: 0.8314607 best: 0.8370787 (52) total: 1.4s remaining: 5.27s 63: learn: 0.8396624 test: 0.8314607 best: 0.8370787 (52) total: 1.42s remaining: 5.24s 64: learn: 0.8396624 test: 0.8314607 best: 0.8370787 (52) total: 1.44s remaining: 5.21s 65: learn: 0.8396624 test: 0.8314607 best: 0.8370787 (52) total: 1.46s remaining: 5.18s 66: learn: 0.8410689 test: 0.8314607 best: 0.8370787 (52) total: 1.48s remaining: 5.15s 67: learn: 0.8410689 test: 0.8314607 best: 0.8370787 (52) total: 1.5s remaining: 5.13s 68: learn: 0.8410689 test: 0.8314607 best: 0.8370787 (52) total: 1.51s remaining: 5.07s 69: learn: 0.8424754 test: 0.8314607 best: 0.8370787 (52) total: 1.53s remaining: 5.05s 70: learn: 0.8424754 test: 0.8314607 best: 0.8370787 (52) total: 1.56s remaining: 5.03s 71: learn: 0.8424754 test: 0.8314607 best: 0.8370787 (52) total: 1.58s remaining: 5s 72: learn: 0.8438819 test: 0.8314607 best: 0.8370787 (52) total: 1.6s remaining: 4.98s 73: learn: 0.8396624 test: 0.8314607 best: 0.8370787 (52) total: 1.61s remaining: 4.93s 74: learn: 0.8382560 test: 0.8258427 best: 0.8370787 (52) total: 1.64s remaining: 4.91s 75: learn: 0.8410689 test: 0.8258427 best: 0.8370787 (52) total: 1.66s remaining: 4.88s 76: learn: 0.8424754 test: 0.8202247 best: 0.8370787 (52) total: 1.68s remaining: 4.86s 77: learn: 0.8424754 test: 0.8202247 best: 0.8370787 (52) total: 1.7s remaining: 4.83s 78: learn: 0.8438819 test: 0.8202247 best: 0.8370787 (52) total: 1.72s remaining: 4.81s 79: learn: 0.8452883 test: 0.8202247 best: 0.8370787 (52) total: 1.74s remaining: 4.78s 80: learn: 0.8452883 test: 0.8202247 best: 0.8370787 (52) total: 1.76s remaining: 4.76s 81: learn: 0.8452883 test: 0.8202247 best: 0.8370787 (52) total: 1.78s remaining: 4.73s 82: learn: 0.8452883 test: 0.8202247 best: 0.8370787 (52) total: 1.8s remaining: 4.71s 83: learn: 0.8452883 test: 0.8202247 best: 0.8370787 (52) total: 1.82s remaining: 4.68s 84: learn: 0.8438819 test: 0.8202247 best: 0.8370787 (52) total: 1.84s remaining: 4.67s 85: learn: 0.8452883 test: 0.8202247 best: 0.8370787 (52) total: 1.87s remaining: 4.65s 86: learn: 0.8452883 test: 0.8202247 best: 0.8370787 (52) total: 1.89s remaining: 4.64s 87: learn: 0.8438819 test: 0.8202247 best: 0.8370787 (52) total: 1.92s remaining: 4.62s 88: learn: 0.8452883 test: 0.8146067 best: 0.8370787 (52) total: 1.94s remaining: 4.59s 89: learn: 0.8438819 test: 0.8146067 best: 0.8370787 (52) total: 1.96s remaining: 4.58s 90: learn: 0.8438819 test: 0.8146067 best: 0.8370787 (52) total: 1.98s remaining: 4.55s 91: learn: 0.8438819 test: 0.8146067 best: 0.8370787 (52) total: 2s remaining: 4.53s 92: learn: 0.8438819 test: 0.8146067 best: 0.8370787 (52) total: 2.02s remaining: 4.51s 93: learn: 0.8438819 test: 0.8258427 best: 0.8370787 (52) total: 2.05s remaining: 4.49s 94: learn: 0.8438819 test: 0.8258427 best: 0.8370787 (52) total: 2.05s remaining: 4.43s 95: learn: 0.8438819 test: 0.8258427 best: 0.8370787 (52) total: 2.08s remaining: 4.41s 96: learn: 0.8438819 test: 0.8202247 best: 0.8370787 (52) total: 2.1s remaining: 4.39s 97: learn: 0.8438819 test: 0.8258427 best: 0.8370787 (52) total: 2.12s remaining: 4.36s 98: learn: 0.8438819 test: 0.8258427 best: 0.8370787 (52) total: 2.14s remaining: 4.34s 99: learn: 0.8438819 test: 0.8258427 best: 0.8370787 (52) total: 2.16s remaining: 4.32s 100: learn: 0.8438819 test: 0.8258427 best: 0.8370787 (52) total: 2.19s remaining: 4.31s 101: learn: 0.8438819 test: 0.8258427 best: 0.8370787 (52) total: 2.21s remaining: 4.29s 102: learn: 0.8438819 test: 0.8258427 best: 0.8370787 (52) total: 2.23s remaining: 4.27s 103: learn: 0.8438819 test: 0.8258427 best: 0.8370787 (52) total: 2.25s remaining: 4.25s 104: learn: 0.8438819 test: 0.8258427 best: 0.8370787 (52) total: 2.28s remaining: 4.23s 105: learn: 0.8452883 test: 0.8202247 best: 0.8370787 (52) total: 2.3s remaining: 4.21s 106: learn: 0.8452883 test: 0.8202247 best: 0.8370787 (52) total: 2.32s remaining: 4.19s 107: learn: 0.8452883 test: 0.8202247 best: 0.8370787 (52) total: 2.35s remaining: 4.17s 108: learn: 0.8452883 test: 0.8146067 best: 0.8370787 (52) total: 2.37s remaining: 4.15s 109: learn: 0.8466948 test: 0.8146067 best: 0.8370787 (52) total: 2.39s remaining: 4.13s 110: learn: 0.8466948 test: 0.8146067 best: 0.8370787 (52) total: 2.41s remaining: 4.11s 111: learn: 0.8466948 test: 0.8146067 best: 0.8370787 (52) total: 2.44s remaining: 4.09s 112: learn: 0.8466948 test: 0.8146067 best: 0.8370787 (52) total: 2.46s remaining: 4.07s 113: learn: 0.8466948 test: 0.8146067 best: 0.8370787 (52) total: 2.48s remaining: 4.05s 114: learn: 0.8466948 test: 0.8146067 best: 0.8370787 (52) total: 2.5s remaining: 4.02s 115: learn: 0.8466948 test: 0.8146067 best: 0.8370787 (52) total: 2.52s remaining: 4s 116: learn: 0.8466948 test: 0.8146067 best: 0.8370787 (52) total: 2.54s remaining: 3.97s 117: learn: 0.8452883 test: 0.8202247 best: 0.8370787 (52) total: 2.56s remaining: 3.95s 118: learn: 0.8452883 test: 0.8146067 best: 0.8370787 (52) total: 2.58s remaining: 3.92s 119: learn: 0.8452883 test: 0.8146067 best: 0.8370787 (52) total: 2.6s remaining: 3.9s 120: learn: 0.8452883 test: 0.8146067 best: 0.8370787 (52) total: 2.62s remaining: 3.87s 121: learn: 0.8452883 test: 0.8146067 best: 0.8370787 (52) total: 2.64s remaining: 3.85s 122: learn: 0.8452883 test: 0.8146067 best: 0.8370787 (52) total: 2.66s remaining: 3.83s 123: learn: 0.8452883 test: 0.8146067 best: 0.8370787 (52) total: 2.68s remaining: 3.8s 124: learn: 0.8452883 test: 0.8146067 best: 0.8370787 (52) total: 2.7s remaining: 3.78s 125: learn: 0.8452883 test: 0.8146067 best: 0.8370787 (52) total: 2.72s remaining: 3.75s 126: learn: 0.8452883 test: 0.8146067 best: 0.8370787 (52) total: 2.74s remaining: 3.73s 127: learn: 0.8466948 test: 0.8146067 best: 0.8370787 (52) total: 2.76s remaining: 3.71s 128: learn: 0.8466948 test: 0.8146067 best: 0.8370787 (52) total: 2.78s remaining: 3.68s 129: learn: 0.8466948 test: 0.8146067 best: 0.8370787 (52) total: 2.79s remaining: 3.65s 130: learn: 0.8466948 test: 0.8146067 best: 0.8370787 (52) total: 2.81s remaining: 3.62s 131: learn: 0.8466948 test: 0.8146067 best: 0.8370787 (52) total: 2.82s remaining: 3.59s 132: learn: 0.8466948 test: 0.8146067 best: 0.8370787 (52) total: 2.84s remaining: 3.57s 133: learn: 0.8466948 test: 0.8146067 best: 0.8370787 (52) total: 2.85s remaining: 3.53s 134: learn: 0.8466948 test: 0.8146067 best: 0.8370787 (52) total: 2.87s remaining: 3.51s 135: learn: 0.8466948 test: 0.8146067 best: 0.8370787 (52) total: 2.89s remaining: 3.49s 136: learn: 0.8466948 test: 0.8146067 best: 0.8370787 (52) total: 2.91s remaining: 3.47s 137: learn: 0.8466948 test: 0.8146067 best: 0.8370787 (52) total: 2.94s remaining: 3.44s 138: learn: 0.8481013 test: 0.8146067 best: 0.8370787 (52) total: 2.96s remaining: 3.42s 139: learn: 0.8481013 test: 0.8146067 best: 0.8370787 (52) total: 2.97s remaining: 3.4s 140: learn: 0.8481013 test: 0.8146067 best: 0.8370787 (52) total: 3s remaining: 3.38s 141: learn: 0.8481013 test: 0.8146067 best: 0.8370787 (52) total: 3.02s remaining: 3.35s 142: learn: 0.8481013 test: 0.8146067 best: 0.8370787 (52) total: 3.03s remaining: 3.33s 143: learn: 0.8495077 test: 0.8146067 best: 0.8370787 (52) total: 3.05s remaining: 3.31s 144: learn: 0.8495077 test: 0.8146067 best: 0.8370787 (52) total: 3.07s remaining: 3.29s 145: learn: 0.8495077 test: 0.8146067 best: 0.8370787 (52) total: 3.1s remaining: 3.27s 146: learn: 0.8509142 test: 0.8146067 best: 0.8370787 (52) total: 3.12s remaining: 3.24s 147: learn: 0.8509142 test: 0.8146067 best: 0.8370787 (52) total: 3.14s remaining: 3.22s 148: learn: 0.8523207 test: 0.8146067 best: 0.8370787 (52) total: 3.16s remaining: 3.2s 149: learn: 0.8523207 test: 0.8146067 best: 0.8370787 (52) total: 3.18s remaining: 3.18s 150: learn: 0.8523207 test: 0.8146067 best: 0.8370787 (52) total: 3.2s remaining: 3.15s 151: learn: 0.8523207 test: 0.8146067 best: 0.8370787 (52) total: 3.22s remaining: 3.13s 152: learn: 0.8523207 test: 0.8146067 best: 0.8370787 (52) total: 3.24s remaining: 3.11s Stopped by overfitting detector (100 iterations wait) bestTest = 0.8370786517 bestIteration = 52 Shrink model to first 53 iterations.
#Оценим качество acc = accuracy_score(y_valid, y_pred) print(f"Validation Accuracy: {acc:.4f}")
Validation Accuracy: 0.8315
#Предсказание на тесте best_test_preds = best_model.predict(X_test)
# Создание submission_V2.csv submission_V2 = pd.DataFrame({ 'PassengerId': passenger_ids, 'Survived': best_test_preds.astype(int) }) submission_V2.to_csv('submission_V2.csv', index=False) print("✅ Submission файл сохранён как submission_V2.csv")
✅ Submission файл сохранён как submission_V2.csv
# Смотрим оценку качества accuracy_score(y_valid, best_model.predict(X_valid))
0.8370786516853933


