#Импортируем все необходимые библиотеки
import pandas as pd
from catboost import CatBoostClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import json
# 🔕 Отключаем предупреждения, чтобы не загромождали вывод
import warnings
warnings.filterwarnings('ignore')
### Установим красивые дефолтные настройки
### Может быть лень постоянно прописывать
### У графиков параметры цвета, размера, шрифта
### Можно положить их в словарь дефолтных настроек
import matplotlib as mlp
# Сетка (grid)
mlp.rcParams['axes.grid'] = True
mlp.rcParams['grid.color'] = '#D3D3D3'
mlp.rcParams['grid.linestyle'] = '--'
mlp.rcParams['grid.linewidth'] = 1
# Цвет фона
mlp.rcParams['axes.facecolor'] = '#F9F9F9' # светло-серый фон внутри графика
mlp.rcParams['figure.facecolor'] = '#FFFFFF' # фон всего холста
# Легенда
mlp.rcParams['legend.fontsize'] = 14
mlp.rcParams['legend.frameon'] = True
mlp.rcParams['legend.framealpha'] = 0.9
mlp.rcParams['legend.edgecolor'] = '#333333'
# Размер фигуры по умолчанию
mlp.rcParams['figure.figsize'] = (10, 6)
# Шрифты
mlp.rcParams['font.family'] = 'DejaVu Sans' # можешь заменить на 'Arial', 'Roboto' и т.д.
mlp.rcParams['font.size'] = 16
# Цвет осей (спинки)
mlp.rcParams['axes.edgecolor'] = '#333333'
mlp.rcParams['axes.linewidth'] = 2
# Цвет основного текста
mlp.rcParams['text.color'] = '#222222'
# Отдельно скачиваю train...
train_df = pd.read_csv('../data/train.csv')
# ... и отдельно test
test_df = pd.read_csv('../data/test.csv')
# Посмотрим первые 10 строк train'a
train_df.head(10)
PassengerId | Survived | Pclass | Name | Sex | Age | SibSp | Parch | Ticket | Fare | Cabin | Embarked | |
---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 1 | 0 | 3 | Braund, Mr. Owen Harris | male | 22.0 | 1 | 0 | A/5 21171 | 7.2500 | NaN | S |
1 | 2 | 1 | 1 | Cumings, Mrs. John Bradley (Florence Briggs Th... | female | 38.0 | 1 | 0 | PC 17599 | 71.2833 | C85 | C |
2 | 3 | 1 | 3 | Heikkinen, Miss. Laina | female | 26.0 | 0 | 0 | STON/O2. 3101282 | 7.9250 | NaN | S |
3 | 4 | 1 | 1 | Futrelle, Mrs. Jacques Heath (Lily May Peel) | female | 35.0 | 1 | 0 | 113803 | 53.1000 | C123 | S |
4 | 5 | 0 | 3 | Allen, Mr. William Henry | male | 35.0 | 0 | 0 | 373450 | 8.0500 | NaN | S |
5 | 6 | 0 | 3 | Moran, Mr. James | male | NaN | 0 | 0 | 330877 | 8.4583 | NaN | Q |
6 | 7 | 0 | 1 | McCarthy, Mr. Timothy J | male | 54.0 | 0 | 0 | 17463 | 51.8625 | E46 | S |
7 | 8 | 0 | 3 | Palsson, Master. Gosta Leonard | male | 2.0 | 3 | 1 | 349909 | 21.0750 | NaN | S |
8 | 9 | 1 | 3 | Johnson, Mrs. Oscar W (Elisabeth Vilhelmina Berg) | female | 27.0 | 0 | 2 | 347742 | 11.1333 | NaN | S |
9 | 10 | 1 | 2 | Nasser, Mrs. Nicholas (Adele Achem) | female | 14.0 | 1 | 0 | 237736 | 30.0708 | NaN | C |
# Посмотрим информацию по train'у
train_df.info()
RangeIndex: 891 entries, 0 to 890
Data columns (total 12 columns):
# Column Non-Null Count Dtype
--- ------ -------------- -----
0 PassengerId 891 non-null int64
1 Survived 891 non-null int64
2 Pclass 891 non-null int64
3 Name 891 non-null object
4 Sex 891 non-null object
5 Age 714 non-null float64
6 SibSp 891 non-null int64
7 Parch 891 non-null int64
8 Ticket 891 non-null object
9 Fare 891 non-null float64
10 Cabin 204 non-null object
11 Embarked 889 non-null object
dtypes: float64(2), int64(5), object(5)
memory usage: 83.7+ KB
Видно, что есть пропуски в признаке возраст. Очень много пропусков в признаке кабина
# Получаем базовую статистику по числовым признакам (среднее, медиана, std и т.д.)
train_df.describe()
PassengerId | Survived | Pclass | Age | SibSp | Parch | Fare | |
---|---|---|---|---|---|---|---|
count | 891.000000 | 891.000000 | 891.000000 | 714.000000 | 891.000000 | 891.000000 | 891.000000 |
mean | 446.000000 | 0.383838 | 2.308642 | 29.699118 | 0.523008 | 0.381594 | 32.204208 |
std | 257.353842 | 0.486592 | 0.836071 | 14.526497 | 1.102743 | 0.806057 | 49.693429 |
min | 1.000000 | 0.000000 | 1.000000 | 0.420000 | 0.000000 | 0.000000 | 0.000000 |
25% | 223.500000 | 0.000000 | 2.000000 | 20.125000 | 0.000000 | 0.000000 | 7.910400 |
50% | 446.000000 | 0.000000 | 3.000000 | 28.000000 | 0.000000 | 0.000000 | 14.454200 |
75% | 668.500000 | 1.000000 | 3.000000 | 38.000000 | 1.000000 | 0.000000 | 31.000000 |
max | 891.000000 | 1.000000 | 3.000000 | 80.000000 | 8.000000 | 6.000000 | 512.329200 |
# И статистику по объектным признакам
train_df.describe(include='object')
Name | Sex | Ticket | Cabin | Embarked | |
---|---|---|---|---|---|
count | 891 | 891 | 891 | 204 | 889 |
unique | 891 | 2 | 681 | 147 | 3 |
top | Braund, Mr. Owen Harris | male | 347082 | G6 | S |
freq | 1 | 577 | 7 | 4 | 644 |
# Посмотрим на распределение таргета
sns.countplot(x='Survived', data=train_df)
plt.title('Распределение выживших')
plt.show()

Дисбаланса классов не неаблюдаем
# Посмотрим на распределение таргета по признакам
# Пол
plt.figure(figsize=(5,4))
sns.countplot(x='Sex', hue='Survived', data=train_df)
plt.title('Пол и выживание')
plt.show()
# Класс каюты
plt.figure(figsize=(5,4))
sns.countplot(x='Pclass', hue='Survived', data=train_df)
plt.title('Класс каюты и выживание')
plt.show()
# Порт посадки
plt.figure(figsize=(5,4))
sns.countplot(x='Embarked', hue='Survived', data=train_df)
plt.title('Порт посадки и выживание')
plt.show()



Видно, что все признаки являются важными для таргета. В противном случае графики для разных признаков были бы одинаковыми.
# Для возраста и платы за проезд посмотрим на ящики с усами
# Возраст
plt.figure(figsize=(6,5))
sns.boxplot(x='Survived', y='Age', data=train_df)
plt.title('Возраст и выживание (boxplot)')
plt.show()
# Fare
plt.figure(figsize=(6,5))
sns.boxplot(x='Survived', y='Fare', data=train_df)
plt.title('Стоимость билета и выживание (boxplot)')
plt.show()


Тоже видны различия в зависимости от таргета
# Сделаем списки: категориальные и числовые признаки
categorical_cols = ['Sex', 'Pclass', 'Embarked', 'Cabin']
numeric_cols = ['Age', 'Fare', 'SibSp', 'Parch']
# Посмотрим тепловую карту корреляции между числовыми признаками
plt.figure(figsize=(10,8))
sns.heatmap(train_df.corr(numeric_only=True), annot=True, cmap='coolwarm', fmt=".2f")
plt.title('Корреляция числовых признаков')
plt.show()

Визуально видно, что мультиколлинеарных признаков нет, но проверим с помощью функции
### Секретная функция со Stackovervlow для фильтрации признаков
def get_redundant_pairs(df):
pairs_to_drop = set()
cols = df.columns
for i in range(0, df.shape[1]):
for j in range(0, i+1):
pairs_to_drop.add((cols[i], cols[j]))
return pairs_to_drop
def get_top_abs_correlations(df, n=5):
au_corr = df.corr().abs().unstack()
labels_to_drop = get_redundant_pairs(df)
au_corr = au_corr.drop(labels=labels_to_drop).sort_values(ascending=False)
return au_corr[0:n]
print("Top Absolute Correlations")
print(get_top_abs_correlations(train_df[numeric_cols], 10))
Top Absolute Correlations
SibSp Parch 0.414838
Age SibSp 0.308247
Fare Parch 0.216225
Age Parch 0.189119
Fare SibSp 0.159651
Age Fare 0.096067
dtype: float64
Мультиколлинеарности нет - подтверждаем
# Смотрим пропуски
train_df.isnull().sum()
PassengerId 0
Survived 0
Pclass 0
Name 0
Sex 0
Age 177
SibSp 0
Parch 0
Ticket 0
Fare 0
Cabin 687
Embarked 2
dtype: int64
#Удаляем пропуски в колонке "Embarked"
#Так как их всего два
#Сейчас до объединения
#Тренировочных и тестовых данных
#Можно это делать
train_df = train_df.dropna(subset=['Embarked']).copy()
#Смотрим снова
train_df.isnull().sum()
PassengerId 0
Survived 0
Pclass 0
Name 0
Sex 0
Age 177
SibSp 0
Parch 0
Ticket 0
Fare 0
Cabin 687
Embarked 0
dtype: int64
После удаления можно объединять строки. Удаляли до объединения, чтобы не повредить тестовые данные.
Теперь можно готовить данные к объединению и обработке общего датафрейма.
Обязательно нужно всё сделать правильно, чтобы не испортить данные
#Добавим колонку-метку, чтобы потом правильно разделить данные обратно
train_df['is_train'] = 1
test_df['is_train'] = 0
#Добавим фиктивную колонку `Survived` в тест (чтобы структура была одинаковая)
test_df['Survived'] = np.nan
#Сохраняем PassengerId из теста для submission
passenger_ids = test_df['PassengerId'].copy()
(колонка не важна для обучения, но требуется в итоговом файле решения)
#Удаляем колонку PassengerId перед объединением — она не нужна для модели
train_df = train_df.drop(columns=['PassengerId'])
test_df = test_df.drop(columns=['PassengerId'])
#Объединяем тренировочные и тестовые данные для одинаковой обработки данных
full_df = pd.concat([train_df, test_df], axis=0).reset_index(drop=True)
#Пропущенные значения к признаке Age заменяем медианным значением по всем пассажирам
#Но которые были только в тренировочных данных
#Медианное значение менее чувствительно к выбросам в данных
full_df['Age'] = full_df['Age'].fillna(train_df['Age'].median())
# Снова смотрим пропуски, но уже в объединённом датафрейме
full_df.isnull().sum()
Survived 418
Pclass 0
Name 0
Sex 0
Age 0
SibSp 0
Parch 0
Ticket 0
Fare 1
Cabin 1014
Embarked 0
is_train 0
dtype: int64
# Посмотрим ещё раз на данные, чтобы принять решение, что делать с признаком плата за проезд
full_df.head(20)
Survived | Pclass | Name | Sex | Age | SibSp | Parch | Ticket | Fare | Cabin | Embarked | is_train | |
---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 0.0 | 3 | Braund, Mr. Owen Harris | male | 22.0 | 1 | 0 | A/5 21171 | 7.2500 | NaN | S | 1 |
1 | 1.0 | 1 | Cumings, Mrs. John Bradley (Florence Briggs Th... | female | 38.0 | 1 | 0 | PC 17599 | 71.2833 | C85 | C | 1 |
2 | 1.0 | 3 | Heikkinen, Miss. Laina | female | 26.0 | 0 | 0 | STON/O2. 3101282 | 7.9250 | NaN | S | 1 |
3 | 1.0 | 1 | Futrelle, Mrs. Jacques Heath (Lily May Peel) | female | 35.0 | 1 | 0 | 113803 | 53.1000 | C123 | S | 1 |
4 | 0.0 | 3 | Allen, Mr. William Henry | male | 35.0 | 0 | 0 | 373450 | 8.0500 | NaN | S | 1 |
5 | 0.0 | 3 | Moran, Mr. James | male | 28.0 | 0 | 0 | 330877 | 8.4583 | NaN | Q | 1 |
6 | 0.0 | 1 | McCarthy, Mr. Timothy J | male | 54.0 | 0 | 0 | 17463 | 51.8625 | E46 | S | 1 |
7 | 0.0 | 3 | Palsson, Master. Gosta Leonard | male | 2.0 | 3 | 1 | 349909 | 21.0750 | NaN | S | 1 |
8 | 1.0 | 3 | Johnson, Mrs. Oscar W (Elisabeth Vilhelmina Berg) | female | 27.0 | 0 | 2 | 347742 | 11.1333 | NaN | S | 1 |
9 | 1.0 | 2 | Nasser, Mrs. Nicholas (Adele Achem) | female | 14.0 | 1 | 0 | 237736 | 30.0708 | NaN | C | 1 |
10 | 1.0 | 3 | Sandstrom, Miss. Marguerite Rut | female | 4.0 | 1 | 1 | PP 9549 | 16.7000 | G6 | S | 1 |
11 | 1.0 | 1 | Bonnell, Miss. Elizabeth | female | 58.0 | 0 | 0 | 113783 | 26.5500 | C103 | S | 1 |
12 | 0.0 | 3 | Saundercock, Mr. William Henry | male | 20.0 | 0 | 0 | A/5. 2151 | 8.0500 | NaN | S | 1 |
13 | 0.0 | 3 | Andersson, Mr. Anders Johan | male | 39.0 | 1 | 5 | 347082 | 31.2750 | NaN | S | 1 |
14 | 0.0 | 3 | Vestrom, Miss. Hulda Amanda Adolfina | female | 14.0 | 0 | 0 | 350406 | 7.8542 | NaN | S | 1 |
15 | 1.0 | 2 | Hewlett, Mrs. (Mary D Kingcome) | female | 55.0 | 0 | 0 | 248706 | 16.0000 | NaN | S | 1 |
16 | 0.0 | 3 | Rice, Master. Eugene | male | 2.0 | 4 | 1 | 382652 | 29.1250 | NaN | Q | 1 |
17 | 1.0 | 2 | Williams, Mr. Charles Eugene | male | 28.0 | 0 | 0 | 244373 | 13.0000 | NaN | S | 1 |
18 | 0.0 | 3 | Vander Planke, Mrs. Julius (Emelia Maria Vande... | female | 31.0 | 1 | 0 | 345763 | 18.0000 | NaN | S | 1 |
19 | 1.0 | 3 | Masselmani, Mrs. Fatima | female | 28.0 | 0 | 0 | 2649 | 7.2250 | NaN | C | 1 |
#Также закодируем и цену за проезд
#Удалять после объединения нельзя - можно удалить строку из тестовых данных
full_df['Fare'] = full_df['Fare'].fillna(train_df['Fare'].median())
# Проверяем
full_df.isnull().sum()
Survived 418
Pclass 0
Name 0
Sex 0
Age 0
SibSp 0
Parch 0
Ticket 0
Fare 0
Cabin 1014
Embarked 0
is_train 0
dtype: int64
Три четверти данных по колонке Cabin в тестовых данных являются NaN. Скорее всего, это пассажиры второго или третьего класса, у которых просто не было собственной кабины. Совсем избавляться от этой колонки, наверное, не стоит — и не обязательно. Вместо этого мы закодируем её как наличие или отсутствие палубы.
# Создаём новый бинарный признак: была ли указана каюта
full_df['Has_Cabin'] = full_df['Cabin'].notnull().astype(int)
# Удаляем оригинальную колонку Cabin, чтобы она не мешала
full_df = full_df.drop(columns='Cabin')
# Посмотрим на наш изменённый датафрейм
full_df
Survived | Pclass | Name | Sex | Age | SibSp | Parch | Ticket | Fare | Embarked | is_train | Has_Cabin | |
---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 0.0 | 3 | Braund, Mr. Owen Harris | male | 22.0 | 1 | 0 | A/5 21171 | 7.2500 | S | 1 | 0 |
1 | 1.0 | 1 | Cumings, Mrs. John Bradley (Florence Briggs Th... | female | 38.0 | 1 | 0 | PC 17599 | 71.2833 | C | 1 | 1 |
2 | 1.0 | 3 | Heikkinen, Miss. Laina | female | 26.0 | 0 | 0 | STON/O2. 3101282 | 7.9250 | S | 1 | 0 |
3 | 1.0 | 1 | Futrelle, Mrs. Jacques Heath (Lily May Peel) | female | 35.0 | 1 | 0 | 113803 | 53.1000 | S | 1 | 1 |
4 | 0.0 | 3 | Allen, Mr. William Henry | male | 35.0 | 0 | 0 | 373450 | 8.0500 | S | 1 | 0 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
1302 | NaN | 3 | Spector, Mr. Woolf | male | 28.0 | 0 | 0 | A.5. 3236 | 8.0500 | S | 0 | 0 |
1303 | NaN | 1 | Oliva y Ocana, Dona. Fermina | female | 39.0 | 0 | 0 | PC 17758 | 108.9000 | C | 0 | 1 |
1304 | NaN | 3 | Saether, Mr. Simon Sivertsen | male | 38.5 | 0 | 0 | SOTON/O.Q. 3101262 | 7.2500 | S | 0 | 0 |
1305 | NaN | 3 | Ware, Mr. Frederick | male | 28.0 | 0 | 0 | 359309 | 8.0500 | S | 0 | 0 |
1306 | NaN | 3 | Peter, Master. Michael J | male | 28.0 | 1 | 1 | 2668 | 22.3583 | C | 0 | 0 |
1307 rows × 12 columns
Начинаем избавляться от ненужных, не несущих полезной информации для обучения модели, признаков
# Имя и номер билета удаляем
full_df = full_df.drop(columns=['Name','Ticket'])
# Смотрим результат
full_df
Survived | Pclass | Sex | Age | SibSp | Parch | Fare | Embarked | is_train | Has_Cabin | |
---|---|---|---|---|---|---|---|---|---|---|
0 | 0.0 | 3 | male | 22.0 | 1 | 0 | 7.2500 | S | 1 | 0 |
1 | 1.0 | 1 | female | 38.0 | 1 | 0 | 71.2833 | C | 1 | 1 |
2 | 1.0 | 3 | female | 26.0 | 0 | 0 | 7.9250 | S | 1 | 0 |
3 | 1.0 | 1 | female | 35.0 | 1 | 0 | 53.1000 | S | 1 | 1 |
4 | 0.0 | 3 | male | 35.0 | 0 | 0 | 8.0500 | S | 1 | 0 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
1302 | NaN | 3 | male | 28.0 | 0 | 0 | 8.0500 | S | 0 | 0 |
1303 | NaN | 1 | female | 39.0 | 0 | 0 | 108.9000 | C | 0 | 1 |
1304 | NaN | 3 | male | 38.5 | 0 | 0 | 7.2500 | S | 0 | 0 |
1305 | NaN | 3 | male | 28.0 | 0 | 0 | 8.0500 | S | 0 | 0 |
1306 | NaN | 3 | male | 28.0 | 1 | 1 | 22.3583 | C | 0 | 0 |
1307 rows × 10 columns
В целом не стесняемся часто смотреть и проверять результат. В процессе постоянного отсмотра данных может придти идея, которая улучшит качество модели
Теперь закодируем колонки, которые являются объектами (object), как категории (category). Это особенно важно, если мы собираемся использовать модель CatBoost — она умеет напрямую работать с категориальными признаками и не требует их one-hot-кодирования.
CatBoost сам обработает эти признаки, если они будут иметь тип category, поэтому просто приведём нужные колонки к этому типу.
# Приводим колонки пол и порт посадки к категориальному виду
full_df['Sex'] = full_df['Sex'].astype('category')
full_df['Embarked'] = full_df['Embarked'].astype('category')
# Проверяем
full_df.describe(include='all')
Survived | Pclass | Sex | Age | SibSp | Parch | Fare | Embarked | is_train | Has_Cabin | |
---|---|---|---|---|---|---|---|---|---|---|
count | 889.000000 | 1307.000000 | 1307 | 1307.000000 | 1307.000000 | 1307.000000 | 1307.000000 | 1307 | 1307.000000 | 1307.000000 |
unique | NaN | NaN | 2 | NaN | NaN | NaN | NaN | 3 | NaN | NaN |
top | NaN | NaN | male | NaN | NaN | NaN | NaN | S | NaN | NaN |
freq | NaN | NaN | 843 | NaN | NaN | NaN | NaN | 914 | NaN | NaN |
mean | 0.382452 | 2.296863 | NaN | 29.471821 | 0.499617 | 0.385616 | 33.209595 | NaN | 0.680184 | 0.224178 |
std | 0.486260 | 0.836942 | NaN | 12.881592 | 1.042273 | 0.866092 | 51.748768 | NaN | 0.466584 | 0.417199 |
min | 0.000000 | 1.000000 | NaN | 0.170000 | 0.000000 | 0.000000 | 0.000000 | NaN | 0.000000 | 0.000000 |
25% | 0.000000 | 2.000000 | NaN | 22.000000 | 0.000000 | 0.000000 | 7.895800 | NaN | 0.000000 | 0.000000 |
50% | 0.000000 | 3.000000 | NaN | 28.000000 | 0.000000 | 0.000000 | 14.454200 | NaN | 1.000000 | 0.000000 |
75% | 1.000000 | 3.000000 | NaN | 35.000000 | 1.000000 | 0.000000 | 31.275000 | NaN | 1.000000 | 0.000000 |
max | 1.000000 | 3.000000 | NaN | 80.000000 | 8.000000 | 9.000000 | 512.329200 | NaN | 1.000000 | 1.000000 |
Кроме того, колонка Pclass изначально имеет тип int, но на самом деле это категориальный признак (класс обслуживания: 1, 2 или 3). Если оставить её как числовую, модель может ошибочно посчитать, что класс 3 «больше» и важнее, чем класс 2, а тот — важнее, чем класс 1. Чтобы избежать этого, мы также приведём Pclass к категориальному типу.
# Приводим колонку класс к категориальному виду
full_df['Pclass'] = full_df['Pclass'].astype('category')
Обработка данных завершена и теперь разделяем данные обратно:
# Разделим обратно:
X_train = full_df[full_df['is_train'] == 1].drop(['is_train', 'Survived'], axis=1)
y_train = full_df[full_df['is_train'] == 1]['Survived']
X_test = full_df[full_df['is_train'] == 0].drop(['is_train', 'Survived'], axis=1)
# Проверяем размеры
print(X_train.shape, y_train.shape, X_test.shape)
(889, 8) (889,) (418, 8)
Начинаем обучение модели
# Начинаем обучение
# Сначали сплиттим выборку
X_train_split, X_valid, y_train_split, y_valid = train_test_split(
X_train, y_train,
test_size=0.2,
random_state=42,
stratify=y_train
)
# Положим в список категориальных признаков для CatBoost наши приведённые к типу Category колонки
cat_features = X_train.select_dtypes(include='category').columns.tolist()
# Проверяем
cat_features
['Pclass', 'Sex', 'Embarked']
# Обучаем модель с достаточно средними параметрами
# Пока не используем перебор гиперпараметров
model = CatBoostClassifier(
iterations=1000,
learning_rate=0.05,
depth=6,
eval_metric='Accuracy',
random_seed=42,
early_stopping_rounds=50,
verbose=100
)
model.fit(
X_train_split, y_train_split,
eval_set=(X_valid, y_valid),
cat_features=cat_features
)
0: learn: 0.8227848 test: 0.7977528 best: 0.7977528 (0) total: 160ms remaining: 2m 39s
Stopped by overfitting detector (50 iterations wait)
bestTest = 0.8314606742
bestIteration = 32
Shrink model to first 33 iterations.
# Оценим качество
y_pred = model.predict(X_valid)
acc = accuracy_score(y_valid, y_pred)
print(f"Validation Accuracy: {acc:.4f}")
Validation Accuracy: 0.8315
Доля правильных ответов 83,15%


# Предсказание на тесте
test_preds = model.predict(X_test)
# Посмотрим на предсказания модели
test_preds
array([0., 0., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0., 1., 0., 1., 1., 0.,
0., 0., 1., 0., 0., 1., 0., 1., 0., 1., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 0., 0., 0., 1., 0., 0.,
0., 1., 1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 1., 1., 1., 1., 0.,
0., 1., 1., 0., 0., 0., 1., 0., 0., 1., 0., 1., 1., 0., 0., 0., 0.,
0., 1., 0., 1., 1., 0., 0., 1., 0., 0., 0., 1., 0., 1., 0., 1., 0.,
0., 0., 1., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 0., 0., 1., 0.,
1., 1., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,
0., 0., 0., 1., 1., 0., 0., 1., 1., 1., 0., 0., 0., 0., 0., 1., 0.,
0., 0., 0., 0., 0., 1., 1., 0., 1., 1., 0., 0., 1., 0., 1., 0., 1.,
0., 0., 0., 0., 0., 0., 0., 1., 0., 1., 1., 0., 0., 1., 1., 0., 1.,
0., 0., 1., 0., 1., 0., 0., 0., 0., 1., 0., 0., 1., 0., 1., 0., 1.,
0., 1., 0., 1., 1., 0., 1., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,
1., 1., 1., 1., 0., 0., 0., 0., 1., 0., 1., 1., 1., 0., 0., 0., 0.,
0., 0., 0., 1., 0., 0., 0., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
1., 1., 0., 1., 0., 0., 0., 0., 0., 1., 1., 1., 1., 0., 0., 0., 0.,
0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 1., 1.,
0., 1., 0., 0., 0., 0., 0., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0.,
0., 1., 0., 1., 0., 0., 0., 1., 0., 0., 1., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 1., 0., 1., 0., 1., 0., 1., 1., 0., 0., 0., 1., 0., 1.,
0., 0., 0., 0., 1., 1., 0., 1., 0., 0., 0., 1., 0., 0., 1., 0., 0.,
1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
1., 0., 0., 0., 1., 0., 1., 0., 0., 1., 0., 1., 0., 0., 0., 0., 0.,
1., 1., 1., 1., 1., 0., 1., 0., 0., 0.])
# Создание submission.csv
submission = pd.DataFrame({
'PassengerId': passenger_ids,
'Survived': test_preds.astype(int)
})
submission.to_csv('../submissions/submission.csv', index=False)
print("✅ Submission файл сохранён как submission.csv")
✅ Submission файл сохранён как submission.csv
# Посмотрим файл
submission
PassengerId | Survived | |
---|---|---|
0 | 892 | 0 |
1 | 893 | 0 |
2 | 894 | 0 |
3 | 895 | 0 |
4 | 896 | 0 |
... | ... | ... |
413 | 1305 | 0 |
414 | 1306 | 1 |
415 | 1307 | 0 |
416 | 1308 | 0 |
417 | 1309 | 0 |
418 rows × 2 columns

Результат Топ 2'400
Теперь попробуем улучшить качество модели с помощью подбора гиперпараметров
Буду использовать случайный подбор параметров с помощью RandomizedSearchCV
# Импортируем модуль
from sklearn.model_selection import RandomizedSearchCV
#Сетка гиперпараметров
param_grid = {
'depth': [4, 6, 8, 10], # Максимальная глубина дерева (чем глубже — тем сложнее модель)
'learning_rate': [0.01, 0.05, 0.1], # Скорость обучения (маленькое значение = медленнее обучение, но может быть точнее)
'iterations': [300, 500, 1000], # Количество деревьев (итераций бустинга)
'l2_leaf_reg': [1, 3, 5, 7, 9], # L2-регуляризация — предотвращает переобучение
'border_count': [32, 64, 128] # Количество бинов для дискретизации числовых признаков
}
#Randomized Search с кросс-валидацией
random_search = RandomizedSearchCV(
estimator=model,
param_distributions=param_grid,
n_iter=45, # Сколько случайных комбинаций попробовать
scoring='accuracy', # Метрика качества, которую нужно максимизировать
cv=10, # Количество фолдов (разбиений) для кросс-валидации
verbose=2, # Показывать процесс обучения в терминале
n_jobs=-1 # Использовать все доступные ядра CPU для ускорения
)
# Создаём экземпляр модели
model = CatBoostClassifier(silent=True, random_state=42) # random state фиксированный
# Фиксируем некоторые параметры для модели
fit_params = {
"eval_set": [(X_valid, y_valid)], # Набор валидационных данных (для контроля переобучения и использования early stopping)
"early_stopping_rounds": 100, # Если метрика не улучшается в течение 100 итераций — обучение остановится
"cat_features": cat_features, # Указываем, какие признаки являются категориальными (CatBoost работает с ними нативно)
"verbose": 1 # Показывать прогресс обучения во время тренировки
}
#Запуск подбора
random_search.fit(X_train_split, y_train_split, **fit_params)
Fitting 10 folds for each of 45 candidates, totalling 450 fits
0: learn: 0.7988748 test: 0.7752809 best: 0.7752809 (0) total: 18.8ms remaining: 5.63s
1: learn: 0.8016878 test: 0.7808989 best: 0.7808989 (1) total: 39.8ms remaining: 5.93s
2: learn: 0.8101266 test: 0.7921348 best: 0.7921348 (2) total: 56.4ms remaining: 5.58s
3: learn: 0.8045007 test: 0.7865169 best: 0.7921348 (2) total: 76.9ms remaining: 5.69s
4: learn: 0.8030942 test: 0.7865169 best: 0.7921348 (2) total: 97.1ms remaining: 5.73s
5: learn: 0.8087201 test: 0.7977528 best: 0.7977528 (5) total: 118ms remaining: 5.78s
6: learn: 0.8087201 test: 0.7977528 best: 0.7977528 (5) total: 139ms remaining: 5.82s
7: learn: 0.8101266 test: 0.8033708 best: 0.8033708 (7) total: 160ms remaining: 5.85s
8: learn: 0.8101266 test: 0.7977528 best: 0.8033708 (7) total: 181ms remaining: 5.86s
9: learn: 0.8101266 test: 0.7977528 best: 0.8033708 (7) total: 201ms remaining: 5.82s
10: learn: 0.8101266 test: 0.7977528 best: 0.8033708 (7) total: 220ms remaining: 5.77s
11: learn: 0.8115331 test: 0.7977528 best: 0.8033708 (7) total: 241ms remaining: 5.78s
12: learn: 0.8171589 test: 0.7977528 best: 0.8033708 (7) total: 262ms remaining: 5.78s
13: learn: 0.8185654 test: 0.7977528 best: 0.8033708 (7) total: 293ms remaining: 5.98s
14: learn: 0.8185654 test: 0.8033708 best: 0.8033708 (7) total: 322ms remaining: 6.12s
15: learn: 0.8185654 test: 0.8033708 best: 0.8033708 (7) total: 348ms remaining: 6.17s
16: learn: 0.8185654 test: 0.8033708 best: 0.8033708 (7) total: 369ms remaining: 6.14s
17: learn: 0.8185654 test: 0.8033708 best: 0.8033708 (7) total: 385ms remaining: 6.03s
18: learn: 0.8199719 test: 0.8033708 best: 0.8033708 (7) total: 407ms remaining: 6.03s
19: learn: 0.8227848 test: 0.8033708 best: 0.8033708 (7) total: 430ms remaining: 6.02s
20: learn: 0.8227848 test: 0.8033708 best: 0.8033708 (7) total: 452ms remaining: 6s
21: learn: 0.8227848 test: 0.8033708 best: 0.8033708 (7) total: 473ms remaining: 5.98s
22: learn: 0.8255977 test: 0.8033708 best: 0.8033708 (7) total: 495ms remaining: 5.96s
23: learn: 0.8270042 test: 0.8033708 best: 0.8033708 (7) total: 501ms remaining: 5.76s
24: learn: 0.8270042 test: 0.8033708 best: 0.8033708 (7) total: 521ms remaining: 5.73s
25: learn: 0.8255977 test: 0.8033708 best: 0.8033708 (7) total: 536ms remaining: 5.65s
26: learn: 0.8255977 test: 0.8033708 best: 0.8033708 (7) total: 558ms remaining: 5.64s
27: learn: 0.8241913 test: 0.7977528 best: 0.8033708 (7) total: 576ms remaining: 5.6s
28: learn: 0.8255977 test: 0.7977528 best: 0.8033708 (7) total: 596ms remaining: 5.57s
29: learn: 0.8255977 test: 0.8033708 best: 0.8033708 (7) total: 615ms remaining: 5.54s
30: learn: 0.8255977 test: 0.8033708 best: 0.8033708 (7) total: 635ms remaining: 5.51s
31: learn: 0.8255977 test: 0.8089888 best: 0.8089888 (31) total: 655ms remaining: 5.48s
32: learn: 0.8284107 test: 0.8089888 best: 0.8089888 (31) total: 674ms remaining: 5.45s
33: learn: 0.8298172 test: 0.8146067 best: 0.8146067 (33) total: 694ms remaining: 5.43s
34: learn: 0.8340366 test: 0.8146067 best: 0.8146067 (33) total: 706ms remaining: 5.35s
35: learn: 0.8354430 test: 0.8146067 best: 0.8146067 (33) total: 728ms remaining: 5.34s
36: learn: 0.8354430 test: 0.8146067 best: 0.8146067 (33) total: 749ms remaining: 5.32s
37: learn: 0.8368495 test: 0.8089888 best: 0.8146067 (33) total: 764ms remaining: 5.27s
38: learn: 0.8382560 test: 0.8089888 best: 0.8146067 (33) total: 780ms remaining: 5.22s
39: learn: 0.8368495 test: 0.8089888 best: 0.8146067 (33) total: 799ms remaining: 5.19s
40: learn: 0.8368495 test: 0.8089888 best: 0.8146067 (33) total: 820ms remaining: 5.18s
41: learn: 0.8382560 test: 0.8089888 best: 0.8146067 (33) total: 840ms remaining: 5.16s
42: learn: 0.8382560 test: 0.8202247 best: 0.8202247 (42) total: 862ms remaining: 5.15s
43: learn: 0.8410689 test: 0.8146067 best: 0.8202247 (42) total: 884ms remaining: 5.14s
44: learn: 0.8396624 test: 0.8146067 best: 0.8202247 (42) total: 907ms remaining: 5.14s
45: learn: 0.8438819 test: 0.8258427 best: 0.8258427 (45) total: 930ms remaining: 5.14s
46: learn: 0.8466948 test: 0.8258427 best: 0.8258427 (45) total: 953ms remaining: 5.13s
47: learn: 0.8466948 test: 0.8258427 best: 0.8258427 (45) total: 976ms remaining: 5.12s
48: learn: 0.8481013 test: 0.8258427 best: 0.8258427 (45) total: 999ms remaining: 5.12s
49: learn: 0.8452883 test: 0.8314607 best: 0.8314607 (49) total: 1.02s remaining: 5.11s
50: learn: 0.8438819 test: 0.8314607 best: 0.8314607 (49) total: 1.04s remaining: 5.09s
51: learn: 0.8438819 test: 0.8314607 best: 0.8314607 (49) total: 1.06s remaining: 5.07s
52: learn: 0.8452883 test: 0.8370787 best: 0.8370787 (52) total: 1.08s remaining: 5.05s
53: learn: 0.8424754 test: 0.8370787 best: 0.8370787 (52) total: 1.1s remaining: 5.04s
54: learn: 0.8396624 test: 0.8370787 best: 0.8370787 (52) total: 1.13s remaining: 5.01s
55: learn: 0.8396624 test: 0.8314607 best: 0.8370787 (52) total: 1.15s remaining: 4.99s
56: learn: 0.8396624 test: 0.8314607 best: 0.8370787 (52) total: 1.17s remaining: 4.97s
57: learn: 0.8382560 test: 0.8314607 best: 0.8370787 (52) total: 1.18s remaining: 4.94s
58: learn: 0.8382560 test: 0.8314607 best: 0.8370787 (52) total: 1.2s remaining: 4.89s
59: learn: 0.8396624 test: 0.8314607 best: 0.8370787 (52) total: 1.22s remaining: 4.87s
60: learn: 0.8396624 test: 0.8314607 best: 0.8370787 (52) total: 1.24s remaining: 4.85s
61: learn: 0.8396624 test: 0.8314607 best: 0.8370787 (52) total: 1.27s remaining: 4.89s
62: learn: 0.8396624 test: 0.8314607 best: 0.8370787 (52) total: 1.29s remaining: 4.87s
63: learn: 0.8396624 test: 0.8314607 best: 0.8370787 (52) total: 1.32s remaining: 4.86s
64: learn: 0.8396624 test: 0.8314607 best: 0.8370787 (52) total: 1.34s remaining: 4.85s
65: learn: 0.8396624 test: 0.8314607 best: 0.8370787 (52) total: 1.36s remaining: 4.84s
66: learn: 0.8410689 test: 0.8314607 best: 0.8370787 (52) total: 1.39s remaining: 4.82s
67: learn: 0.8410689 test: 0.8314607 best: 0.8370787 (52) total: 1.41s remaining: 4.8s
68: learn: 0.8410689 test: 0.8314607 best: 0.8370787 (52) total: 1.42s remaining: 4.75s
69: learn: 0.8424754 test: 0.8314607 best: 0.8370787 (52) total: 1.44s remaining: 4.74s
70: learn: 0.8424754 test: 0.8314607 best: 0.8370787 (52) total: 1.46s remaining: 4.72s
71: learn: 0.8424754 test: 0.8314607 best: 0.8370787 (52) total: 1.49s remaining: 4.7s
72: learn: 0.8438819 test: 0.8314607 best: 0.8370787 (52) total: 1.51s remaining: 4.69s
73: learn: 0.8396624 test: 0.8314607 best: 0.8370787 (52) total: 1.52s remaining: 4.66s
74: learn: 0.8382560 test: 0.8258427 best: 0.8370787 (52) total: 1.55s remaining: 4.64s
75: learn: 0.8410689 test: 0.8258427 best: 0.8370787 (52) total: 1.57s remaining: 4.63s
76: learn: 0.8424754 test: 0.8202247 best: 0.8370787 (52) total: 1.59s remaining: 4.6s
77: learn: 0.8424754 test: 0.8202247 best: 0.8370787 (52) total: 1.61s remaining: 4.58s
78: learn: 0.8438819 test: 0.8202247 best: 0.8370787 (52) total: 1.63s remaining: 4.56s
79: learn: 0.8452883 test: 0.8202247 best: 0.8370787 (52) total: 1.65s remaining: 4.54s
80: learn: 0.8452883 test: 0.8202247 best: 0.8370787 (52) total: 1.67s remaining: 4.52s
81: learn: 0.8452883 test: 0.8202247 best: 0.8370787 (52) total: 1.69s remaining: 4.49s
82: learn: 0.8452883 test: 0.8202247 best: 0.8370787 (52) total: 1.71s remaining: 4.47s
83: learn: 0.8452883 test: 0.8202247 best: 0.8370787 (52) total: 1.73s remaining: 4.44s
84: learn: 0.8438819 test: 0.8202247 best: 0.8370787 (52) total: 1.75s remaining: 4.42s
85: learn: 0.8452883 test: 0.8202247 best: 0.8370787 (52) total: 1.77s remaining: 4.39s
86: learn: 0.8452883 test: 0.8202247 best: 0.8370787 (52) total: 1.79s remaining: 4.37s
87: learn: 0.8438819 test: 0.8202247 best: 0.8370787 (52) total: 1.8s remaining: 4.35s
88: learn: 0.8452883 test: 0.8146067 best: 0.8370787 (52) total: 1.83s remaining: 4.33s
89: learn: 0.8438819 test: 0.8146067 best: 0.8370787 (52) total: 1.85s remaining: 4.31s
90: learn: 0.8438819 test: 0.8146067 best: 0.8370787 (52) total: 1.87s remaining: 4.29s
91: learn: 0.8438819 test: 0.8146067 best: 0.8370787 (52) total: 1.89s remaining: 4.26s
92: learn: 0.8438819 test: 0.8146067 best: 0.8370787 (52) total: 1.91s remaining: 4.24s
93: learn: 0.8438819 test: 0.8258427 best: 0.8370787 (52) total: 1.92s remaining: 4.22s
94: learn: 0.8438819 test: 0.8258427 best: 0.8370787 (52) total: 1.93s remaining: 4.16s
95: learn: 0.8438819 test: 0.8258427 best: 0.8370787 (52) total: 1.95s remaining: 4.14s
96: learn: 0.8438819 test: 0.8202247 best: 0.8370787 (52) total: 1.97s remaining: 4.12s
97: learn: 0.8438819 test: 0.8258427 best: 0.8370787 (52) total: 1.99s remaining: 4.1s
98: learn: 0.8438819 test: 0.8258427 best: 0.8370787 (52) total: 2.01s remaining: 4.08s
99: learn: 0.8438819 test: 0.8258427 best: 0.8370787 (52) total: 2.02s remaining: 4.05s
100: learn: 0.8438819 test: 0.8258427 best: 0.8370787 (52) total: 2.04s remaining: 4.03s
101: learn: 0.8438819 test: 0.8258427 best: 0.8370787 (52) total: 2.06s remaining: 4.01s
102: learn: 0.8438819 test: 0.8258427 best: 0.8370787 (52) total: 2.09s remaining: 3.99s
103: learn: 0.8438819 test: 0.8258427 best: 0.8370787 (52) total: 2.11s remaining: 3.97s
104: learn: 0.8438819 test: 0.8258427 best: 0.8370787 (52) total: 2.13s remaining: 3.95s
105: learn: 0.8452883 test: 0.8202247 best: 0.8370787 (52) total: 2.15s remaining: 3.93s
106: learn: 0.8452883 test: 0.8202247 best: 0.8370787 (52) total: 2.2s remaining: 3.97s
107: learn: 0.8452883 test: 0.8202247 best: 0.8370787 (52) total: 2.24s remaining: 3.99s
108: learn: 0.8452883 test: 0.8146067 best: 0.8370787 (52) total: 2.28s remaining: 3.99s
109: learn: 0.8466948 test: 0.8146067 best: 0.8370787 (52) total: 2.3s remaining: 3.98s
110: learn: 0.8466948 test: 0.8146067 best: 0.8370787 (52) total: 2.32s remaining: 3.96s
111: learn: 0.8466948 test: 0.8146067 best: 0.8370787 (52) total: 2.35s remaining: 3.94s
112: learn: 0.8466948 test: 0.8146067 best: 0.8370787 (52) total: 2.37s remaining: 3.92s
113: learn: 0.8466948 test: 0.8146067 best: 0.8370787 (52) total: 2.39s remaining: 3.9s
114: learn: 0.8466948 test: 0.8146067 best: 0.8370787 (52) total: 2.41s remaining: 3.88s
115: learn: 0.8466948 test: 0.8146067 best: 0.8370787 (52) total: 2.43s remaining: 3.86s
116: learn: 0.8466948 test: 0.8146067 best: 0.8370787 (52) total: 2.45s remaining: 3.84s
117: learn: 0.8452883 test: 0.8202247 best: 0.8370787 (52) total: 2.48s remaining: 3.82s
118: learn: 0.8452883 test: 0.8146067 best: 0.8370787 (52) total: 2.5s remaining: 3.8s
119: learn: 0.8452883 test: 0.8146067 best: 0.8370787 (52) total: 2.52s remaining: 3.78s
120: learn: 0.8452883 test: 0.8146067 best: 0.8370787 (52) total: 2.54s remaining: 3.76s
121: learn: 0.8452883 test: 0.8146067 best: 0.8370787 (52) total: 2.56s remaining: 3.74s
122: learn: 0.8452883 test: 0.8146067 best: 0.8370787 (52) total: 2.59s remaining: 3.72s
123: learn: 0.8452883 test: 0.8146067 best: 0.8370787 (52) total: 2.61s remaining: 3.71s
124: learn: 0.8452883 test: 0.8146067 best: 0.8370787 (52) total: 2.63s remaining: 3.69s
125: learn: 0.8452883 test: 0.8146067 best: 0.8370787 (52) total: 2.65s remaining: 3.67s
126: learn: 0.8452883 test: 0.8146067 best: 0.8370787 (52) total: 2.68s remaining: 3.65s
127: learn: 0.8466948 test: 0.8146067 best: 0.8370787 (52) total: 2.7s remaining: 3.63s
128: learn: 0.8466948 test: 0.8146067 best: 0.8370787 (52) total: 2.72s remaining: 3.61s
129: learn: 0.8466948 test: 0.8146067 best: 0.8370787 (52) total: 2.73s remaining: 3.57s
130: learn: 0.8466948 test: 0.8146067 best: 0.8370787 (52) total: 2.75s remaining: 3.55s
131: learn: 0.8466948 test: 0.8146067 best: 0.8370787 (52) total: 2.77s remaining: 3.53s
132: learn: 0.8466948 test: 0.8146067 best: 0.8370787 (52) total: 2.79s remaining: 3.51s
133: learn: 0.8466948 test: 0.8146067 best: 0.8370787 (52) total: 2.8s remaining: 3.47s
134: learn: 0.8466948 test: 0.8146067 best: 0.8370787 (52) total: 2.83s remaining: 3.46s
135: learn: 0.8466948 test: 0.8146067 best: 0.8370787 (52) total: 2.85s remaining: 3.43s
136: learn: 0.8466948 test: 0.8146067 best: 0.8370787 (52) total: 2.87s remaining: 3.42s
137: learn: 0.8466948 test: 0.8146067 best: 0.8370787 (52) total: 2.89s remaining: 3.4s
138: learn: 0.8481013 test: 0.8146067 best: 0.8370787 (52) total: 2.91s remaining: 3.37s
139: learn: 0.8481013 test: 0.8146067 best: 0.8370787 (52) total: 2.94s remaining: 3.35s
140: learn: 0.8481013 test: 0.8146067 best: 0.8370787 (52) total: 2.96s remaining: 3.33s
141: learn: 0.8481013 test: 0.8146067 best: 0.8370787 (52) total: 2.98s remaining: 3.31s
142: learn: 0.8481013 test: 0.8146067 best: 0.8370787 (52) total: 3s remaining: 3.29s
143: learn: 0.8495077 test: 0.8146067 best: 0.8370787 (52) total: 3.02s remaining: 3.27s
144: learn: 0.8495077 test: 0.8146067 best: 0.8370787 (52) total: 3.04s remaining: 3.25s
145: learn: 0.8495077 test: 0.8146067 best: 0.8370787 (52) total: 3.06s remaining: 3.23s
146: learn: 0.8509142 test: 0.8146067 best: 0.8370787 (52) total: 3.08s remaining: 3.21s
147: learn: 0.8509142 test: 0.8146067 best: 0.8370787 (52) total: 3.1s remaining: 3.19s
148: learn: 0.8523207 test: 0.8146067 best: 0.8370787 (52) total: 3.12s remaining: 3.16s
149: learn: 0.8523207 test: 0.8146067 best: 0.8370787 (52) total: 3.14s remaining: 3.14s
150: learn: 0.8523207 test: 0.8146067 best: 0.8370787 (52) total: 3.16s remaining: 3.12s
151: learn: 0.8523207 test: 0.8146067 best: 0.8370787 (52) total: 3.18s remaining: 3.1s
152: learn: 0.8523207 test: 0.8146067 best: 0.8370787 (52) total: 3.2s remaining: 3.08s
Stopped by overfitting detector (100 iterations wait)
bestTest = 0.8370786517
bestIteration = 52
Shrink model to first 53 iterations.
RandomizedSearchCV(cv=10,
estimator=<catboost.core.CatBoostClassifier object at 0x000002E16600D400>,
n_iter=45, n_jobs=-1,
param_distributions={'border_count': [32, 64, 128],
'depth': [4, 6, 8, 10],
'iterations': [300, 500, 1000],
'l2_leaf_reg': [1, 3, 5, 7, 9],
'learning_rate': [0.01, 0.05, 0.1]},
scoring='accuracy', verbose=2)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
RandomizedSearchCV
?Documentation for RandomizedSearchCViFitted
Parameters
<tr class="user-set">
<td><i class="copy-paste-icon"></i></td>
<td class="param">estimator </td>
<td class="value"><catboost.cor...002E16600D400></td>
</tr>
<tr class="user-set">
<td><i class="copy-paste-icon"></i></td>
<td class="param">param_distributions </td>
<td class="value">{'border_count': [32, 64, ...], 'depth': [4, 6, ...], 'iterations': [300, 500, ...], 'l2_leaf_reg': [1, 3, ...], ...}</td>
</tr>
<tr class="user-set">
<td><i class="copy-paste-icon"></i></td>
<td class="param">n_iter </td>
<td class="value">45</td>
</tr>
<tr class="user-set">
<td><i class="copy-paste-icon"></i></td>
<td class="param">scoring </td>
<td class="value">'accuracy'</td>
</tr>
<tr class="user-set">
<td><i class="copy-paste-icon"></i></td>
<td class="param">n_jobs </td>
<td class="value">-1</td>
</tr>
<tr class="default">
<td><i class="copy-paste-icon"></i></td>
<td class="param">refit </td>
<td class="value">True</td>
</tr>
<tr class="user-set">
<td><i class="copy-paste-icon"></i></td>
<td class="param">cv </td>
<td class="value">10</td>
</tr>
<tr class="user-set">
<td><i class="copy-paste-icon"></i></td>
<td class="param">verbose </td>
<td class="value">2</td>
</tr>
<tr class="default">
<td><i class="copy-paste-icon"></i></td>
<td class="param">pre_dispatch </td>
<td class="value">'2*n_jobs'</td>
</tr>
<tr class="default">
<td><i class="copy-paste-icon"></i></td>
<td class="param">random_state </td>
<td class="value">None</td>
</tr>
<tr class="default">
<td><i class="copy-paste-icon"></i></td>
<td class="param">error_score </td>
<td class="value">nan</td>
</tr>
<tr class="default">
<td><i class="copy-paste-icon"></i></td>
<td class="param">return_train_score </td>
<td class="value">False</td>
</tr>
</tbody>
</table>
</details>
</div>
</div></div></div><div class="sk-parallel"><div class="sk-parallel-item"><div class="sk-item"><div class="sk-label-container"><div class="sk-label fitted sk-toggleable"><input type="checkbox" id="sk-estimator-id-2" class="sk-toggleable__control sk-hidden--visually"><label class="sk-toggleable__label fitted sk-toggleable__label-arrow" for="sk-estimator-id-2"><div><div>best_estimator_: CatBoostClassifier</div></div></label><div data-param-prefix="best_estimator___" class="sk-toggleable__content fitted"><pre><catboost.core.CatBoostClassifier object at 0x000002E169667020></pre></div></div></div><div class="sk-serial"><div class="sk-item"><div class="sk-estimator fitted sk-toggleable"><input type="checkbox" id="sk-estimator-id-3" class="sk-toggleable__control sk-hidden--visually"><label class="sk-toggleable__label fitted sk-toggleable__label-arrow" for="sk-estimator-id-3"><div><div>CatBoostClassifier</div></div></label><div data-param-prefix="best_estimator___" class="sk-toggleable__content fitted"><pre><catboost.core.CatBoostClassifier object at 0x000002E169667020></pre></div></div></div></div></div></div></div></div></div></div>
# Выведем лучшие параметры
random_search.best_params_
{'learning_rate': 0.05,
'l2_leaf_reg': 3,
'iterations': 300,
'depth': 4,
'border_count': 32}
# Выведем лучший скор
random_search.best_score_
np.float64(0.82981220657277)
#Сохраняем best_params в .txt файл, чтобы не потерять
with open("best_params.txt", "a") as f:
json.dump(random_search.best_params_, f,
indent=4)
Оступление
Дважды при малом early_stopping_rounds, равном 30, при n_iter, равном 15 и cs, равном 3
Модель показывала лучший accuracy, но при валидации на Kaggle, показывала результаты хуже
Добавил early_stopping_rounds, n_iter и cs и тогда
Получилось улучшить итоговый результат
# Посмотрим лучшую модель
best_model = random_search.best_estimator_
# Обучаем модель с лучшими параметрами
best_model.fit(X_train_split, y_train_split, **fit_params)
0: learn: 0.7988748 test: 0.7752809 best: 0.7752809 (0) total: 21.9ms remaining: 6.54s
1: learn: 0.8016878 test: 0.7808989 best: 0.7808989 (1) total: 48.9ms remaining: 7.29s
2: learn: 0.8101266 test: 0.7921348 best: 0.7921348 (2) total: 74.5ms remaining: 7.37s
3: learn: 0.8045007 test: 0.7865169 best: 0.7921348 (2) total: 105ms remaining: 7.76s
4: learn: 0.8030942 test: 0.7865169 best: 0.7921348 (2) total: 133ms remaining: 7.85s
5: learn: 0.8087201 test: 0.7977528 best: 0.7977528 (5) total: 159ms remaining: 7.77s
6: learn: 0.8087201 test: 0.7977528 best: 0.7977528 (5) total: 184ms remaining: 7.7s
7: learn: 0.8101266 test: 0.8033708 best: 0.8033708 (7) total: 209ms remaining: 7.63s
8: learn: 0.8101266 test: 0.7977528 best: 0.8033708 (7) total: 234ms remaining: 7.57s
9: learn: 0.8101266 test: 0.7977528 best: 0.8033708 (7) total: 258ms remaining: 7.49s
10: learn: 0.8101266 test: 0.7977528 best: 0.8033708 (7) total: 286ms remaining: 7.52s
11: learn: 0.8115331 test: 0.7977528 best: 0.8033708 (7) total: 314ms remaining: 7.54s
12: learn: 0.8171589 test: 0.7977528 best: 0.8033708 (7) total: 341ms remaining: 7.54s
13: learn: 0.8185654 test: 0.7977528 best: 0.8033708 (7) total: 369ms remaining: 7.53s
14: learn: 0.8185654 test: 0.8033708 best: 0.8033708 (7) total: 392ms remaining: 7.45s
15: learn: 0.8185654 test: 0.8033708 best: 0.8033708 (7) total: 416ms remaining: 7.38s
16: learn: 0.8185654 test: 0.8033708 best: 0.8033708 (7) total: 438ms remaining: 7.29s
17: learn: 0.8185654 test: 0.8033708 best: 0.8033708 (7) total: 456ms remaining: 7.14s
18: learn: 0.8199719 test: 0.8033708 best: 0.8033708 (7) total: 479ms remaining: 7.08s
19: learn: 0.8227848 test: 0.8033708 best: 0.8033708 (7) total: 501ms remaining: 7.02s
20: learn: 0.8227848 test: 0.8033708 best: 0.8033708 (7) total: 524ms remaining: 6.96s
21: learn: 0.8227848 test: 0.8033708 best: 0.8033708 (7) total: 546ms remaining: 6.9s
22: learn: 0.8255977 test: 0.8033708 best: 0.8033708 (7) total: 576ms remaining: 6.93s
23: learn: 0.8270042 test: 0.8033708 best: 0.8033708 (7) total: 584ms remaining: 6.72s
24: learn: 0.8270042 test: 0.8033708 best: 0.8033708 (7) total: 607ms remaining: 6.68s
25: learn: 0.8255977 test: 0.8033708 best: 0.8033708 (7) total: 624ms remaining: 6.57s
26: learn: 0.8255977 test: 0.8033708 best: 0.8033708 (7) total: 646ms remaining: 6.53s
27: learn: 0.8241913 test: 0.7977528 best: 0.8033708 (7) total: 670ms remaining: 6.51s
28: learn: 0.8255977 test: 0.7977528 best: 0.8033708 (7) total: 692ms remaining: 6.47s
29: learn: 0.8255977 test: 0.8033708 best: 0.8033708 (7) total: 716ms remaining: 6.44s
30: learn: 0.8255977 test: 0.8033708 best: 0.8033708 (7) total: 739ms remaining: 6.41s
31: learn: 0.8255977 test: 0.8089888 best: 0.8089888 (31) total: 760ms remaining: 6.37s
32: learn: 0.8284107 test: 0.8089888 best: 0.8089888 (31) total: 784ms remaining: 6.34s
33: learn: 0.8298172 test: 0.8146067 best: 0.8146067 (33) total: 807ms remaining: 6.31s
34: learn: 0.8340366 test: 0.8146067 best: 0.8146067 (33) total: 820ms remaining: 6.21s
35: learn: 0.8354430 test: 0.8146067 best: 0.8146067 (33) total: 845ms remaining: 6.2s
36: learn: 0.8354430 test: 0.8146067 best: 0.8146067 (33) total: 868ms remaining: 6.17s
37: learn: 0.8368495 test: 0.8089888 best: 0.8146067 (33) total: 886ms remaining: 6.11s
38: learn: 0.8382560 test: 0.8089888 best: 0.8146067 (33) total: 908ms remaining: 6.08s
39: learn: 0.8368495 test: 0.8089888 best: 0.8146067 (33) total: 936ms remaining: 6.08s
40: learn: 0.8368495 test: 0.8089888 best: 0.8146067 (33) total: 959ms remaining: 6.05s
41: learn: 0.8382560 test: 0.8089888 best: 0.8146067 (33) total: 979ms remaining: 6.01s
42: learn: 0.8382560 test: 0.8202247 best: 0.8202247 (42) total: 1s remaining: 5.98s
43: learn: 0.8410689 test: 0.8146067 best: 0.8202247 (42) total: 1.02s remaining: 5.94s
44: learn: 0.8396624 test: 0.8146067 best: 0.8202247 (42) total: 1.04s remaining: 5.92s
45: learn: 0.8438819 test: 0.8258427 best: 0.8258427 (45) total: 1.07s remaining: 5.89s
46: learn: 0.8466948 test: 0.8258427 best: 0.8258427 (45) total: 1.09s remaining: 5.85s
47: learn: 0.8466948 test: 0.8258427 best: 0.8258427 (45) total: 1.11s remaining: 5.82s
48: learn: 0.8481013 test: 0.8258427 best: 0.8258427 (45) total: 1.13s remaining: 5.78s
49: learn: 0.8452883 test: 0.8314607 best: 0.8314607 (49) total: 1.15s remaining: 5.75s
50: learn: 0.8438819 test: 0.8314607 best: 0.8314607 (49) total: 1.17s remaining: 5.71s
51: learn: 0.8438819 test: 0.8314607 best: 0.8314607 (49) total: 1.19s remaining: 5.67s
52: learn: 0.8452883 test: 0.8370787 best: 0.8370787 (52) total: 1.21s remaining: 5.63s
53: learn: 0.8424754 test: 0.8370787 best: 0.8370787 (52) total: 1.23s remaining: 5.59s
54: learn: 0.8396624 test: 0.8370787 best: 0.8370787 (52) total: 1.25s remaining: 5.56s
55: learn: 0.8396624 test: 0.8314607 best: 0.8370787 (52) total: 1.27s remaining: 5.53s
56: learn: 0.8396624 test: 0.8314607 best: 0.8370787 (52) total: 1.29s remaining: 5.5s
57: learn: 0.8382560 test: 0.8314607 best: 0.8370787 (52) total: 1.31s remaining: 5.46s
58: learn: 0.8382560 test: 0.8314607 best: 0.8370787 (52) total: 1.32s remaining: 5.41s
59: learn: 0.8396624 test: 0.8314607 best: 0.8370787 (52) total: 1.34s remaining: 5.38s
60: learn: 0.8396624 test: 0.8314607 best: 0.8370787 (52) total: 1.36s remaining: 5.35s
61: learn: 0.8396624 test: 0.8314607 best: 0.8370787 (52) total: 1.38s remaining: 5.31s
62: learn: 0.8396624 test: 0.8314607 best: 0.8370787 (52) total: 1.4s remaining: 5.27s
63: learn: 0.8396624 test: 0.8314607 best: 0.8370787 (52) total: 1.42s remaining: 5.24s
64: learn: 0.8396624 test: 0.8314607 best: 0.8370787 (52) total: 1.44s remaining: 5.21s
65: learn: 0.8396624 test: 0.8314607 best: 0.8370787 (52) total: 1.46s remaining: 5.18s
66: learn: 0.8410689 test: 0.8314607 best: 0.8370787 (52) total: 1.48s remaining: 5.15s
67: learn: 0.8410689 test: 0.8314607 best: 0.8370787 (52) total: 1.5s remaining: 5.13s
68: learn: 0.8410689 test: 0.8314607 best: 0.8370787 (52) total: 1.51s remaining: 5.07s
69: learn: 0.8424754 test: 0.8314607 best: 0.8370787 (52) total: 1.53s remaining: 5.05s
70: learn: 0.8424754 test: 0.8314607 best: 0.8370787 (52) total: 1.56s remaining: 5.03s
71: learn: 0.8424754 test: 0.8314607 best: 0.8370787 (52) total: 1.58s remaining: 5s
72: learn: 0.8438819 test: 0.8314607 best: 0.8370787 (52) total: 1.6s remaining: 4.98s
73: learn: 0.8396624 test: 0.8314607 best: 0.8370787 (52) total: 1.61s remaining: 4.93s
74: learn: 0.8382560 test: 0.8258427 best: 0.8370787 (52) total: 1.64s remaining: 4.91s
75: learn: 0.8410689 test: 0.8258427 best: 0.8370787 (52) total: 1.66s remaining: 4.88s
76: learn: 0.8424754 test: 0.8202247 best: 0.8370787 (52) total: 1.68s remaining: 4.86s
77: learn: 0.8424754 test: 0.8202247 best: 0.8370787 (52) total: 1.7s remaining: 4.83s
78: learn: 0.8438819 test: 0.8202247 best: 0.8370787 (52) total: 1.72s remaining: 4.81s
79: learn: 0.8452883 test: 0.8202247 best: 0.8370787 (52) total: 1.74s remaining: 4.78s
80: learn: 0.8452883 test: 0.8202247 best: 0.8370787 (52) total: 1.76s remaining: 4.76s
81: learn: 0.8452883 test: 0.8202247 best: 0.8370787 (52) total: 1.78s remaining: 4.73s
82: learn: 0.8452883 test: 0.8202247 best: 0.8370787 (52) total: 1.8s remaining: 4.71s
83: learn: 0.8452883 test: 0.8202247 best: 0.8370787 (52) total: 1.82s remaining: 4.68s
84: learn: 0.8438819 test: 0.8202247 best: 0.8370787 (52) total: 1.84s remaining: 4.67s
85: learn: 0.8452883 test: 0.8202247 best: 0.8370787 (52) total: 1.87s remaining: 4.65s
86: learn: 0.8452883 test: 0.8202247 best: 0.8370787 (52) total: 1.89s remaining: 4.64s
87: learn: 0.8438819 test: 0.8202247 best: 0.8370787 (52) total: 1.92s remaining: 4.62s
88: learn: 0.8452883 test: 0.8146067 best: 0.8370787 (52) total: 1.94s remaining: 4.59s
89: learn: 0.8438819 test: 0.8146067 best: 0.8370787 (52) total: 1.96s remaining: 4.58s
90: learn: 0.8438819 test: 0.8146067 best: 0.8370787 (52) total: 1.98s remaining: 4.55s
91: learn: 0.8438819 test: 0.8146067 best: 0.8370787 (52) total: 2s remaining: 4.53s
92: learn: 0.8438819 test: 0.8146067 best: 0.8370787 (52) total: 2.02s remaining: 4.51s
93: learn: 0.8438819 test: 0.8258427 best: 0.8370787 (52) total: 2.05s remaining: 4.49s
94: learn: 0.8438819 test: 0.8258427 best: 0.8370787 (52) total: 2.05s remaining: 4.43s
95: learn: 0.8438819 test: 0.8258427 best: 0.8370787 (52) total: 2.08s remaining: 4.41s
96: learn: 0.8438819 test: 0.8202247 best: 0.8370787 (52) total: 2.1s remaining: 4.39s
97: learn: 0.8438819 test: 0.8258427 best: 0.8370787 (52) total: 2.12s remaining: 4.36s
98: learn: 0.8438819 test: 0.8258427 best: 0.8370787 (52) total: 2.14s remaining: 4.34s
99: learn: 0.8438819 test: 0.8258427 best: 0.8370787 (52) total: 2.16s remaining: 4.32s
100: learn: 0.8438819 test: 0.8258427 best: 0.8370787 (52) total: 2.19s remaining: 4.31s
101: learn: 0.8438819 test: 0.8258427 best: 0.8370787 (52) total: 2.21s remaining: 4.29s
102: learn: 0.8438819 test: 0.8258427 best: 0.8370787 (52) total: 2.23s remaining: 4.27s
103: learn: 0.8438819 test: 0.8258427 best: 0.8370787 (52) total: 2.25s remaining: 4.25s
104: learn: 0.8438819 test: 0.8258427 best: 0.8370787 (52) total: 2.28s remaining: 4.23s
105: learn: 0.8452883 test: 0.8202247 best: 0.8370787 (52) total: 2.3s remaining: 4.21s
106: learn: 0.8452883 test: 0.8202247 best: 0.8370787 (52) total: 2.32s remaining: 4.19s
107: learn: 0.8452883 test: 0.8202247 best: 0.8370787 (52) total: 2.35s remaining: 4.17s
108: learn: 0.8452883 test: 0.8146067 best: 0.8370787 (52) total: 2.37s remaining: 4.15s
109: learn: 0.8466948 test: 0.8146067 best: 0.8370787 (52) total: 2.39s remaining: 4.13s
110: learn: 0.8466948 test: 0.8146067 best: 0.8370787 (52) total: 2.41s remaining: 4.11s
111: learn: 0.8466948 test: 0.8146067 best: 0.8370787 (52) total: 2.44s remaining: 4.09s
112: learn: 0.8466948 test: 0.8146067 best: 0.8370787 (52) total: 2.46s remaining: 4.07s
113: learn: 0.8466948 test: 0.8146067 best: 0.8370787 (52) total: 2.48s remaining: 4.05s
114: learn: 0.8466948 test: 0.8146067 best: 0.8370787 (52) total: 2.5s remaining: 4.02s
115: learn: 0.8466948 test: 0.8146067 best: 0.8370787 (52) total: 2.52s remaining: 4s
116: learn: 0.8466948 test: 0.8146067 best: 0.8370787 (52) total: 2.54s remaining: 3.97s
117: learn: 0.8452883 test: 0.8202247 best: 0.8370787 (52) total: 2.56s remaining: 3.95s
118: learn: 0.8452883 test: 0.8146067 best: 0.8370787 (52) total: 2.58s remaining: 3.92s
119: learn: 0.8452883 test: 0.8146067 best: 0.8370787 (52) total: 2.6s remaining: 3.9s
120: learn: 0.8452883 test: 0.8146067 best: 0.8370787 (52) total: 2.62s remaining: 3.87s
121: learn: 0.8452883 test: 0.8146067 best: 0.8370787 (52) total: 2.64s remaining: 3.85s
122: learn: 0.8452883 test: 0.8146067 best: 0.8370787 (52) total: 2.66s remaining: 3.83s
123: learn: 0.8452883 test: 0.8146067 best: 0.8370787 (52) total: 2.68s remaining: 3.8s
124: learn: 0.8452883 test: 0.8146067 best: 0.8370787 (52) total: 2.7s remaining: 3.78s
125: learn: 0.8452883 test: 0.8146067 best: 0.8370787 (52) total: 2.72s remaining: 3.75s
126: learn: 0.8452883 test: 0.8146067 best: 0.8370787 (52) total: 2.74s remaining: 3.73s
127: learn: 0.8466948 test: 0.8146067 best: 0.8370787 (52) total: 2.76s remaining: 3.71s
128: learn: 0.8466948 test: 0.8146067 best: 0.8370787 (52) total: 2.78s remaining: 3.68s
129: learn: 0.8466948 test: 0.8146067 best: 0.8370787 (52) total: 2.79s remaining: 3.65s
130: learn: 0.8466948 test: 0.8146067 best: 0.8370787 (52) total: 2.81s remaining: 3.62s
131: learn: 0.8466948 test: 0.8146067 best: 0.8370787 (52) total: 2.82s remaining: 3.59s
132: learn: 0.8466948 test: 0.8146067 best: 0.8370787 (52) total: 2.84s remaining: 3.57s
133: learn: 0.8466948 test: 0.8146067 best: 0.8370787 (52) total: 2.85s remaining: 3.53s
134: learn: 0.8466948 test: 0.8146067 best: 0.8370787 (52) total: 2.87s remaining: 3.51s
135: learn: 0.8466948 test: 0.8146067 best: 0.8370787 (52) total: 2.89s remaining: 3.49s
136: learn: 0.8466948 test: 0.8146067 best: 0.8370787 (52) total: 2.91s remaining: 3.47s
137: learn: 0.8466948 test: 0.8146067 best: 0.8370787 (52) total: 2.94s remaining: 3.44s
138: learn: 0.8481013 test: 0.8146067 best: 0.8370787 (52) total: 2.96s remaining: 3.42s
139: learn: 0.8481013 test: 0.8146067 best: 0.8370787 (52) total: 2.97s remaining: 3.4s
140: learn: 0.8481013 test: 0.8146067 best: 0.8370787 (52) total: 3s remaining: 3.38s
141: learn: 0.8481013 test: 0.8146067 best: 0.8370787 (52) total: 3.02s remaining: 3.35s
142: learn: 0.8481013 test: 0.8146067 best: 0.8370787 (52) total: 3.03s remaining: 3.33s
143: learn: 0.8495077 test: 0.8146067 best: 0.8370787 (52) total: 3.05s remaining: 3.31s
144: learn: 0.8495077 test: 0.8146067 best: 0.8370787 (52) total: 3.07s remaining: 3.29s
145: learn: 0.8495077 test: 0.8146067 best: 0.8370787 (52) total: 3.1s remaining: 3.27s
146: learn: 0.8509142 test: 0.8146067 best: 0.8370787 (52) total: 3.12s remaining: 3.24s
147: learn: 0.8509142 test: 0.8146067 best: 0.8370787 (52) total: 3.14s remaining: 3.22s
148: learn: 0.8523207 test: 0.8146067 best: 0.8370787 (52) total: 3.16s remaining: 3.2s
149: learn: 0.8523207 test: 0.8146067 best: 0.8370787 (52) total: 3.18s remaining: 3.18s
150: learn: 0.8523207 test: 0.8146067 best: 0.8370787 (52) total: 3.2s remaining: 3.15s
151: learn: 0.8523207 test: 0.8146067 best: 0.8370787 (52) total: 3.22s remaining: 3.13s
152: learn: 0.8523207 test: 0.8146067 best: 0.8370787 (52) total: 3.24s remaining: 3.11s
Stopped by overfitting detector (100 iterations wait)
bestTest = 0.8370786517
bestIteration = 52
Shrink model to first 53 iterations.
#Оценим качество
acc = accuracy_score(y_valid, y_pred)
print(f"Validation Accuracy: {acc:.4f}")
Validation Accuracy: 0.8315
#Предсказание на тесте
best_test_preds = best_model.predict(X_test)
# Создание submission_V2.csv
submission_V2 = pd.DataFrame({
'PassengerId': passenger_ids,
'Survived': best_test_preds.astype(int)
})
submission_V2.to_csv('submission_V2.csv', index=False)
print("✅ Submission файл сохранён как submission_V2.csv")
✅ Submission файл сохранён как submission_V2.csv
# Смотрим оценку качества
accuracy_score(y_valid, best_model.predict(X_valid))
0.8370786516853933


Смогли улучшить качество модели с помощью подбора гиперпараметров и отвоевать больше 500 мест в итоговом рейтинге