Search
Write a publication
Pull to refresh

Titanic + CatBoost (Первое решение, первый Jupyter Notebook)

Level of difficultyEasy
Reading time32 min
Views583
#Импортируем все необходимые библиотеки

import pandas as pd
from catboost import CatBoostClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import json
# 🔕 Отключаем предупреждения, чтобы не загромождали вывод


import warnings
warnings.filterwarnings('ignore')

### Установим красивые дефолтные настройки
### Может быть лень постоянно прописывать
### У графиков параметры цвета, размера, шрифта
### Можно положить их в словарь дефолтных настроек

import matplotlib as mlp

# Сетка (grid)
mlp.rcParams['axes.grid'] = True
mlp.rcParams['grid.color'] = '#D3D3D3'
mlp.rcParams['grid.linestyle'] = '--'
mlp.rcParams['grid.linewidth'] = 1

# Цвет фона
mlp.rcParams['axes.facecolor'] = '#F9F9F9'   # светло-серый фон внутри графика
mlp.rcParams['figure.facecolor'] = '#FFFFFF'  # фон всего холста

# Легенда
mlp.rcParams['legend.fontsize'] = 14
mlp.rcParams['legend.frameon'] = True
mlp.rcParams['legend.framealpha'] = 0.9
mlp.rcParams['legend.edgecolor'] = '#333333'

# Размер фигуры по умолчанию
mlp.rcParams['figure.figsize'] = (10, 6)

# Шрифты
mlp.rcParams['font.family'] = 'DejaVu Sans'  # можешь заменить на 'Arial', 'Roboto' и т.д.
mlp.rcParams['font.size'] = 16

# Цвет осей (спинки)
mlp.rcParams['axes.edgecolor'] = '#333333'
mlp.rcParams['axes.linewidth'] = 2

# Цвет основного текста
mlp.rcParams['text.color'] = '#222222'

# Отдельно скачиваю train...

train_df = pd.read_csv('../data/train.csv')

# ... и отдельно test

test_df = pd.read_csv('../data/test.csv')

# Посмотрим первые 10 строк train'a

train_df.head(10)

PassengerId

Survived

Pclass

Name

Sex

Age

SibSp

Parch

Ticket

Fare

Cabin

Embarked

0

1

0

3

Braund, Mr. Owen Harris

male

22.0

1

0

A/5 21171

7.2500

NaN

S

1

2

1

1

Cumings, Mrs. John Bradley (Florence Briggs Th...

female

38.0

1

0

PC 17599

71.2833

C85

C

2

3

1

3

Heikkinen, Miss. Laina

female

26.0

0

0

STON/O2. 3101282

7.9250

NaN

S

3

4

1

1

Futrelle, Mrs. Jacques Heath (Lily May Peel)

female

35.0

1

0

113803

53.1000

C123

S

4

5

0

3

Allen, Mr. William Henry

male

35.0

0

0

373450

8.0500

NaN

S

5

6

0

3

Moran, Mr. James

male

NaN

0

0

330877

8.4583

NaN

Q

6

7

0

1

McCarthy, Mr. Timothy J

male

54.0

0

0

17463

51.8625

E46

S

7

8

0

3

Palsson, Master. Gosta Leonard

male

2.0

3

1

349909

21.0750

NaN

S

8

9

1

3

Johnson, Mrs. Oscar W (Elisabeth Vilhelmina Berg)

female

27.0

0

2

347742

11.1333

NaN

S

9

10

1

2

Nasser, Mrs. Nicholas (Adele Achem)

female

14.0

1

0

237736

30.0708

NaN

C

# Посмотрим информацию по train'у

train_df.info()

RangeIndex: 891 entries, 0 to 890
Data columns (total 12 columns):
 #   Column       Non-Null Count  Dtype  
---  ------       --------------  -----  
 0   PassengerId  891 non-null    int64  
 1   Survived     891 non-null    int64  
 2   Pclass       891 non-null    int64  
 3   Name         891 non-null    object 
 4   Sex          891 non-null    object 
 5   Age          714 non-null    float64
 6   SibSp        891 non-null    int64  
 7   Parch        891 non-null    int64  
 8   Ticket       891 non-null    object 
 9   Fare         891 non-null    float64
 10  Cabin        204 non-null    object 
 11  Embarked     889 non-null    object 
dtypes: float64(2), int64(5), object(5)
memory usage: 83.7+ KB

Видно, что есть пропуски в признаке возраст. Очень много пропусков в признаке кабина

# Получаем базовую статистику по числовым признакам (среднее, медиана, std и т.д.)

train_df.describe()

PassengerId

Survived

Pclass

Age

SibSp

Parch

Fare

count

891.000000

891.000000

891.000000

714.000000

891.000000

891.000000

891.000000

mean

446.000000

0.383838

2.308642

29.699118

0.523008

0.381594

32.204208

std

257.353842

0.486592

0.836071

14.526497

1.102743

0.806057

49.693429

min

1.000000

0.000000

1.000000

0.420000

0.000000

0.000000

0.000000

25%

223.500000

0.000000

2.000000

20.125000

0.000000

0.000000

7.910400

50%

446.000000

0.000000

3.000000

28.000000

0.000000

0.000000

14.454200

75%

668.500000

1.000000

3.000000

38.000000

1.000000

0.000000

31.000000

max

891.000000

1.000000

3.000000

80.000000

8.000000

6.000000

512.329200

# И статистику по объектным признакам

train_df.describe(include='object')

Name

Sex

Ticket

Cabin

Embarked

count

891

891

891

204

889

unique

891

2

681

147

3

top

Braund, Mr. Owen Harris

male

347082

G6

S

freq

1

577

7

4

644

# Посмотрим на распределение таргета

sns.countplot(x='Survived', data=train_df)
plt.title('Распределение выживших')
plt.show()

Дисбаланса классов не неаблюдаем

# Посмотрим на распределение таргета по признакам

# Пол
plt.figure(figsize=(5,4))
sns.countplot(x='Sex', hue='Survived', data=train_df)
plt.title('Пол и выживание')
plt.show()

# Класс каюты
plt.figure(figsize=(5,4))
sns.countplot(x='Pclass', hue='Survived', data=train_df)
plt.title('Класс каюты и выживание')
plt.show()

# Порт посадки
plt.figure(figsize=(5,4))
sns.countplot(x='Embarked', hue='Survived', data=train_df)
plt.title('Порт посадки и выживание')
plt.show()

png
png
png
png

Видно, что все признаки являются важными для таргета. В противном случае графики для разных признаков были бы одинаковыми.

# Для возраста и платы за проезд посмотрим на ящики с усами

# Возраст
plt.figure(figsize=(6,5))
sns.boxplot(x='Survived', y='Age', data=train_df)
plt.title('Возраст и выживание (boxplot)')
plt.show()

# Fare
plt.figure(figsize=(6,5))
sns.boxplot(x='Survived', y='Fare', data=train_df)
plt.title('Стоимость билета и выживание (boxplot)')
plt.show()

png
png
png
png

Тоже видны различия в зависимости от таргета

# Сделаем списки: категориальные и числовые признаки

categorical_cols = ['Sex', 'Pclass', 'Embarked', 'Cabin']
numeric_cols = ['Age', 'Fare', 'SibSp', 'Parch']
# Посмотрим тепловую карту корреляции между числовыми признаками

plt.figure(figsize=(10,8))
sns.heatmap(train_df.corr(numeric_only=True), annot=True, cmap='coolwarm', fmt=".2f")
plt.title('Корреляция числовых признаков')
plt.show()

png
png

Визуально видно, что мультиколлинеарных признаков нет, но проверим с помощью функции

### Секретная функция со Stackovervlow для фильтрации признаков

def get_redundant_pairs(df):
    pairs_to_drop = set()
    cols = df.columns
    for i in range(0, df.shape[1]):
        for j in range(0, i+1):
            pairs_to_drop.add((cols[i], cols[j]))
    return pairs_to_drop

def get_top_abs_correlations(df, n=5):
    au_corr = df.corr().abs().unstack()
    labels_to_drop = get_redundant_pairs(df)
    au_corr = au_corr.drop(labels=labels_to_drop).sort_values(ascending=False)
    return au_corr[0:n]

print("Top Absolute Correlations")
print(get_top_abs_correlations(train_df[numeric_cols], 10))
Top Absolute Correlations
SibSp  Parch    0.414838
Age    SibSp    0.308247
Fare   Parch    0.216225
Age    Parch    0.189119
Fare   SibSp    0.159651
Age    Fare     0.096067
dtype: float64

Мультиколлинеарности нет - подтверждаем

# Смотрим пропуски

train_df.isnull().sum()

PassengerId      0
Survived         0
Pclass           0
Name             0
Sex              0
Age            177
SibSp            0
Parch            0
Ticket           0
Fare             0
Cabin          687
Embarked         2
dtype: int64
#Удаляем пропуски в колонке "Embarked"
#Так как их всего два
#Сейчас до объединения
#Тренировочных и тестовых данных
#Можно это делать

train_df = train_df.dropna(subset=['Embarked']).copy()

#Смотрим снова

train_df.isnull().sum()

PassengerId      0
Survived         0
Pclass           0
Name             0
Sex              0
Age            177
SibSp            0
Parch            0
Ticket           0
Fare             0
Cabin          687
Embarked         0
dtype: int64

После удаления можно объединять строки. Удаляли до объединения, чтобы не повредить тестовые данные.

Теперь можно готовить данные к объединению и обработке общего датафрейма.

Обязательно нужно всё сделать правильно, чтобы не испортить данные

#Добавим колонку-метку, чтобы потом правильно разделить данные обратно

train_df['is_train'] = 1
test_df['is_train'] = 0

#Добавим фиктивную колонку `Survived` в тест (чтобы структура была одинаковая)

test_df['Survived'] = np.nan

#Сохраняем PassengerId из теста для submission

passenger_ids = test_df['PassengerId'].copy()

(колонка не важна для обучения, но требуется в итоговом файле решения)

#Удаляем колонку PassengerId перед объединением — она не нужна для модели

train_df = train_df.drop(columns=['PassengerId'])
test_df = test_df.drop(columns=['PassengerId'])
#Объединяем тренировочные и тестовые данные для одинаковой обработки данных

full_df = pd.concat([train_df, test_df], axis=0).reset_index(drop=True)

#Пропущенные значения к признаке Age заменяем медианным значением по всем пассажирам
#Но которые были только в тренировочных данных
#Медианное значение менее чувствительно к выбросам в данных

full_df['Age'] = full_df['Age'].fillna(train_df['Age'].median())

# Снова смотрим пропуски, но уже в объединённом датафрейме

full_df.isnull().sum()

Survived     418
Pclass         0
Name           0
Sex            0
Age            0
SibSp          0
Parch          0
Ticket         0
Fare           1
Cabin       1014
Embarked       0
is_train       0
dtype: int64
# Посмотрим ещё раз на данные, чтобы принять решение, что делать с признаком плата за проезд

full_df.head(20)

Survived

Pclass

Name

Sex

Age

SibSp

Parch

Ticket

Fare

Cabin

Embarked

is_train

0

0.0

3

Braund, Mr. Owen Harris

male

22.0

1

0

A/5 21171

7.2500

NaN

S

1

1

1.0

1

Cumings, Mrs. John Bradley (Florence Briggs Th...

female

38.0

1

0

PC 17599

71.2833

C85

C

1

2

1.0

3

Heikkinen, Miss. Laina

female

26.0

0

0

STON/O2. 3101282

7.9250

NaN

S

1

3

1.0

1

Futrelle, Mrs. Jacques Heath (Lily May Peel)

female

35.0

1

0

113803

53.1000

C123

S

1

4

0.0

3

Allen, Mr. William Henry

male

35.0

0

0

373450

8.0500

NaN

S

1

5

0.0

3

Moran, Mr. James

male

28.0

0

0

330877

8.4583

NaN

Q

1

6

0.0

1

McCarthy, Mr. Timothy J

male

54.0

0

0

17463

51.8625

E46

S

1

7

0.0

3

Palsson, Master. Gosta Leonard

male

2.0

3

1

349909

21.0750

NaN

S

1

8

1.0

3

Johnson, Mrs. Oscar W (Elisabeth Vilhelmina Berg)

female

27.0

0

2

347742

11.1333

NaN

S

1

9

1.0

2

Nasser, Mrs. Nicholas (Adele Achem)

female

14.0

1

0

237736

30.0708

NaN

C

1

10

1.0

3

Sandstrom, Miss. Marguerite Rut

female

4.0

1

1

PP 9549

16.7000

G6

S

1

11

1.0

1

Bonnell, Miss. Elizabeth

female

58.0

0

0

113783

26.5500

C103

S

1

12

0.0

3

Saundercock, Mr. William Henry

male

20.0

0

0

A/5. 2151

8.0500

NaN

S

1

13

0.0

3

Andersson, Mr. Anders Johan

male

39.0

1

5

347082

31.2750

NaN

S

1

14

0.0

3

Vestrom, Miss. Hulda Amanda Adolfina

female

14.0

0

0

350406

7.8542

NaN

S

1

15

1.0

2

Hewlett, Mrs. (Mary D Kingcome)

female

55.0

0

0

248706

16.0000

NaN

S

1

16

0.0

3

Rice, Master. Eugene

male

2.0

4

1

382652

29.1250

NaN

Q

1

17

1.0

2

Williams, Mr. Charles Eugene

male

28.0

0

0

244373

13.0000

NaN

S

1

18

0.0

3

Vander Planke, Mrs. Julius (Emelia Maria Vande...

female

31.0

1

0

345763

18.0000

NaN

S

1

19

1.0

3

Masselmani, Mrs. Fatima

female

28.0

0

0

2649

7.2250

NaN

C

1

#Также закодируем и цену за проезд
#Удалять после объединения нельзя - можно удалить строку из тестовых данных

full_df['Fare'] = full_df['Fare'].fillna(train_df['Fare'].median())

# Проверяем

full_df.isnull().sum()

Survived     418
Pclass         0
Name           0
Sex            0
Age            0
SibSp          0
Parch          0
Ticket         0
Fare           0
Cabin       1014
Embarked       0
is_train       0
dtype: int64

Три четверти данных по колонке Cabin в тестовых данных являются NaN. Скорее всего, это пассажиры второго или третьего класса, у которых просто не было собственной кабины. Совсем избавляться от этой колонки, наверное, не стоит — и не обязательно. Вместо этого мы закодируем её как наличие или отсутствие палубы.

# Создаём новый бинарный признак: была ли указана каюта

full_df['Has_Cabin'] = full_df['Cabin'].notnull().astype(int)

# Удаляем оригинальную колонку Cabin, чтобы она не мешала

full_df = full_df.drop(columns='Cabin')

# Посмотрим на наш изменённый датафрейм 

full_df

Survived

Pclass

Name

Sex

Age

SibSp

Parch

Ticket

Fare

Embarked

is_train

Has_Cabin

0

0.0

3

Braund, Mr. Owen Harris

male

22.0

1

0

A/5 21171

7.2500

S

1

0

1

1.0

1

Cumings, Mrs. John Bradley (Florence Briggs Th...

female

38.0

1

0

PC 17599

71.2833

C

1

1

2

1.0

3

Heikkinen, Miss. Laina

female

26.0

0

0

STON/O2. 3101282

7.9250

S

1

0

3

1.0

1

Futrelle, Mrs. Jacques Heath (Lily May Peel)

female

35.0

1

0

113803

53.1000

S

1

1

4

0.0

3

Allen, Mr. William Henry

male

35.0

0

0

373450

8.0500

S

1

0

...

...

...

...

...

...

...

...

...

...

...

...

...

1302

NaN

3

Spector, Mr. Woolf

male

28.0

0

0

A.5. 3236

8.0500

S

0

0

1303

NaN

1

Oliva y Ocana, Dona. Fermina

female

39.0

0

0

PC 17758

108.9000

C

0

1

1304

NaN

3

Saether, Mr. Simon Sivertsen

male

38.5

0

0

SOTON/O.Q. 3101262

7.2500

S

0

0

1305

NaN

3

Ware, Mr. Frederick

male

28.0

0

0

359309

8.0500

S

0

0

1306

NaN

3

Peter, Master. Michael J

male

28.0

1

1

2668

22.3583

C

0

0

1307 rows × 12 columns

Начинаем избавляться от ненужных, не несущих полезной информации для обучения модели, признаков

# Имя и номер билета удаляем

full_df = full_df.drop(columns=['Name','Ticket'])

# Смотрим результат

full_df

Survived

Pclass

Sex

Age

SibSp

Parch

Fare

Embarked

is_train

Has_Cabin

0

0.0

3

male

22.0

1

0

7.2500

S

1

0

1

1.0

1

female

38.0

1

0

71.2833

C

1

1

2

1.0

3

female

26.0

0

0

7.9250

S

1

0

3

1.0

1

female

35.0

1

0

53.1000

S

1

1

4

0.0

3

male

35.0

0

0

8.0500

S

1

0

...

...

...

...

...

...

...

...

...

...

...

1302

NaN

3

male

28.0

0

0

8.0500

S

0

0

1303

NaN

1

female

39.0

0

0

108.9000

C

0

1

1304

NaN

3

male

38.5

0

0

7.2500

S

0

0

1305

NaN

3

male

28.0

0

0

8.0500

S

0

0

1306

NaN

3

male

28.0

1

1

22.3583

C

0

0

1307 rows × 10 columns

В целом не стесняемся часто смотреть и проверять результат. В процессе постоянного отсмотра данных может придти идея, которая улучшит качество модели

Теперь закодируем колонки, которые являются объектами (object), как категории (category). Это особенно важно, если мы собираемся использовать модель CatBoost — она умеет напрямую работать с категориальными признаками и не требует их one-hot-кодирования.

CatBoost сам обработает эти признаки, если они будут иметь тип category, поэтому просто приведём нужные колонки к этому типу.

# Приводим колонки пол и порт посадки к категориальному виду

full_df['Sex'] = full_df['Sex'].astype('category')
full_df['Embarked'] = full_df['Embarked'].astype('category')

# Проверяем

full_df.describe(include='all')

Survived

Pclass

Sex

Age

SibSp

Parch

Fare

Embarked

is_train

Has_Cabin

count

889.000000

1307.000000

1307

1307.000000

1307.000000

1307.000000

1307.000000

1307

1307.000000

1307.000000

unique

NaN

NaN

2

NaN

NaN

NaN

NaN

3

NaN

NaN

top

NaN

NaN

male

NaN

NaN

NaN

NaN

S

NaN

NaN

freq

NaN

NaN

843

NaN

NaN

NaN

NaN

914

NaN

NaN

mean

0.382452

2.296863

NaN

29.471821

0.499617

0.385616

33.209595

NaN

0.680184

0.224178

std

0.486260

0.836942

NaN

12.881592

1.042273

0.866092

51.748768

NaN

0.466584

0.417199

min

0.000000

1.000000

NaN

0.170000

0.000000

0.000000

0.000000

NaN

0.000000

0.000000

25%

0.000000

2.000000

NaN

22.000000

0.000000

0.000000

7.895800

NaN

0.000000

0.000000

50%

0.000000

3.000000

NaN

28.000000

0.000000

0.000000

14.454200

NaN

1.000000

0.000000

75%

1.000000

3.000000

NaN

35.000000

1.000000

0.000000

31.275000

NaN

1.000000

0.000000

max

1.000000

3.000000

NaN

80.000000

8.000000

9.000000

512.329200

NaN

1.000000

1.000000

Кроме того, колонка Pclass изначально имеет тип int, но на самом деле это категориальный признак (класс обслуживания: 1, 2 или 3). Если оставить её как числовую, модель может ошибочно посчитать, что класс 3 «больше» и важнее, чем класс 2, а тот — важнее, чем класс 1. Чтобы избежать этого, мы также приведём Pclass к категориальному типу.

# Приводим колонку класс к категориальному виду

full_df['Pclass'] = full_df['Pclass'].astype('category')

Обработка данных завершена и теперь разделяем данные обратно:

# Разделим обратно:

X_train = full_df[full_df['is_train'] == 1].drop(['is_train', 'Survived'], axis=1)
y_train = full_df[full_df['is_train'] == 1]['Survived']

X_test  = full_df[full_df['is_train'] == 0].drop(['is_train', 'Survived'], axis=1)

# Проверяем размеры

print(X_train.shape, y_train.shape, X_test.shape)

(889, 8) (889,) (418, 8)

Начинаем обучение модели

# Начинаем обучение
# Сначали сплиттим выборку

X_train_split, X_valid, y_train_split, y_valid = train_test_split(
    X_train, y_train,
    test_size=0.2,
    random_state=42,
    stratify=y_train
)

# Положим в список категориальных признаков для CatBoost наши приведённые к типу Category колонки

cat_features = X_train.select_dtypes(include='category').columns.tolist()

# Проверяем

cat_features

['Pclass', 'Sex', 'Embarked']
# Обучаем модель с достаточно средними параметрами
# Пока не используем перебор гиперпараметров

model = CatBoostClassifier(
    iterations=1000,
    learning_rate=0.05,
    depth=6,
    eval_metric='Accuracy',
    random_seed=42,
    early_stopping_rounds=50,
    verbose=100
)

model.fit(
    X_train_split, y_train_split,
    eval_set=(X_valid, y_valid),
    cat_features=cat_features
)

0:	learn: 0.8227848	test: 0.7977528	best: 0.7977528 (0)	total: 160ms	remaining: 2m 39s
Stopped by overfitting detector  (50 iterations wait)

bestTest = 0.8314606742
bestIteration = 32

Shrink model to first 33 iterations.
# Оценим качество
y_pred = model.predict(X_valid)
acc = accuracy_score(y_valid, y_pred)
print(f"Validation Accuracy: {acc:.4f}")

Validation Accuracy: 0.8315

Доля правильных ответов 83,15%

image.png
image.png
image.png
image.png
# Предсказание на тесте

test_preds = model.predict(X_test)


# Посмотрим на предсказания модели

test_preds

array([0., 0., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0., 1., 0., 1., 1., 0.,
       0., 0., 1., 0., 0., 1., 0., 1., 0., 1., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 0., 0., 0., 1., 0., 0.,
       0., 1., 1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 1., 1., 1., 1., 0.,
       0., 1., 1., 0., 0., 0., 1., 0., 0., 1., 0., 1., 1., 0., 0., 0., 0.,
       0., 1., 0., 1., 1., 0., 0., 1., 0., 0., 0., 1., 0., 1., 0., 1., 0.,
       0., 0., 1., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 0., 0., 1., 0.,
       1., 1., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,
       0., 0., 0., 1., 1., 0., 0., 1., 1., 1., 0., 0., 0., 0., 0., 1., 0.,
       0., 0., 0., 0., 0., 1., 1., 0., 1., 1., 0., 0., 1., 0., 1., 0., 1.,
       0., 0., 0., 0., 0., 0., 0., 1., 0., 1., 1., 0., 0., 1., 1., 0., 1.,
       0., 0., 1., 0., 1., 0., 0., 0., 0., 1., 0., 0., 1., 0., 1., 0., 1.,
       0., 1., 0., 1., 1., 0., 1., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,
       1., 1., 1., 1., 0., 0., 0., 0., 1., 0., 1., 1., 1., 0., 0., 0., 0.,
       0., 0., 0., 1., 0., 0., 0., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
       1., 1., 0., 1., 0., 0., 0., 0., 0., 1., 1., 1., 1., 0., 0., 0., 0.,
       0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 1., 1.,
       0., 1., 0., 0., 0., 0., 0., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0.,
       0., 1., 0., 1., 0., 0., 0., 1., 0., 0., 1., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 1., 0., 1., 0., 1., 0., 1., 1., 0., 0., 0., 1., 0., 1.,
       0., 0., 0., 0., 1., 1., 0., 1., 0., 0., 0., 1., 0., 0., 1., 0., 0.,
       1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
       1., 0., 0., 0., 1., 0., 1., 0., 0., 1., 0., 1., 0., 0., 0., 0., 0.,
       1., 1., 1., 1., 1., 0., 1., 0., 0., 0.])
# Создание submission.csv

submission = pd.DataFrame({
    'PassengerId': passenger_ids,
    'Survived': test_preds.astype(int)
})

submission.to_csv('../submissions/submission.csv', index=False)
print("✅ Submission файл сохранён как submission.csv")

✅ Submission файл сохранён как submission.csv
# Посмотрим файл

submission


PassengerId

Survived

0

892

0

1

893

0

2

894

0

3

895

0

4

896

0

...

...

...

413

1305

0

414

1306

1

415

1307

0

416

1308

0

417

1309

0

418 rows × 2 columns

image.png
image.png

Результат Топ 2'400

Теперь попробуем улучшить качество модели с помощью подбора гиперпараметров

Буду использовать случайный подбор параметров с помощью RandomizedSearchCV

# Импортируем модуль

from sklearn.model_selection import RandomizedSearchCV

#Сетка гиперпараметров

param_grid = {
    'depth': [4, 6, 8, 10],              # Максимальная глубина дерева (чем глубже — тем сложнее модель)
    'learning_rate': [0.01, 0.05, 0.1],  # Скорость обучения (маленькое значение = медленнее обучение, но может быть точнее) 
    'iterations': [300, 500, 1000],      # Количество деревьев (итераций бустинга)
    'l2_leaf_reg': [1, 3, 5, 7, 9],      # L2-регуляризация — предотвращает переобучение
    'border_count': [32, 64, 128]        # Количество бинов для дискретизации числовых признаков
}

#Randomized Search с кросс-валидацией

random_search = RandomizedSearchCV(
    estimator=model,
    param_distributions=param_grid,
    n_iter=45,                   # Сколько случайных комбинаций попробовать
    scoring='accuracy',          # Метрика качества, которую нужно максимизировать
    cv=10,                       # Количество фолдов (разбиений) для кросс-валидации
    verbose=2,                   # Показывать процесс обучения в терминале
    n_jobs=-1                    # Использовать все доступные ядра CPU для ускорения
)

# Создаём экземпляр модели

model = CatBoostClassifier(silent=True, random_state=42) # random state фиксированный 

# Фиксируем некоторые параметры для модели

fit_params = {
    "eval_set": [(X_valid, y_valid)], # Набор валидационных данных (для контроля переобучения и использования early stopping)
    "early_stopping_rounds": 100,     # Если метрика не улучшается в течение 100 итераций — обучение остановится
    "cat_features": cat_features,     # Указываем, какие признаки являются категориальными (CatBoost работает с ними нативно)
    "verbose": 1                      # Показывать прогресс обучения во время тренировки
}

#Запуск подбора

random_search.fit(X_train_split, y_train_split, **fit_params)

Fitting 10 folds for each of 45 candidates, totalling 450 fits
0:	learn: 0.7988748	test: 0.7752809	best: 0.7752809 (0)	total: 18.8ms	remaining: 5.63s
1:	learn: 0.8016878	test: 0.7808989	best: 0.7808989 (1)	total: 39.8ms	remaining: 5.93s
2:	learn: 0.8101266	test: 0.7921348	best: 0.7921348 (2)	total: 56.4ms	remaining: 5.58s
3:	learn: 0.8045007	test: 0.7865169	best: 0.7921348 (2)	total: 76.9ms	remaining: 5.69s
4:	learn: 0.8030942	test: 0.7865169	best: 0.7921348 (2)	total: 97.1ms	remaining: 5.73s
5:	learn: 0.8087201	test: 0.7977528	best: 0.7977528 (5)	total: 118ms	remaining: 5.78s
6:	learn: 0.8087201	test: 0.7977528	best: 0.7977528 (5)	total: 139ms	remaining: 5.82s
7:	learn: 0.8101266	test: 0.8033708	best: 0.8033708 (7)	total: 160ms	remaining: 5.85s
8:	learn: 0.8101266	test: 0.7977528	best: 0.8033708 (7)	total: 181ms	remaining: 5.86s
9:	learn: 0.8101266	test: 0.7977528	best: 0.8033708 (7)	total: 201ms	remaining: 5.82s
10:	learn: 0.8101266	test: 0.7977528	best: 0.8033708 (7)	total: 220ms	remaining: 5.77s
11:	learn: 0.8115331	test: 0.7977528	best: 0.8033708 (7)	total: 241ms	remaining: 5.78s
12:	learn: 0.8171589	test: 0.7977528	best: 0.8033708 (7)	total: 262ms	remaining: 5.78s
13:	learn: 0.8185654	test: 0.7977528	best: 0.8033708 (7)	total: 293ms	remaining: 5.98s
14:	learn: 0.8185654	test: 0.8033708	best: 0.8033708 (7)	total: 322ms	remaining: 6.12s
15:	learn: 0.8185654	test: 0.8033708	best: 0.8033708 (7)	total: 348ms	remaining: 6.17s
16:	learn: 0.8185654	test: 0.8033708	best: 0.8033708 (7)	total: 369ms	remaining: 6.14s
17:	learn: 0.8185654	test: 0.8033708	best: 0.8033708 (7)	total: 385ms	remaining: 6.03s
18:	learn: 0.8199719	test: 0.8033708	best: 0.8033708 (7)	total: 407ms	remaining: 6.03s
19:	learn: 0.8227848	test: 0.8033708	best: 0.8033708 (7)	total: 430ms	remaining: 6.02s
20:	learn: 0.8227848	test: 0.8033708	best: 0.8033708 (7)	total: 452ms	remaining: 6s
21:	learn: 0.8227848	test: 0.8033708	best: 0.8033708 (7)	total: 473ms	remaining: 5.98s
22:	learn: 0.8255977	test: 0.8033708	best: 0.8033708 (7)	total: 495ms	remaining: 5.96s
23:	learn: 0.8270042	test: 0.8033708	best: 0.8033708 (7)	total: 501ms	remaining: 5.76s
24:	learn: 0.8270042	test: 0.8033708	best: 0.8033708 (7)	total: 521ms	remaining: 5.73s
25:	learn: 0.8255977	test: 0.8033708	best: 0.8033708 (7)	total: 536ms	remaining: 5.65s
26:	learn: 0.8255977	test: 0.8033708	best: 0.8033708 (7)	total: 558ms	remaining: 5.64s
27:	learn: 0.8241913	test: 0.7977528	best: 0.8033708 (7)	total: 576ms	remaining: 5.6s
28:	learn: 0.8255977	test: 0.7977528	best: 0.8033708 (7)	total: 596ms	remaining: 5.57s
29:	learn: 0.8255977	test: 0.8033708	best: 0.8033708 (7)	total: 615ms	remaining: 5.54s
30:	learn: 0.8255977	test: 0.8033708	best: 0.8033708 (7)	total: 635ms	remaining: 5.51s
31:	learn: 0.8255977	test: 0.8089888	best: 0.8089888 (31)	total: 655ms	remaining: 5.48s
32:	learn: 0.8284107	test: 0.8089888	best: 0.8089888 (31)	total: 674ms	remaining: 5.45s
33:	learn: 0.8298172	test: 0.8146067	best: 0.8146067 (33)	total: 694ms	remaining: 5.43s
34:	learn: 0.8340366	test: 0.8146067	best: 0.8146067 (33)	total: 706ms	remaining: 5.35s
35:	learn: 0.8354430	test: 0.8146067	best: 0.8146067 (33)	total: 728ms	remaining: 5.34s
36:	learn: 0.8354430	test: 0.8146067	best: 0.8146067 (33)	total: 749ms	remaining: 5.32s
37:	learn: 0.8368495	test: 0.8089888	best: 0.8146067 (33)	total: 764ms	remaining: 5.27s
38:	learn: 0.8382560	test: 0.8089888	best: 0.8146067 (33)	total: 780ms	remaining: 5.22s
39:	learn: 0.8368495	test: 0.8089888	best: 0.8146067 (33)	total: 799ms	remaining: 5.19s
40:	learn: 0.8368495	test: 0.8089888	best: 0.8146067 (33)	total: 820ms	remaining: 5.18s
41:	learn: 0.8382560	test: 0.8089888	best: 0.8146067 (33)	total: 840ms	remaining: 5.16s
42:	learn: 0.8382560	test: 0.8202247	best: 0.8202247 (42)	total: 862ms	remaining: 5.15s
43:	learn: 0.8410689	test: 0.8146067	best: 0.8202247 (42)	total: 884ms	remaining: 5.14s
44:	learn: 0.8396624	test: 0.8146067	best: 0.8202247 (42)	total: 907ms	remaining: 5.14s
45:	learn: 0.8438819	test: 0.8258427	best: 0.8258427 (45)	total: 930ms	remaining: 5.14s
46:	learn: 0.8466948	test: 0.8258427	best: 0.8258427 (45)	total: 953ms	remaining: 5.13s
47:	learn: 0.8466948	test: 0.8258427	best: 0.8258427 (45)	total: 976ms	remaining: 5.12s
48:	learn: 0.8481013	test: 0.8258427	best: 0.8258427 (45)	total: 999ms	remaining: 5.12s
49:	learn: 0.8452883	test: 0.8314607	best: 0.8314607 (49)	total: 1.02s	remaining: 5.11s
50:	learn: 0.8438819	test: 0.8314607	best: 0.8314607 (49)	total: 1.04s	remaining: 5.09s
51:	learn: 0.8438819	test: 0.8314607	best: 0.8314607 (49)	total: 1.06s	remaining: 5.07s
52:	learn: 0.8452883	test: 0.8370787	best: 0.8370787 (52)	total: 1.08s	remaining: 5.05s
53:	learn: 0.8424754	test: 0.8370787	best: 0.8370787 (52)	total: 1.1s	remaining: 5.04s
54:	learn: 0.8396624	test: 0.8370787	best: 0.8370787 (52)	total: 1.13s	remaining: 5.01s
55:	learn: 0.8396624	test: 0.8314607	best: 0.8370787 (52)	total: 1.15s	remaining: 4.99s
56:	learn: 0.8396624	test: 0.8314607	best: 0.8370787 (52)	total: 1.17s	remaining: 4.97s
57:	learn: 0.8382560	test: 0.8314607	best: 0.8370787 (52)	total: 1.18s	remaining: 4.94s
58:	learn: 0.8382560	test: 0.8314607	best: 0.8370787 (52)	total: 1.2s	remaining: 4.89s
59:	learn: 0.8396624	test: 0.8314607	best: 0.8370787 (52)	total: 1.22s	remaining: 4.87s
60:	learn: 0.8396624	test: 0.8314607	best: 0.8370787 (52)	total: 1.24s	remaining: 4.85s
61:	learn: 0.8396624	test: 0.8314607	best: 0.8370787 (52)	total: 1.27s	remaining: 4.89s
62:	learn: 0.8396624	test: 0.8314607	best: 0.8370787 (52)	total: 1.29s	remaining: 4.87s
63:	learn: 0.8396624	test: 0.8314607	best: 0.8370787 (52)	total: 1.32s	remaining: 4.86s
64:	learn: 0.8396624	test: 0.8314607	best: 0.8370787 (52)	total: 1.34s	remaining: 4.85s
65:	learn: 0.8396624	test: 0.8314607	best: 0.8370787 (52)	total: 1.36s	remaining: 4.84s
66:	learn: 0.8410689	test: 0.8314607	best: 0.8370787 (52)	total: 1.39s	remaining: 4.82s
67:	learn: 0.8410689	test: 0.8314607	best: 0.8370787 (52)	total: 1.41s	remaining: 4.8s
68:	learn: 0.8410689	test: 0.8314607	best: 0.8370787 (52)	total: 1.42s	remaining: 4.75s
69:	learn: 0.8424754	test: 0.8314607	best: 0.8370787 (52)	total: 1.44s	remaining: 4.74s
70:	learn: 0.8424754	test: 0.8314607	best: 0.8370787 (52)	total: 1.46s	remaining: 4.72s
71:	learn: 0.8424754	test: 0.8314607	best: 0.8370787 (52)	total: 1.49s	remaining: 4.7s
72:	learn: 0.8438819	test: 0.8314607	best: 0.8370787 (52)	total: 1.51s	remaining: 4.69s
73:	learn: 0.8396624	test: 0.8314607	best: 0.8370787 (52)	total: 1.52s	remaining: 4.66s
74:	learn: 0.8382560	test: 0.8258427	best: 0.8370787 (52)	total: 1.55s	remaining: 4.64s
75:	learn: 0.8410689	test: 0.8258427	best: 0.8370787 (52)	total: 1.57s	remaining: 4.63s
76:	learn: 0.8424754	test: 0.8202247	best: 0.8370787 (52)	total: 1.59s	remaining: 4.6s
77:	learn: 0.8424754	test: 0.8202247	best: 0.8370787 (52)	total: 1.61s	remaining: 4.58s
78:	learn: 0.8438819	test: 0.8202247	best: 0.8370787 (52)	total: 1.63s	remaining: 4.56s
79:	learn: 0.8452883	test: 0.8202247	best: 0.8370787 (52)	total: 1.65s	remaining: 4.54s
80:	learn: 0.8452883	test: 0.8202247	best: 0.8370787 (52)	total: 1.67s	remaining: 4.52s
81:	learn: 0.8452883	test: 0.8202247	best: 0.8370787 (52)	total: 1.69s	remaining: 4.49s
82:	learn: 0.8452883	test: 0.8202247	best: 0.8370787 (52)	total: 1.71s	remaining: 4.47s
83:	learn: 0.8452883	test: 0.8202247	best: 0.8370787 (52)	total: 1.73s	remaining: 4.44s
84:	learn: 0.8438819	test: 0.8202247	best: 0.8370787 (52)	total: 1.75s	remaining: 4.42s
85:	learn: 0.8452883	test: 0.8202247	best: 0.8370787 (52)	total: 1.77s	remaining: 4.39s
86:	learn: 0.8452883	test: 0.8202247	best: 0.8370787 (52)	total: 1.79s	remaining: 4.37s
87:	learn: 0.8438819	test: 0.8202247	best: 0.8370787 (52)	total: 1.8s	remaining: 4.35s
88:	learn: 0.8452883	test: 0.8146067	best: 0.8370787 (52)	total: 1.83s	remaining: 4.33s
89:	learn: 0.8438819	test: 0.8146067	best: 0.8370787 (52)	total: 1.85s	remaining: 4.31s
90:	learn: 0.8438819	test: 0.8146067	best: 0.8370787 (52)	total: 1.87s	remaining: 4.29s
91:	learn: 0.8438819	test: 0.8146067	best: 0.8370787 (52)	total: 1.89s	remaining: 4.26s
92:	learn: 0.8438819	test: 0.8146067	best: 0.8370787 (52)	total: 1.91s	remaining: 4.24s
93:	learn: 0.8438819	test: 0.8258427	best: 0.8370787 (52)	total: 1.92s	remaining: 4.22s
94:	learn: 0.8438819	test: 0.8258427	best: 0.8370787 (52)	total: 1.93s	remaining: 4.16s
95:	learn: 0.8438819	test: 0.8258427	best: 0.8370787 (52)	total: 1.95s	remaining: 4.14s
96:	learn: 0.8438819	test: 0.8202247	best: 0.8370787 (52)	total: 1.97s	remaining: 4.12s
97:	learn: 0.8438819	test: 0.8258427	best: 0.8370787 (52)	total: 1.99s	remaining: 4.1s
98:	learn: 0.8438819	test: 0.8258427	best: 0.8370787 (52)	total: 2.01s	remaining: 4.08s
99:	learn: 0.8438819	test: 0.8258427	best: 0.8370787 (52)	total: 2.02s	remaining: 4.05s
100:	learn: 0.8438819	test: 0.8258427	best: 0.8370787 (52)	total: 2.04s	remaining: 4.03s
101:	learn: 0.8438819	test: 0.8258427	best: 0.8370787 (52)	total: 2.06s	remaining: 4.01s
102:	learn: 0.8438819	test: 0.8258427	best: 0.8370787 (52)	total: 2.09s	remaining: 3.99s
103:	learn: 0.8438819	test: 0.8258427	best: 0.8370787 (52)	total: 2.11s	remaining: 3.97s
104:	learn: 0.8438819	test: 0.8258427	best: 0.8370787 (52)	total: 2.13s	remaining: 3.95s
105:	learn: 0.8452883	test: 0.8202247	best: 0.8370787 (52)	total: 2.15s	remaining: 3.93s
106:	learn: 0.8452883	test: 0.8202247	best: 0.8370787 (52)	total: 2.2s	remaining: 3.97s
107:	learn: 0.8452883	test: 0.8202247	best: 0.8370787 (52)	total: 2.24s	remaining: 3.99s
108:	learn: 0.8452883	test: 0.8146067	best: 0.8370787 (52)	total: 2.28s	remaining: 3.99s
109:	learn: 0.8466948	test: 0.8146067	best: 0.8370787 (52)	total: 2.3s	remaining: 3.98s
110:	learn: 0.8466948	test: 0.8146067	best: 0.8370787 (52)	total: 2.32s	remaining: 3.96s
111:	learn: 0.8466948	test: 0.8146067	best: 0.8370787 (52)	total: 2.35s	remaining: 3.94s
112:	learn: 0.8466948	test: 0.8146067	best: 0.8370787 (52)	total: 2.37s	remaining: 3.92s
113:	learn: 0.8466948	test: 0.8146067	best: 0.8370787 (52)	total: 2.39s	remaining: 3.9s
114:	learn: 0.8466948	test: 0.8146067	best: 0.8370787 (52)	total: 2.41s	remaining: 3.88s
115:	learn: 0.8466948	test: 0.8146067	best: 0.8370787 (52)	total: 2.43s	remaining: 3.86s
116:	learn: 0.8466948	test: 0.8146067	best: 0.8370787 (52)	total: 2.45s	remaining: 3.84s
117:	learn: 0.8452883	test: 0.8202247	best: 0.8370787 (52)	total: 2.48s	remaining: 3.82s
118:	learn: 0.8452883	test: 0.8146067	best: 0.8370787 (52)	total: 2.5s	remaining: 3.8s
119:	learn: 0.8452883	test: 0.8146067	best: 0.8370787 (52)	total: 2.52s	remaining: 3.78s
120:	learn: 0.8452883	test: 0.8146067	best: 0.8370787 (52)	total: 2.54s	remaining: 3.76s
121:	learn: 0.8452883	test: 0.8146067	best: 0.8370787 (52)	total: 2.56s	remaining: 3.74s
122:	learn: 0.8452883	test: 0.8146067	best: 0.8370787 (52)	total: 2.59s	remaining: 3.72s
123:	learn: 0.8452883	test: 0.8146067	best: 0.8370787 (52)	total: 2.61s	remaining: 3.71s
124:	learn: 0.8452883	test: 0.8146067	best: 0.8370787 (52)	total: 2.63s	remaining: 3.69s
125:	learn: 0.8452883	test: 0.8146067	best: 0.8370787 (52)	total: 2.65s	remaining: 3.67s
126:	learn: 0.8452883	test: 0.8146067	best: 0.8370787 (52)	total: 2.68s	remaining: 3.65s
127:	learn: 0.8466948	test: 0.8146067	best: 0.8370787 (52)	total: 2.7s	remaining: 3.63s
128:	learn: 0.8466948	test: 0.8146067	best: 0.8370787 (52)	total: 2.72s	remaining: 3.61s
129:	learn: 0.8466948	test: 0.8146067	best: 0.8370787 (52)	total: 2.73s	remaining: 3.57s
130:	learn: 0.8466948	test: 0.8146067	best: 0.8370787 (52)	total: 2.75s	remaining: 3.55s
131:	learn: 0.8466948	test: 0.8146067	best: 0.8370787 (52)	total: 2.77s	remaining: 3.53s
132:	learn: 0.8466948	test: 0.8146067	best: 0.8370787 (52)	total: 2.79s	remaining: 3.51s
133:	learn: 0.8466948	test: 0.8146067	best: 0.8370787 (52)	total: 2.8s	remaining: 3.47s
134:	learn: 0.8466948	test: 0.8146067	best: 0.8370787 (52)	total: 2.83s	remaining: 3.46s
135:	learn: 0.8466948	test: 0.8146067	best: 0.8370787 (52)	total: 2.85s	remaining: 3.43s
136:	learn: 0.8466948	test: 0.8146067	best: 0.8370787 (52)	total: 2.87s	remaining: 3.42s
137:	learn: 0.8466948	test: 0.8146067	best: 0.8370787 (52)	total: 2.89s	remaining: 3.4s
138:	learn: 0.8481013	test: 0.8146067	best: 0.8370787 (52)	total: 2.91s	remaining: 3.37s
139:	learn: 0.8481013	test: 0.8146067	best: 0.8370787 (52)	total: 2.94s	remaining: 3.35s
140:	learn: 0.8481013	test: 0.8146067	best: 0.8370787 (52)	total: 2.96s	remaining: 3.33s
141:	learn: 0.8481013	test: 0.8146067	best: 0.8370787 (52)	total: 2.98s	remaining: 3.31s
142:	learn: 0.8481013	test: 0.8146067	best: 0.8370787 (52)	total: 3s	remaining: 3.29s
143:	learn: 0.8495077	test: 0.8146067	best: 0.8370787 (52)	total: 3.02s	remaining: 3.27s
144:	learn: 0.8495077	test: 0.8146067	best: 0.8370787 (52)	total: 3.04s	remaining: 3.25s
145:	learn: 0.8495077	test: 0.8146067	best: 0.8370787 (52)	total: 3.06s	remaining: 3.23s
146:	learn: 0.8509142	test: 0.8146067	best: 0.8370787 (52)	total: 3.08s	remaining: 3.21s
147:	learn: 0.8509142	test: 0.8146067	best: 0.8370787 (52)	total: 3.1s	remaining: 3.19s
148:	learn: 0.8523207	test: 0.8146067	best: 0.8370787 (52)	total: 3.12s	remaining: 3.16s
149:	learn: 0.8523207	test: 0.8146067	best: 0.8370787 (52)	total: 3.14s	remaining: 3.14s
150:	learn: 0.8523207	test: 0.8146067	best: 0.8370787 (52)	total: 3.16s	remaining: 3.12s
151:	learn: 0.8523207	test: 0.8146067	best: 0.8370787 (52)	total: 3.18s	remaining: 3.1s
152:	learn: 0.8523207	test: 0.8146067	best: 0.8370787 (52)	total: 3.2s	remaining: 3.08s
Stopped by overfitting detector  (100 iterations wait)

bestTest = 0.8370786517
bestIteration = 52

Shrink model to first 53 iterations.
RandomizedSearchCV(cv=10,
                   estimator=<catboost.core.CatBoostClassifier object at 0x000002E16600D400>,
                   n_iter=45, n_jobs=-1,
                   param_distributions={'border_count': [32, 64, 128],
                                        'depth': [4, 6, 8, 10],
                                        'iterations': [300, 500, 1000],
                                        'l2_leaf_reg': [1, 3, 5, 7, 9],
                                        'learning_rate': [0.01, 0.05, 0.1]},
                   scoring='accuracy', verbose=2)

In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.

RandomizedSearchCV

?Documentation for RandomizedSearchCViFitted

Parameters
    <tr class="user-set">
        <td><i class="copy-paste-icon"></i></td>
        <td class="param">estimator&nbsp;</td>
        <td class="value">&lt;catboost.cor...002E16600D400&gt;</td>
    </tr>


    <tr class="user-set">
        <td><i class="copy-paste-icon"></i></td>
        <td class="param">param_distributions&nbsp;</td>
        <td class="value">{'border_count': [32, 64, ...], 'depth': [4, 6, ...], 'iterations': [300, 500, ...], 'l2_leaf_reg': [1, 3, ...], ...}</td>
    </tr>


    <tr class="user-set">
        <td><i class="copy-paste-icon"></i></td>
        <td class="param">n_iter&nbsp;</td>
        <td class="value">45</td>
    </tr>


    <tr class="user-set">
        <td><i class="copy-paste-icon"></i></td>
        <td class="param">scoring&nbsp;</td>
        <td class="value">'accuracy'</td>
    </tr>


    <tr class="user-set">
        <td><i class="copy-paste-icon"></i></td>
        <td class="param">n_jobs&nbsp;</td>
        <td class="value">-1</td>
    </tr>


    <tr class="default">
        <td><i class="copy-paste-icon"></i></td>
        <td class="param">refit&nbsp;</td>
        <td class="value">True</td>
    </tr>


    <tr class="user-set">
        <td><i class="copy-paste-icon"></i></td>
        <td class="param">cv&nbsp;</td>
        <td class="value">10</td>
    </tr>


    <tr class="user-set">
        <td><i class="copy-paste-icon"></i></td>
        <td class="param">verbose&nbsp;</td>
        <td class="value">2</td>
    </tr>


    <tr class="default">
        <td><i class="copy-paste-icon"></i></td>
        <td class="param">pre_dispatch&nbsp;</td>
        <td class="value">'2*n_jobs'</td>
    </tr>


    <tr class="default">
        <td><i class="copy-paste-icon"></i></td>
        <td class="param">random_state&nbsp;</td>
        <td class="value">None</td>
    </tr>


    <tr class="default">
        <td><i class="copy-paste-icon"></i></td>
        <td class="param">error_score&nbsp;</td>
        <td class="value">nan</td>
    </tr>


    <tr class="default">
        <td><i class="copy-paste-icon"></i></td>
        <td class="param">return_train_score&nbsp;</td>
        <td class="value">False</td>
    </tr>

              </tbody>
            </table>
        </details>
    </div>
</div></div></div><div class="sk-parallel"><div class="sk-parallel-item"><div class="sk-item"><div class="sk-label-container"><div class="sk-label fitted sk-toggleable"><input type="checkbox" id="sk-estimator-id-2" class="sk-toggleable__control sk-hidden--visually"><label class="sk-toggleable__label fitted sk-toggleable__label-arrow" for="sk-estimator-id-2"><div><div>best_estimator_: CatBoostClassifier</div></div></label><div data-param-prefix="best_estimator___" class="sk-toggleable__content fitted"><pre>&lt;catboost.core.CatBoostClassifier object at 0x000002E169667020&gt;</pre></div></div></div><div class="sk-serial"><div class="sk-item"><div class="sk-estimator fitted sk-toggleable"><input type="checkbox" id="sk-estimator-id-3" class="sk-toggleable__control sk-hidden--visually"><label class="sk-toggleable__label fitted sk-toggleable__label-arrow" for="sk-estimator-id-3"><div><div>CatBoostClassifier</div></div></label><div data-param-prefix="best_estimator___" class="sk-toggleable__content fitted"><pre>&lt;catboost.core.CatBoostClassifier object at 0x000002E169667020&gt;</pre></div></div></div></div></div></div></div></div></div></div>
# Выведем лучшие параметры

random_search.best_params_

{'learning_rate': 0.05,
 'l2_leaf_reg': 3,
 'iterations': 300,
 'depth': 4,
 'border_count': 32}
# Выведем лучший скор

random_search.best_score_

np.float64(0.82981220657277)
#Сохраняем best_params в .txt файл, чтобы не потерять

with open("best_params.txt", "a") as f:
    json.dump(random_search.best_params_, f,
             indent=4)
    

Оступление

Дважды при малом early_stopping_rounds, равном 30, при n_iter, равном 15 и cs, равном 3

Модель показывала лучший accuracy, но при валидации на Kaggle, показывала результаты хуже

Добавил early_stopping_rounds, n_iter и cs и тогда

Получилось улучшить итоговый результат

# Посмотрим лучшую модель

best_model = random_search.best_estimator_

# Обучаем модель с лучшими параметрами

best_model.fit(X_train_split, y_train_split, **fit_params)

0:	learn: 0.7988748	test: 0.7752809	best: 0.7752809 (0)	total: 21.9ms	remaining: 6.54s
1:	learn: 0.8016878	test: 0.7808989	best: 0.7808989 (1)	total: 48.9ms	remaining: 7.29s
2:	learn: 0.8101266	test: 0.7921348	best: 0.7921348 (2)	total: 74.5ms	remaining: 7.37s
3:	learn: 0.8045007	test: 0.7865169	best: 0.7921348 (2)	total: 105ms	remaining: 7.76s
4:	learn: 0.8030942	test: 0.7865169	best: 0.7921348 (2)	total: 133ms	remaining: 7.85s
5:	learn: 0.8087201	test: 0.7977528	best: 0.7977528 (5)	total: 159ms	remaining: 7.77s
6:	learn: 0.8087201	test: 0.7977528	best: 0.7977528 (5)	total: 184ms	remaining: 7.7s
7:	learn: 0.8101266	test: 0.8033708	best: 0.8033708 (7)	total: 209ms	remaining: 7.63s
8:	learn: 0.8101266	test: 0.7977528	best: 0.8033708 (7)	total: 234ms	remaining: 7.57s
9:	learn: 0.8101266	test: 0.7977528	best: 0.8033708 (7)	total: 258ms	remaining: 7.49s
10:	learn: 0.8101266	test: 0.7977528	best: 0.8033708 (7)	total: 286ms	remaining: 7.52s
11:	learn: 0.8115331	test: 0.7977528	best: 0.8033708 (7)	total: 314ms	remaining: 7.54s
12:	learn: 0.8171589	test: 0.7977528	best: 0.8033708 (7)	total: 341ms	remaining: 7.54s
13:	learn: 0.8185654	test: 0.7977528	best: 0.8033708 (7)	total: 369ms	remaining: 7.53s
14:	learn: 0.8185654	test: 0.8033708	best: 0.8033708 (7)	total: 392ms	remaining: 7.45s
15:	learn: 0.8185654	test: 0.8033708	best: 0.8033708 (7)	total: 416ms	remaining: 7.38s
16:	learn: 0.8185654	test: 0.8033708	best: 0.8033708 (7)	total: 438ms	remaining: 7.29s
17:	learn: 0.8185654	test: 0.8033708	best: 0.8033708 (7)	total: 456ms	remaining: 7.14s
18:	learn: 0.8199719	test: 0.8033708	best: 0.8033708 (7)	total: 479ms	remaining: 7.08s
19:	learn: 0.8227848	test: 0.8033708	best: 0.8033708 (7)	total: 501ms	remaining: 7.02s
20:	learn: 0.8227848	test: 0.8033708	best: 0.8033708 (7)	total: 524ms	remaining: 6.96s
21:	learn: 0.8227848	test: 0.8033708	best: 0.8033708 (7)	total: 546ms	remaining: 6.9s
22:	learn: 0.8255977	test: 0.8033708	best: 0.8033708 (7)	total: 576ms	remaining: 6.93s
23:	learn: 0.8270042	test: 0.8033708	best: 0.8033708 (7)	total: 584ms	remaining: 6.72s
24:	learn: 0.8270042	test: 0.8033708	best: 0.8033708 (7)	total: 607ms	remaining: 6.68s
25:	learn: 0.8255977	test: 0.8033708	best: 0.8033708 (7)	total: 624ms	remaining: 6.57s
26:	learn: 0.8255977	test: 0.8033708	best: 0.8033708 (7)	total: 646ms	remaining: 6.53s
27:	learn: 0.8241913	test: 0.7977528	best: 0.8033708 (7)	total: 670ms	remaining: 6.51s
28:	learn: 0.8255977	test: 0.7977528	best: 0.8033708 (7)	total: 692ms	remaining: 6.47s
29:	learn: 0.8255977	test: 0.8033708	best: 0.8033708 (7)	total: 716ms	remaining: 6.44s
30:	learn: 0.8255977	test: 0.8033708	best: 0.8033708 (7)	total: 739ms	remaining: 6.41s
31:	learn: 0.8255977	test: 0.8089888	best: 0.8089888 (31)	total: 760ms	remaining: 6.37s
32:	learn: 0.8284107	test: 0.8089888	best: 0.8089888 (31)	total: 784ms	remaining: 6.34s
33:	learn: 0.8298172	test: 0.8146067	best: 0.8146067 (33)	total: 807ms	remaining: 6.31s
34:	learn: 0.8340366	test: 0.8146067	best: 0.8146067 (33)	total: 820ms	remaining: 6.21s
35:	learn: 0.8354430	test: 0.8146067	best: 0.8146067 (33)	total: 845ms	remaining: 6.2s
36:	learn: 0.8354430	test: 0.8146067	best: 0.8146067 (33)	total: 868ms	remaining: 6.17s
37:	learn: 0.8368495	test: 0.8089888	best: 0.8146067 (33)	total: 886ms	remaining: 6.11s
38:	learn: 0.8382560	test: 0.8089888	best: 0.8146067 (33)	total: 908ms	remaining: 6.08s
39:	learn: 0.8368495	test: 0.8089888	best: 0.8146067 (33)	total: 936ms	remaining: 6.08s
40:	learn: 0.8368495	test: 0.8089888	best: 0.8146067 (33)	total: 959ms	remaining: 6.05s
41:	learn: 0.8382560	test: 0.8089888	best: 0.8146067 (33)	total: 979ms	remaining: 6.01s
42:	learn: 0.8382560	test: 0.8202247	best: 0.8202247 (42)	total: 1s	remaining: 5.98s
43:	learn: 0.8410689	test: 0.8146067	best: 0.8202247 (42)	total: 1.02s	remaining: 5.94s
44:	learn: 0.8396624	test: 0.8146067	best: 0.8202247 (42)	total: 1.04s	remaining: 5.92s
45:	learn: 0.8438819	test: 0.8258427	best: 0.8258427 (45)	total: 1.07s	remaining: 5.89s
46:	learn: 0.8466948	test: 0.8258427	best: 0.8258427 (45)	total: 1.09s	remaining: 5.85s
47:	learn: 0.8466948	test: 0.8258427	best: 0.8258427 (45)	total: 1.11s	remaining: 5.82s
48:	learn: 0.8481013	test: 0.8258427	best: 0.8258427 (45)	total: 1.13s	remaining: 5.78s
49:	learn: 0.8452883	test: 0.8314607	best: 0.8314607 (49)	total: 1.15s	remaining: 5.75s
50:	learn: 0.8438819	test: 0.8314607	best: 0.8314607 (49)	total: 1.17s	remaining: 5.71s
51:	learn: 0.8438819	test: 0.8314607	best: 0.8314607 (49)	total: 1.19s	remaining: 5.67s
52:	learn: 0.8452883	test: 0.8370787	best: 0.8370787 (52)	total: 1.21s	remaining: 5.63s
53:	learn: 0.8424754	test: 0.8370787	best: 0.8370787 (52)	total: 1.23s	remaining: 5.59s
54:	learn: 0.8396624	test: 0.8370787	best: 0.8370787 (52)	total: 1.25s	remaining: 5.56s
55:	learn: 0.8396624	test: 0.8314607	best: 0.8370787 (52)	total: 1.27s	remaining: 5.53s
56:	learn: 0.8396624	test: 0.8314607	best: 0.8370787 (52)	total: 1.29s	remaining: 5.5s
57:	learn: 0.8382560	test: 0.8314607	best: 0.8370787 (52)	total: 1.31s	remaining: 5.46s
58:	learn: 0.8382560	test: 0.8314607	best: 0.8370787 (52)	total: 1.32s	remaining: 5.41s
59:	learn: 0.8396624	test: 0.8314607	best: 0.8370787 (52)	total: 1.34s	remaining: 5.38s
60:	learn: 0.8396624	test: 0.8314607	best: 0.8370787 (52)	total: 1.36s	remaining: 5.35s
61:	learn: 0.8396624	test: 0.8314607	best: 0.8370787 (52)	total: 1.38s	remaining: 5.31s
62:	learn: 0.8396624	test: 0.8314607	best: 0.8370787 (52)	total: 1.4s	remaining: 5.27s
63:	learn: 0.8396624	test: 0.8314607	best: 0.8370787 (52)	total: 1.42s	remaining: 5.24s
64:	learn: 0.8396624	test: 0.8314607	best: 0.8370787 (52)	total: 1.44s	remaining: 5.21s
65:	learn: 0.8396624	test: 0.8314607	best: 0.8370787 (52)	total: 1.46s	remaining: 5.18s
66:	learn: 0.8410689	test: 0.8314607	best: 0.8370787 (52)	total: 1.48s	remaining: 5.15s
67:	learn: 0.8410689	test: 0.8314607	best: 0.8370787 (52)	total: 1.5s	remaining: 5.13s
68:	learn: 0.8410689	test: 0.8314607	best: 0.8370787 (52)	total: 1.51s	remaining: 5.07s
69:	learn: 0.8424754	test: 0.8314607	best: 0.8370787 (52)	total: 1.53s	remaining: 5.05s
70:	learn: 0.8424754	test: 0.8314607	best: 0.8370787 (52)	total: 1.56s	remaining: 5.03s
71:	learn: 0.8424754	test: 0.8314607	best: 0.8370787 (52)	total: 1.58s	remaining: 5s
72:	learn: 0.8438819	test: 0.8314607	best: 0.8370787 (52)	total: 1.6s	remaining: 4.98s
73:	learn: 0.8396624	test: 0.8314607	best: 0.8370787 (52)	total: 1.61s	remaining: 4.93s
74:	learn: 0.8382560	test: 0.8258427	best: 0.8370787 (52)	total: 1.64s	remaining: 4.91s
75:	learn: 0.8410689	test: 0.8258427	best: 0.8370787 (52)	total: 1.66s	remaining: 4.88s
76:	learn: 0.8424754	test: 0.8202247	best: 0.8370787 (52)	total: 1.68s	remaining: 4.86s
77:	learn: 0.8424754	test: 0.8202247	best: 0.8370787 (52)	total: 1.7s	remaining: 4.83s
78:	learn: 0.8438819	test: 0.8202247	best: 0.8370787 (52)	total: 1.72s	remaining: 4.81s
79:	learn: 0.8452883	test: 0.8202247	best: 0.8370787 (52)	total: 1.74s	remaining: 4.78s
80:	learn: 0.8452883	test: 0.8202247	best: 0.8370787 (52)	total: 1.76s	remaining: 4.76s
81:	learn: 0.8452883	test: 0.8202247	best: 0.8370787 (52)	total: 1.78s	remaining: 4.73s
82:	learn: 0.8452883	test: 0.8202247	best: 0.8370787 (52)	total: 1.8s	remaining: 4.71s
83:	learn: 0.8452883	test: 0.8202247	best: 0.8370787 (52)	total: 1.82s	remaining: 4.68s
84:	learn: 0.8438819	test: 0.8202247	best: 0.8370787 (52)	total: 1.84s	remaining: 4.67s
85:	learn: 0.8452883	test: 0.8202247	best: 0.8370787 (52)	total: 1.87s	remaining: 4.65s
86:	learn: 0.8452883	test: 0.8202247	best: 0.8370787 (52)	total: 1.89s	remaining: 4.64s
87:	learn: 0.8438819	test: 0.8202247	best: 0.8370787 (52)	total: 1.92s	remaining: 4.62s
88:	learn: 0.8452883	test: 0.8146067	best: 0.8370787 (52)	total: 1.94s	remaining: 4.59s
89:	learn: 0.8438819	test: 0.8146067	best: 0.8370787 (52)	total: 1.96s	remaining: 4.58s
90:	learn: 0.8438819	test: 0.8146067	best: 0.8370787 (52)	total: 1.98s	remaining: 4.55s
91:	learn: 0.8438819	test: 0.8146067	best: 0.8370787 (52)	total: 2s	remaining: 4.53s
92:	learn: 0.8438819	test: 0.8146067	best: 0.8370787 (52)	total: 2.02s	remaining: 4.51s
93:	learn: 0.8438819	test: 0.8258427	best: 0.8370787 (52)	total: 2.05s	remaining: 4.49s
94:	learn: 0.8438819	test: 0.8258427	best: 0.8370787 (52)	total: 2.05s	remaining: 4.43s
95:	learn: 0.8438819	test: 0.8258427	best: 0.8370787 (52)	total: 2.08s	remaining: 4.41s
96:	learn: 0.8438819	test: 0.8202247	best: 0.8370787 (52)	total: 2.1s	remaining: 4.39s
97:	learn: 0.8438819	test: 0.8258427	best: 0.8370787 (52)	total: 2.12s	remaining: 4.36s
98:	learn: 0.8438819	test: 0.8258427	best: 0.8370787 (52)	total: 2.14s	remaining: 4.34s
99:	learn: 0.8438819	test: 0.8258427	best: 0.8370787 (52)	total: 2.16s	remaining: 4.32s
100:	learn: 0.8438819	test: 0.8258427	best: 0.8370787 (52)	total: 2.19s	remaining: 4.31s
101:	learn: 0.8438819	test: 0.8258427	best: 0.8370787 (52)	total: 2.21s	remaining: 4.29s
102:	learn: 0.8438819	test: 0.8258427	best: 0.8370787 (52)	total: 2.23s	remaining: 4.27s
103:	learn: 0.8438819	test: 0.8258427	best: 0.8370787 (52)	total: 2.25s	remaining: 4.25s
104:	learn: 0.8438819	test: 0.8258427	best: 0.8370787 (52)	total: 2.28s	remaining: 4.23s
105:	learn: 0.8452883	test: 0.8202247	best: 0.8370787 (52)	total: 2.3s	remaining: 4.21s
106:	learn: 0.8452883	test: 0.8202247	best: 0.8370787 (52)	total: 2.32s	remaining: 4.19s
107:	learn: 0.8452883	test: 0.8202247	best: 0.8370787 (52)	total: 2.35s	remaining: 4.17s
108:	learn: 0.8452883	test: 0.8146067	best: 0.8370787 (52)	total: 2.37s	remaining: 4.15s
109:	learn: 0.8466948	test: 0.8146067	best: 0.8370787 (52)	total: 2.39s	remaining: 4.13s
110:	learn: 0.8466948	test: 0.8146067	best: 0.8370787 (52)	total: 2.41s	remaining: 4.11s
111:	learn: 0.8466948	test: 0.8146067	best: 0.8370787 (52)	total: 2.44s	remaining: 4.09s
112:	learn: 0.8466948	test: 0.8146067	best: 0.8370787 (52)	total: 2.46s	remaining: 4.07s
113:	learn: 0.8466948	test: 0.8146067	best: 0.8370787 (52)	total: 2.48s	remaining: 4.05s
114:	learn: 0.8466948	test: 0.8146067	best: 0.8370787 (52)	total: 2.5s	remaining: 4.02s
115:	learn: 0.8466948	test: 0.8146067	best: 0.8370787 (52)	total: 2.52s	remaining: 4s
116:	learn: 0.8466948	test: 0.8146067	best: 0.8370787 (52)	total: 2.54s	remaining: 3.97s
117:	learn: 0.8452883	test: 0.8202247	best: 0.8370787 (52)	total: 2.56s	remaining: 3.95s
118:	learn: 0.8452883	test: 0.8146067	best: 0.8370787 (52)	total: 2.58s	remaining: 3.92s
119:	learn: 0.8452883	test: 0.8146067	best: 0.8370787 (52)	total: 2.6s	remaining: 3.9s
120:	learn: 0.8452883	test: 0.8146067	best: 0.8370787 (52)	total: 2.62s	remaining: 3.87s
121:	learn: 0.8452883	test: 0.8146067	best: 0.8370787 (52)	total: 2.64s	remaining: 3.85s
122:	learn: 0.8452883	test: 0.8146067	best: 0.8370787 (52)	total: 2.66s	remaining: 3.83s
123:	learn: 0.8452883	test: 0.8146067	best: 0.8370787 (52)	total: 2.68s	remaining: 3.8s
124:	learn: 0.8452883	test: 0.8146067	best: 0.8370787 (52)	total: 2.7s	remaining: 3.78s
125:	learn: 0.8452883	test: 0.8146067	best: 0.8370787 (52)	total: 2.72s	remaining: 3.75s
126:	learn: 0.8452883	test: 0.8146067	best: 0.8370787 (52)	total: 2.74s	remaining: 3.73s
127:	learn: 0.8466948	test: 0.8146067	best: 0.8370787 (52)	total: 2.76s	remaining: 3.71s
128:	learn: 0.8466948	test: 0.8146067	best: 0.8370787 (52)	total: 2.78s	remaining: 3.68s
129:	learn: 0.8466948	test: 0.8146067	best: 0.8370787 (52)	total: 2.79s	remaining: 3.65s
130:	learn: 0.8466948	test: 0.8146067	best: 0.8370787 (52)	total: 2.81s	remaining: 3.62s
131:	learn: 0.8466948	test: 0.8146067	best: 0.8370787 (52)	total: 2.82s	remaining: 3.59s
132:	learn: 0.8466948	test: 0.8146067	best: 0.8370787 (52)	total: 2.84s	remaining: 3.57s
133:	learn: 0.8466948	test: 0.8146067	best: 0.8370787 (52)	total: 2.85s	remaining: 3.53s
134:	learn: 0.8466948	test: 0.8146067	best: 0.8370787 (52)	total: 2.87s	remaining: 3.51s
135:	learn: 0.8466948	test: 0.8146067	best: 0.8370787 (52)	total: 2.89s	remaining: 3.49s
136:	learn: 0.8466948	test: 0.8146067	best: 0.8370787 (52)	total: 2.91s	remaining: 3.47s
137:	learn: 0.8466948	test: 0.8146067	best: 0.8370787 (52)	total: 2.94s	remaining: 3.44s
138:	learn: 0.8481013	test: 0.8146067	best: 0.8370787 (52)	total: 2.96s	remaining: 3.42s
139:	learn: 0.8481013	test: 0.8146067	best: 0.8370787 (52)	total: 2.97s	remaining: 3.4s
140:	learn: 0.8481013	test: 0.8146067	best: 0.8370787 (52)	total: 3s	remaining: 3.38s
141:	learn: 0.8481013	test: 0.8146067	best: 0.8370787 (52)	total: 3.02s	remaining: 3.35s
142:	learn: 0.8481013	test: 0.8146067	best: 0.8370787 (52)	total: 3.03s	remaining: 3.33s
143:	learn: 0.8495077	test: 0.8146067	best: 0.8370787 (52)	total: 3.05s	remaining: 3.31s
144:	learn: 0.8495077	test: 0.8146067	best: 0.8370787 (52)	total: 3.07s	remaining: 3.29s
145:	learn: 0.8495077	test: 0.8146067	best: 0.8370787 (52)	total: 3.1s	remaining: 3.27s
146:	learn: 0.8509142	test: 0.8146067	best: 0.8370787 (52)	total: 3.12s	remaining: 3.24s
147:	learn: 0.8509142	test: 0.8146067	best: 0.8370787 (52)	total: 3.14s	remaining: 3.22s
148:	learn: 0.8523207	test: 0.8146067	best: 0.8370787 (52)	total: 3.16s	remaining: 3.2s
149:	learn: 0.8523207	test: 0.8146067	best: 0.8370787 (52)	total: 3.18s	remaining: 3.18s
150:	learn: 0.8523207	test: 0.8146067	best: 0.8370787 (52)	total: 3.2s	remaining: 3.15s
151:	learn: 0.8523207	test: 0.8146067	best: 0.8370787 (52)	total: 3.22s	remaining: 3.13s
152:	learn: 0.8523207	test: 0.8146067	best: 0.8370787 (52)	total: 3.24s	remaining: 3.11s
Stopped by overfitting detector  (100 iterations wait)

bestTest = 0.8370786517
bestIteration = 52

Shrink model to first 53 iterations.
#Оценим качество

acc = accuracy_score(y_valid, y_pred)
print(f"Validation Accuracy: {acc:.4f}")

Validation Accuracy: 0.8315
#Предсказание на тесте

best_test_preds = best_model.predict(X_test)

# Создание submission_V2.csv

submission_V2 = pd.DataFrame({
    'PassengerId': passenger_ids,
    'Survived': best_test_preds.astype(int)
})

submission_V2.to_csv('submission_V2.csv', index=False)
print("✅ Submission файл сохранён как submission_V2.csv")

✅ Submission файл сохранён как submission_V2.csv
# Смотрим оценку качества

accuracy_score(y_valid, best_model.predict(X_valid))

0.8370786516853933
image.png
image.png
image.png
image.png

Смогли улучшить качество модели с помощью подбора гиперпараметров и отвоевать больше 500 мест в итоговом рейтинге

Tags:
Hubs:
-3
Comments0

Articles