Дерево решений CART (Classification and Regressoin Tree) — алгоритм классификации и регрессии, основанный на бинарном дереве и являющийся фундаментальным компонентом случайного леса и бустингов, которые входят в число самых мощных алгоритмов машинного обучения на сегодняшний день. Деревья также могут быть не бинарными в зависимости от реализации. К другим популярным реализациям решающего дерева относятся следующие: ID3, C4.5, C5.0.
Ноутбук с данными алгоритмами можно загрузить на Kaggle (eng) и GitHub (rus).
Структура дерева решений
Решающее дерево состоит из следующих компонентов: корневой узел, ветви (левая и правая), решающие и листовые (терминальные) узлы. Корневой и решающие узлы представляют из себя вопросы с пороговым значением для разделения тренировочного набора на части (левая и правая), а листья являются конечными прогнозами: среднее значений в листе для регрессии и статистическая мода для классификации.
Каждый листовой узел соответствует определённой прямоугольной области на графике границ решений между двумя признаками. Если на графике соседние участки имеют одинаковое значение, то они автоматически объединяются и представляются как одна большая область.
Выбор наилучшего разбиения
Выбор наилучшего разбиения при построении решающего узла в дереве напоминает игру, в которой нужно угадать знаменитость, задавая вопросы, на которые можно лишь услышать ответ "да" либо "нет". Логично, что для быстрого поиска правильного ответа необходимо задавать вопросы, которые исключат наибольшее количество неверных вариантов, например, вопрос про "пол" позволит исключить сразу же половину вариантов, в то время как вопрос про "возраст" будет менее информативным. Проще говоря, выбор наилучшего вопроса заключается в поиске признака, определённое значение которого лучше всего отделяет правильный ответ от неправильных.
Показатель того, насколько хорошо вопрос в решающем узле позволяет отделить верный ответ от неверных, называется мерой загрязнённости узла. В случае классификации для оценки качества разбиения узла используются следующие критерии:
Неопределённость (загрязнённость) Джини — мера разнообразия в распределении вероятностей классов. Если все элементы в узле принадлежат к одному классу, то неопределённость Джини равна 0, а в случае равномерного распределения классов в узле неопределённость Джини равна 0.5.
Энтропия Шеннона — мера неопределённости или беспорядка классов в узле. Она характеризует количество информации, которое необходимо для описания состояния системы: чем выше значение энтропии, тем менее упорядочена система и наоборот.
Ошибка классификации — величина, отображающая долю неправильно классифицированных элементов в узле: чем меньше данное значение, тем меньше загрязнённость в узле.
В данном случае — это доля k-го класса среди обучающих образцов в i-ом узле.
На практике чаще всего используются неопределённость Джини и энтропия Шеннона за счёт большей информативности. Как видно из графика для случая бинарной классификации (где P+ — вероятность принадлежности к классу "+"), график удвоенной неопределённости Джини очень схож с графиком энтропии Шеннона: в первом случае будут получаться чуть менее сбалансированные деревья, однако при работе с большими датасетами неопределённость Джини более предпочтительна за счёт меньшей вычислительной сложности.
Код для отрисовки графика
import numpy as np
import matplotlib.pyplot as plt
def gini(probas):
return np.array([1- (p ** 2 + (1-p) ** 2) for p in probas])
def entropy(probas):
return np.array([-1 * (p * np.log2(p) + (1-p) * np.log2(1-p)) for p in probas])
def misclass_error_rate(probas):
return np.array([1 - max([p, 1-p]) for p in probas])
probas = np.linspace(0, 1, 250)
plt.plot(probas, entropy(probas), label="Shannon's entropy")
plt.plot(probas, 2 * gini(probas), label="Gini impurity x 2")
plt.plot(probas, 2 * misclass_error_rate(probas), label="Misclass error x 2")
plt.plot(probas, gini(probas), label="Gini impurity")
plt.plot(probas, misclass_error_rate(probas), label="Misclass error")
plt.title("Splitting criteria from P+ (binary classification case)")
plt.xlabel("P+")
plt.ylabel("Impurity")
plt.legend();
В случае регрессии для оценки качества разбиения узла чаще всего используется среднеквадратичная ошибка, но также могут быть использованы Friedman MSE и MAE.
Функция потерь
Так как же в конечном счёте происходит выбор наилучшего разбиения? После выбора одного из критериев оценки качества разбиения узла (например, неопределённость Джини или MSE), для всех уникальных значений признака берутся их пороговые значения, отсортированные по возрастанию и представленные как среднее арифметическое между соседними значениями. Далее обучающий набор разделяется на 2 поднабора (узла): всё что меньше либо равно текущего порогового значения идёт в левый поднабор, а всё что больше — в правый. Для полученных поднаборов рассчитываются загрязнённости на основе выбранного критерия, после чего их взвешенная сумма представляется как функция потерь, значение которой будет соответствовать пороговому значению признака. Порог с наименьшим значением функции потерь в обучающем наборе (поднаборе) будет наилучшим разбиением.
Функции потерь будут иметь следующий вид:
для классификации:
для регрессии:
где .
В случае с энтропией используется немного иной подход: рассчитывается так называемый информационный прирост — разница энтропий родительского и дочерних узлов. Порог с максимальным информационным приростом в обучающем наборе/поднаборе будет соответствовать наилучшему разбиению. :
где — условие (вопрос) для разбиения поднабора .
Для наглядности рассмотрим следующий пример. Допустим, у нас есть 4 кота и 4 собаки, а также нам известны их некоторые визуальные признаки: "рост", "уши" (висячие и заострённые стоячие) и "усы" (наличие либо отсутствие). Как в дереве решений строится корневой узел для классификации собак и котов? Очень просто: сначала для каждого признака животные разделяются на 2 группы согласно вопросу, после чего для каждой из групп рассчитывается неопределённость Джини. Пороговое значение признака с наименьшим значением функции потерь (взвешенной суммой неопределённостей) будет использоваться для разбиения корневого (решающего) узла. В данном случае признак "рост" имеет наименьшую загрязнённость и вопрос в решающем узле будет выглядеть как "рост ⩽ 20 см".
Стоит добавить, что в случае с категориальными признаками происходит их бинаризация и вопросы выглядят следующим образом: "уши ⩽ 0.5", "усы ⩽ 0.5". Когда категориальные признаки могут принимать более 2 значений, применяются другие виды кодирования, например label или one-hot encoding.
Принцип работы дерева решений (CART)
Алгоритм строится следующим образом:
1) создаётся корневой узел на основе наилучшего разбиения;
2) тренировочный набор разбивается на 2 поднабора: всё что соответствует условию разбиения отправляется в левый узел, остальное — в правый;
3) далее рекурсивно для каждого тренировочного поднабора повторяются шаги 1-2 пока не будет достигнут один из основных критериев останова: максимальная глубина, максимальное количество листьев, минимальное количество наблюдений в листе или минимальное снижение загрязнения в узле.
Регуляризация дерева решений
Выращенная без ограничений древовидная структура в той или иной степени будет склонна к переобучению и для решения данной проблемы используется 2 подхода: pre-pruning (ограничение роста дерева во время построения любым из критериев останова) и post-pruning (отсечение лишних ветвей после полного построения). Второй подход является более деликатным так как позволяет получить несимметричную и более точную древовидную структуру, оставляя лишь самые информативные решающие узлы.
Существует 2 типа post-puning'а:
Top-down pruning — метод, при котором проверка и обрезка наименее информативных ветвей начинается с корневого узла. Данный метод обладает относительно низкой вычислительной сложностью, однако, как и в случае с pre-pruning'ом, его главным недостатком также является возможность недообучения за счёт удаления ветвей, которые могли потенциально содержать информативные узлы. К самым известным видам данного прунинга относятся следующие:
Pessimistic Error Pruning (PEP), когда обрезаются ветви с наибольшей ожидаемой ошибкой, порог которой устанавливается заранее;
Critical Value Pruning (CVP), когда обрезаются ветви, информативность которых меньше определённого критического значения.
Bottom-up pruning — метод, при котором проверка и обрезка наименее информативных ветвей начинается с листьев. В данном случае получаются более точные деревья за счёт полного обхода снизу-вверх и оценки каждого решающего узла, однако это приводит к увеличению вычислительной сложности. Самыми популярными видами данного прунинга являются следующие:
Minimum Error Pruning (MEP), когда происходит поиск дерева с наименьшей ожидаемой ошибкой на отложенной выборке;
Reduced Error Pruning (REP), когда решающие узлы удаляются до тех пор, пока не падает точность, измеренная на отложенной выборке;
Cost-complexity pruning (CCP), когда строится серия поддеревьев через удаление слабейших узлов в каждом из них с помощью коэффициента, рассчитанного как разность ошибки корневого узла поддерева и общей ошибки его листьев, а выбор наилучшего поддерева производится на тестовом наборе или с помощью k-fold кросс-валидации.
Minimal cost-complexity pruning
В реализации scikit-learn для деревьев решений используется модификация cost-complexity pruning, которая работает следующим образом:
1) сначала строится полное дерево без ограничений;
2) далее абсолютно для всех узлов в дереве рассчитывается ошибка на основе взвешенной загрязнённости в случае классификации или взвешенной MSE в случае регрессии;
3) для каждого поддерева в дереве подсчитывается совокупная ошибка его листьев;
4) для каждого поддерева в дереве рассчитывается коэффициент альфа, представленный как разность ошибки корневого узла поддерева и совокупная ошибка его листьев;
5) поддерево с наименьшим удаляется и становится листовым узлом, а сам коэффициент хранится в массиве cost_complety_pruning_path и соответствует новому обрезанному дереву;
6) шаги 2-5 рекурсивно повторяются для каждого поддерева до тех пор, пока обрезка не дойдёт до корневого узла.
Если задавать определённое значение изначально, то данный коэффициент применится к каждому поддереву и в итоге останется поддерево с наименьшей ошибкой среди всех поддеревьев, а выбор наилучшего из cost_complety_pruning_path для получения самого точного поддерева производится на тестовом наборе или с помощью k-fold кросс-валидации.
Формулы для расчётов
Регуляризация дерева:
Эффективный :
Ошибка решающего или листового узлов для классификации:
Ошибка решающего или листового узлов для регрессии:
Совокупная ошибка листьев в дереве/поддереве:
T — число терминальных (листовых) узлов.
Импорт необходимых библиотек
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.datasets import load_linnerud
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, mean_absolute_percentage_error
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor, plot_tree
from mlxtend.plotting import plot_decision_regions
from copy import deepcopy
from pprint import pprint
Реализация на Python с нуля
В оригинальном дереве для создания узлов и хранения в них информации используется отдельный класс, но в данном случае дерево и вся информация об узлах хранятся в словаре с ключами в строчном формате. Такое изменение используется для возможности вывода полученного дерева и сравнения с реализацией scikit-learn.
class DecisionTreeCART:
def __init__(self, max_depth=100, min_samples=2, ccp_alpha=0.0, regression=False):
self.max_depth = max_depth
self.min_samples = min_samples
self.ccp_alpha = ccp_alpha
self.regression = regression
self.tree = None
self._y_type = None
self._num_all_samples = None
def _set_df_type(self, X, y, dtype):
X = X.astype(dtype)
y = y.astype(dtype) if self.regression else y
self._y_dtype = y.dtype
return X, y
@staticmethod
def _purity(y):
unique_classes = np.unique(y)
return unique_classes.size == 1
@staticmethod
def _is_leaf_node(node):
return not isinstance(node, dict) # if a node/tree is a leaf
def _leaf_node(self, y):
class_index = 0
return np.mean(y) if self.regression else y.mode()[class_index]
def _split_df(self, X, y, feature, threshold):
feature_values = X[feature]
left_indexes = X[feature_values <= threshold].index
right_indexes = X[feature_values > threshold].index
sizes = np.array([left_indexes.size, right_indexes.size])
return self._leaf_node(y) if any(sizes == 0) else left_indexes, right_indexes
@staticmethod
def _gini_impurity(y):
_, counts_classes = np.unique(y, return_counts=True)
squared_probabilities = np.square(counts_classes / y.size)
gini_impurity = 1 - sum(squared_probabilities)
return gini_impurity
@staticmethod
def _mse(y):
mse = np.mean((y - y.mean()) ** 2)
return mse
@staticmethod
def _cost_function(left_df, right_df, method):
total_df_size = left_df.size + right_df.size
p_left_df = left_df.size / total_df_size
p_right_df = right_df.size / total_df_size
J_left = method(left_df)
J_right = method(right_df)
J = p_left_df*J_left + p_right_df*J_right
return J # weighted Gini impurity or weighted mse (depends on a method)
def _node_error_rate(self, y, method):
if self._num_all_samples is None:
self._num_all_samples = y.size # num samples of all dataframe
current_num_samples = y.size
return current_num_samples / self._num_all_samples * method(y)
def _best_split(self, X, y):
features = X.columns
min_cost_function = np.inf
best_feature, best_threshold = None, None
method = self._mse if self.regression else self._gini_impurity
for feature in features:
unique_feature_values = np.unique(X[feature])
for i in range(1, len(unique_feature_values)):
current_value = unique_feature_values[i]
previous_value = unique_feature_values[i-1]
threshold = (current_value + previous_value) / 2
left_indexes, right_indexes = self._split_df(X, y, feature, threshold)
left_labels, right_labels = y.loc[left_indexes], y.loc[right_indexes]
current_J = self._cost_function(left_labels, right_labels, method)
if current_J <= min_cost_function:
min_cost_function = current_J
best_feature = feature
best_threshold = threshold
return best_feature, best_threshold
def _stopping_conditions(self, y, depth, n_samples):
return self._purity(y), depth == self.max_depth, n_samples < self.min_samples
def _grow_tree(self, X, y, depth=0):
current_num_samples = y.size
X, y = self._set_df_type(X, y, np.float128)
method = self._mse if self.regression else self._gini_impurity
if any(self._stopping_conditions(y, depth, current_num_samples)):
RTi = self._node_error_rate(y, method) # leaf node error rate
leaf_node = f'{self._leaf_node(y)} | error_rate {RTi}'
return leaf_node
Rt = self._node_error_rate(y, method) # decision node error rate
best_feature, best_threshold = self._best_split(X, y)
decision_node = f'{best_feature} <= {best_threshold} | ' \
f'as_leaf {self._leaf_node(y)} error_rate {Rt}'
left_indexes, right_indexes = self._split_df(X, y, best_feature, best_threshold)
left_X, right_X = X.loc[left_indexes], X.loc[right_indexes]
left_labels, right_labels = y.loc[left_indexes], y.loc[right_indexes]
# recursive part
tree = {decision_node: []}
left_subtree = self._grow_tree(left_X, left_labels, depth+1)
right_subtree = self._grow_tree(right_X, right_labels, depth+1)
if left_subtree == right_subtree:
tree = left_subtree
else:
tree[decision_node].extend([left_subtree, right_subtree])
return tree
def _tree_error_rate_info(self, tree, error_rates_list):
if self._is_leaf_node(tree):
*_, leaf_error_rate = tree.split()
error_rates_list.append(np.float128(leaf_error_rate))
else:
decision_node = next(iter(tree))
left_subtree, right_subtree = tree[decision_node]
self._tree_error_rate_info(left_subtree, error_rates_list)
self._tree_error_rate_info(right_subtree, error_rates_list)
RT = sum(error_rates_list) # total leaf error rate of a tree
num_leaf_nodes = len(error_rates_list)
return RT, num_leaf_nodes
@staticmethod
def _ccp_alpha_eff(decision_node_Rt, leaf_nodes_RTt, num_leafs):
return (decision_node_Rt - leaf_nodes_RTt) / (num_leafs - 1)
def _find_weakest_node(self, tree, weakest_node_info):
if self._is_leaf_node(tree):
return tree
decision_node = next(iter(tree))
left_subtree, right_subtree = tree[decision_node]
*_, decision_node_error_rate = decision_node.split()
Rt = np.float128(decision_node_error_rate)
RTt, num_leaf_nodes = self._tree_error_rate_info(tree, [])
ccp_alpha = self._ccp_alpha_eff(Rt, RTt, num_leaf_nodes)
decision_node_index, min_ccp_alpha_index = 0, 1
if ccp_alpha <= weakest_node_info[min_ccp_alpha_index]:
weakest_node_info[decision_node_index] = decision_node
weakest_node_info[min_ccp_alpha_index] = ccp_alpha
self._find_weakest_node(left_subtree, weakest_node_info)
self._find_weakest_node(right_subtree, weakest_node_info)
return weakest_node_info
def _prune_tree(self, tree, weakest_node):
if self._is_leaf_node(tree):
return tree
decision_node = next(iter(tree))
left_subtree, right_subtree = tree[decision_node]
left_subtree_index, right_subtree_index = 0, 1
_, leaf_node = weakest_node.split('as_leaf ')
if weakest_node is decision_node:
tree = weakest_node
if weakest_node in left_subtree:
tree[decision_node][left_subtree_index] = leaf_node
if weakest_node in right_subtree:
tree[decision_node][right_subtree_index] = leaf_node
self._prune_tree(left_subtree, weakest_node)
self._prune_tree(right_subtree, weakest_node)
return tree
def cost_complexity_pruning_path(self, X: pd.DataFrame, y: pd.Series):
tree = self._grow_tree(X, y) # grow a full tree
tree_error_rate, _ = self._tree_error_rate_info(tree, [])
error_rates = [tree_error_rate]
ccp_alpha_list = [0.0]
while not self._is_leaf_node(tree):
initial_node = [None, np.inf]
weakest_node, ccp_alpha = self._find_weakest_node(tree, initial_node)
tree = self._prune_tree(tree, weakest_node)
tree_error_rate, _ = self._tree_error_rate_info(tree, [])
error_rates.append(tree_error_rate)
ccp_alpha_list.append(ccp_alpha)
return np.array(ccp_alpha_list), np.array(error_rates)
def _ccp_tree_error_rate(self, tree_error_rate, num_leaf_nodes):
return tree_error_rate + self.ccp_alpha*num_leaf_nodes # regularization
def _optimal_tree(self, X, y):
tree = self._grow_tree(X, y) # grow a full tree
min_RT_alpha, final_tree = np.inf, None
while not self._is_leaf_node(tree):
RT, num_leaf_nodes = self._tree_error_rate_info(tree, [])
current_RT_alpha = self._ccp_tree_error_rate(RT, num_leaf_nodes)
if current_RT_alpha <= min_RT_alpha:
min_RT_alpha = current_RT_alpha
final_tree = deepcopy(tree)
initial_node = [None, np.inf]
weakest_node, _ = self._find_weakest_node(tree, initial_node)
tree = self._prune_tree(tree, weakest_node)
return final_tree
def fit(self, X: pd.DataFrame, y: pd.Series):
self.tree = self._optimal_tree(X, y)
def _traverse_tree(self, sample, tree):
if self._is_leaf_node(tree):
leaf, *_ = tree.split()
return leaf
decision_node = next(iter(tree)) # dict key
left_node, right_node = tree[decision_node]
feature, other = decision_node.split(' <=')
threshold, *_ = other.split()
feature_value = sample[feature]
if np.float128(feature_value) <= np.float128(threshold):
next_node = self._traverse_tree(sample, left_node) # left_node
else:
next_node = self._traverse_tree(sample, right_node) # right_node
return next_node
def predict(self, samples: pd.DataFrame):
# apply traverse_tree method for each row in a dataframe
results = samples.apply(self._traverse_tree, args=(self.tree,), axis=1)
return np.array(results.astype(self._y_dtype))
Код для отрисовки графиков
def tree_plot(sklearn_tree, Xa_train):
plt.figure(figsize=(12, 18)) # customize according to the size of your tree
plot_tree(sklearn_tree, feature_names=Xa_train.columns, filled=True, precision=6)
plt.show()
def tree_scores_plot(estimator, ccp_alphas, train_data, test_data, metric, labels):
train_scores, test_scores = [], []
X_train, y_train = train_data
X_test, y_test = test_data
x_label, y_label = labels
for ccp_alpha_i in ccp_alphas:
estimator.ccp_alpha = ccp_alpha_i
estimator.fit(X_train, y_train)
train_pred_res = estimator.predict(X_train)
test_pred_res = estimator.predict(X_test)
train_score = metric(y_train, train_pred_res)
test_score = metric(y_test, test_pred_res)
train_scores.append(train_score)
test_scores.append(test_score)
fig, ax = plt.subplots()
ax.set_xlabel(x_label)
ax.set_ylabel(y_label)
ax.set_title(f"{y_label} vs {x_label} for training and testing sets")
ax.plot(ccp_alphas, train_scores, marker="o", label="train", drawstyle="steps-post")
ax.plot(ccp_alphas, test_scores, marker="o", label="test", drawstyle="steps-post")
ax.legend()
plt.show()
def decision_boundary_plot(X, y, X_train, y_train, clf, feature_indexes, title=None):
if y.dtype != 'int':
y = pd.Series(LabelEncoder().fit_transform(y))
y_train = pd.Series(LabelEncoder().fit_transform(y_train))
feature1_name, feature2_name = X.columns[feature_indexes]
X_feature_columns = X.values[:, feature_indexes]
X_train_feature_columns = X_train.values[:, feature_indexes]
clf.fit(X_train_feature_columns, y_train.values)
plot_decision_regions(X=X_feature_columns, y=y.values, clf=clf)
plt.xlabel(feature1_name)
plt.ylabel(feature2_name)
plt.title(title)
Загрузка датасетов
Для обучения моделей будет использован Iris dataset, где необходимо верно определить типы цветков на основе их признаков. В случае регрессии используется load_linnerud dataset из scikit-learn.
df_path = "/content/drive/MyDrive/iris.csv"
iris = pd.read_csv(df_path)
X1, y1 = iris.iloc[:, :-1], iris.iloc[:, -1]
X1_train, X1_test, y1_train, y1_test = train_test_split(X1, y1, test_size=0.3, random_state=0)
print(iris)
sepal_length sepal_width petal_length petal_width species
0 5.1 3.5 1.4 0.2 setosa
1 4.9 3.0 1.4 0.2 setosa
2 4.7 3.2 1.3 0.2 setosa
3 4.6 3.1 1.5 0.2 setosa
4 5.0 3.6 1.4 0.2 setosa
.. ... ... ... ... ...
145 6.7 3.0 5.2 2.3 virginica
146 6.3 2.5 5.0 1.9 virginica
147 6.5 3.0 5.2 2.0 virginica
148 6.2 3.4 5.4 2.3 virginica
149 5.9 3.0 5.1 1.8 virginica
[150 rows x 5 columns]
X2, y2 = load_linnerud(return_X_y=True, as_frame=True)
y2 = y2['Pulse']
X2_train, X2_test, y2_train, y2_test = train_test_split(X2, y2, test_size=0.2, random_state=0)
print(X2, y2, sep='\n')
Chins Situps Jumps
0 5.0 162.0 60.0
1 2.0 110.0 60.0
2 12.0 101.0 101.0
3 12.0 105.0 37.0
4 13.0 155.0 58.0
5 4.0 101.0 42.0
6 8.0 101.0 38.0
7 6.0 125.0 40.0
8 15.0 200.0 40.0
9 17.0 251.0 250.0
10 17.0 120.0 38.0
11 13.0 210.0 115.0
12 14.0 215.0 105.0
13 1.0 50.0 50.0
14 6.0 70.0 31.0
15 12.0 210.0 120.0
16 4.0 60.0 25.0
17 11.0 230.0 80.0
18 15.0 225.0 73.0
19 2.0 110.0 43.0
0 50.0
1 52.0
2 58.0
3 62.0
4 46.0
5 56.0
6 56.0
7 60.0
8 74.0
9 56.0
10 50.0
11 52.0
12 64.0
13 50.0
14 46.0
15 62.0
16 54.0
17 52.0
18 54.0
19 68.0
Name: Pulse, dtype: float64
Обучение моделей и оценка полученных результатов
В случае классификации дерево на данных ирис показало высокую точность. После прунинга точность не увеличилась, зато удалось найти более оптимальное дерево (alpha=0.0143) с такой же точностью и меньшим количеством узлов.
А вот в случае регрессии удалось получить прирост в плане точности, подобрав alpha=3.613, которое создаёт дерево с минимальной ошибкой на тестовом наборе. Все результаты приведены ниже.
Классификация до прунинга
tree_classifier = DecisionTreeCART()
tree_classifier.fit(X1_train, y1_train)
clf_ccp_alphas, _ = tree_classifier.cost_complexity_pruning_path(X1_train, y1_train)
clf_ccp_alphas = clf_ccp_alphas[:-1]
sk_tree_classifier = DecisionTreeClassifier(random_state=0)
sk_tree_classifier.fit(X1_train, y1_train)
sk_clf_path = sk_tree_classifier.cost_complexity_pruning_path(X1_train, y1_train)
sk_clf_ccp_alphas = sk_clf_path.ccp_alphas[:-1]
sk_clf_estimator = DecisionTreeClassifier(random_state=0)
train1_data, test1_data = [X1_train, y1_train], [X1_test, y1_test]
metric = accuracy_score
labels = ['Alpha', 'Accuracy']
pprint(tree_classifier.tree, width=180)
tree_plot(sk_tree_classifier, X1_train)
print(f'tree alphas: {clf_ccp_alphas}', f'sklearn alphas: {sk_clf_ccp_alphas}', sep='\n')
tree_scores_plot(sk_clf_estimator, clf_ccp_alphas, train1_data, test1_data, metric, labels)
{'petal_width <= 0.75 | as_leaf virginica error_rate 0.6643083900226757': ['setosa | error_rate 0.0',
{'petal_length <= 4.95 | as_leaf virginica error_rate 0.33480885311871234': [{'petal_width <= 1.65 | as_leaf versicolor error_rate 0.0521008403361345': ['versicolor '
'| '
'error_rate '
'0.0',
{'sepal_width <= 3.1 | as_leaf virginica error_rate 0.014285714285714287': ['virginica '
'| '
'error_rate '
'0.0',
'versicolor '
'| '
'error_rate '
'0.0']}]},
{'petal_width <= 1.75 | as_leaf virginica error_rate 0.018532818532818515': [{'petal_width <= 1.65 | as_leaf virginica error_rate 0.014285714285714287': ['virginica '
'| '
'error_rate '
'0.0',
'versicolor '
'| '
'error_rate '
'0.0']},
'virginica '
'| '
'error_rate '
'0.0']}]}]}
tree alphas: [0. 0.00926641 0.01428571 0.03781513 0.26417519]
sklearn alphas: [0. 0.00926641 0.01428571 0.03781513 0.26417519]
Классификация после прунинга
tree_clf_prediction = tree_classifier.predict(X1_test)
tree_clf_accuracy = accuracy_score(y1_test, tree_clf_prediction)
sk_tree_clf_prediction = sk_tree_classifier.predict(X1_test)
sk_clf_accuracy = accuracy_score(y1_test, sk_tree_clf_prediction)
best_clf_ccp_alpha = 0.0143 # based on a plot
best_tree_classifier = DecisionTreeCART(ccp_alpha=best_clf_ccp_alpha)
best_tree_classifier.fit(X1_train, y1_train)
best_tree_clf_prediction = best_tree_classifier.predict(X1_test)
best_tree_clf_accuracy = accuracy_score(y1_test, best_tree_clf_prediction)
best_sk_tree_classifier = DecisionTreeClassifier(random_state=0, ccp_alpha=best_clf_ccp_alpha)
best_sk_tree_classifier.fit(X1_train, y1_train)
best_sk_tree_clf_prediction = best_sk_tree_classifier.predict(X1_test)
best_sk_clf_accuracy = accuracy_score(y1_test, best_sk_tree_clf_prediction)
print('tree prediction', tree_clf_prediction, ' ', sep='\n')
print('sklearn prediction', sk_tree_clf_prediction, ' ', sep='\n')
print('best tree prediction', best_tree_clf_prediction, ' ', sep='\n')
print('best sklearn prediction', best_sk_tree_clf_prediction, ' ', sep='\n')
pprint(best_tree_classifier.tree, width=180)
tree_plot(best_sk_tree_classifier, X1_train)
print(f'our tree pruning accuracy: before {tree_clf_accuracy} -> after {best_tree_clf_accuracy}')
print(f'sklearn tree pruning accuracy: before {sk_clf_accuracy} -> after {best_sk_clf_accuracy}')
tree prediction
['virginica' 'versicolor' 'setosa' 'virginica' 'setosa' 'virginica'
'setosa' 'versicolor' 'versicolor' 'versicolor' 'virginica' 'versicolor'
'versicolor' 'versicolor' 'versicolor' 'setosa' 'versicolor' 'versicolor'
'setosa' 'setosa' 'virginica' 'versicolor' 'setosa' 'setosa' 'virginica'
'setosa' 'setosa' 'versicolor' 'versicolor' 'setosa' 'virginica'
'versicolor' 'setosa' 'virginica' 'virginica' 'versicolor' 'setosa'
'virginica' 'versicolor' 'versicolor' 'virginica' 'setosa' 'virginica'
'setosa' 'setosa']
sklearn prediction
['virginica' 'versicolor' 'setosa' 'virginica' 'setosa' 'virginica'
'setosa' 'versicolor' 'versicolor' 'versicolor' 'virginica' 'versicolor'
'versicolor' 'versicolor' 'versicolor' 'setosa' 'versicolor' 'versicolor'
'setosa' 'setosa' 'virginica' 'versicolor' 'setosa' 'setosa' 'virginica'
'setosa' 'setosa' 'versicolor' 'versicolor' 'setosa' 'virginica'
'versicolor' 'setosa' 'virginica' 'virginica' 'versicolor' 'setosa'
'virginica' 'versicolor' 'versicolor' 'virginica' 'setosa' 'virginica'
'setosa' 'setosa']
best tree prediction
['virginica' 'versicolor' 'setosa' 'virginica' 'setosa' 'virginica'
'setosa' 'versicolor' 'versicolor' 'versicolor' 'virginica' 'versicolor'
'versicolor' 'versicolor' 'versicolor' 'setosa' 'versicolor' 'versicolor'
'setosa' 'setosa' 'virginica' 'versicolor' 'setosa' 'setosa' 'virginica'
'setosa' 'setosa' 'versicolor' 'versicolor' 'setosa' 'virginica'
'versicolor' 'setosa' 'virginica' 'virginica' 'versicolor' 'setosa'
'virginica' 'versicolor' 'versicolor' 'virginica' 'setosa' 'virginica'
'setosa' 'setosa']
best sklearn prediction
['virginica' 'versicolor' 'setosa' 'virginica' 'setosa' 'virginica'
'setosa' 'versicolor' 'versicolor' 'versicolor' 'virginica' 'versicolor'
'versicolor' 'versicolor' 'versicolor' 'setosa' 'versicolor' 'versicolor'
'setosa' 'setosa' 'virginica' 'versicolor' 'setosa' 'setosa' 'virginica'
'setosa' 'setosa' 'versicolor' 'versicolor' 'setosa' 'virginica'
'versicolor' 'setosa' 'virginica' 'virginica' 'versicolor' 'setosa'
'virginica' 'versicolor' 'versicolor' 'virginica' 'setosa' 'virginica'
'setosa' 'setosa']
{'petal_width <= 0.75 | as_leaf virginica error_rate 0.6643083900226757': ['setosa | error_rate 0.0',
{'petal_length <= 4.95 | as_leaf virginica error_rate 0.33480885311871234': [{'petal_width <= 1.65 | as_leaf versicolor error_rate 0.0521008403361345': ['versicolor '
'| '
'error_rate '
'0.0',
'virginica '
'error_rate '
'0.014285714285714287']},
'virginica error_rate '
'0.018532818532818515']}]}
our tree pruning accuracy: before 0.9777777777777777 -> after 0.9777777777777777
sklearn tree pruning accuracy: before 0.9777777777777777 -> after 0.9777777777777777
Визуализация решающих границ до и после прунинга
feature_indexes = [2, 3]
title1 = 'Classification tree surface before pruning'
decision_boundary_plot(X1, y1, X1_train, y1_train, sk_tree_classifier, feature_indexes, title1)
feature_indexes = [2, 3]
title2 = 'Classification tree surface after pruning'
decision_boundary_plot(X1, y1, X1_train, y1_train, best_sk_tree_classifier, feature_indexes, title2)
feature_indexes = [2, 3]
plt.figure(figsize=(10, 15))
for i, alpha in enumerate(clf_ccp_alphas):
sk_tree_clf = DecisionTreeClassifier(random_state=0, ccp_alpha=alpha)
plt.subplot(3, 2, i + 1)
plt.subplots_adjust(hspace=0.5)
title = f'ccp_alpha = {alpha}'
decision_boundary_plot(X1, y1, X1_train, y1_train, sk_tree_clf, feature_indexes, title)
Регрессия до прунинга
tree_regressor = DecisionTreeCART(regression=True)
tree_regressor.fit(X2_train, y2_train)
reg_ccp_alphas, _ = tree_regressor.cost_complexity_pruning_path(X2_train, y2_train)
reg_ccp_alphas = reg_ccp_alphas[:-1]
sk_tree_regressor = DecisionTreeRegressor(random_state=0)
sk_tree_regressor.fit(X2_train, y2_train)
sk_reg_path = sk_tree_regressor.cost_complexity_pruning_path(X2_train, y2_train)
sk_reg_ccp_alphas = sk_reg_path.ccp_alphas[:-1]
reg_estimator = DecisionTreeCART(regression=True)
sk_reg_estimator = DecisionTreeRegressor(random_state=0)
train2_data, test2_data = [X2_train, y2_train], [X2_test, y2_test]
metric = mean_absolute_percentage_error
labels = ['Alpha', 'MAPE']
pprint(tree_regressor.tree)
tree_plot(sk_tree_regressor, X2_train)
print(f'CART alphas: {reg_ccp_alphas}')
tree_scores_plot(reg_estimator, reg_ccp_alphas, train2_data, test2_data, metric, labels)
print(f'sklearn_alphas: {sk_reg_ccp_alphas}')
tree_scores_plot(sk_reg_estimator, sk_reg_ccp_alphas, train2_data, test2_data, metric, labels)
{'Jumps <= 90.5 | as_leaf 54.625 error_rate 29.359375': [{'Jumps <= 46.0 | as_leaf 52.90909090909091 error_rate 17.181818181818183': [{'Jumps <= 34.0 | as_leaf 54.857142857142854 error_rate 11.428571428571429': [{'Jumps <= 28.0 | as_leaf 50.0 error_rate 2.0': ['54.0 '
'| '
'error_rate '
'0.0',
'46.0 '
'| '
'error_rate '
'0.0']},
{'Chins <= 14.5 | as_leaf 56.8 error_rate 5.3': [{'Situps <= 103.0 | as_leaf 58.5 error_rate 1.6875': ['56.0 '
'| '
'error_rate '
'0.0',
{'Jumps <= 38.5 | as_leaf 61.0 error_rate 0.125': ['62.0 '
'| '
'error_rate '
'0.0',
'60.0 '
'| '
'error_rate '
'0.0']}]},
'50.0 '
'| '
'error_rate '
'0.0']}]},
{'Chins <= 12.0 | as_leaf 49.5 error_rate 1.1875': [{'Jumps <= 70.0 | as_leaf 50.666666666666664 error_rate 0.16666666666666666': ['50.0 '
'| '
'error_rate '
'0.0',
'52.0 '
'| '
'error_rate '
'0.0']},
'46.0 '
'| '
'error_rate '
'0.0']}]},
{'Jumps <= 110.0 | as_leaf 58.4 error_rate 5.7': [{'Jumps <= 103.0 | as_leaf 61.0 error_rate 1.125': ['58.0 '
'| '
'error_rate '
'0.0',
'64.0 '
'| '
'error_rate '
'0.0']},
{'Chins <= 12.5 | as_leaf 56.666666666666664 error_rate 3.1666666666666665': ['62.0 '
'| '
'error_rate '
'0.0',
{'Jumps <= 182.5 | as_leaf 54.0 error_rate 0.5': ['52.0 '
'| '
'error_rate '
'0.0',
'56.0 '
'| '
'error_rate '
'0.0']}]}]}]}
CART alphas: [0. 0.125 0.16666667 0.5 1.02083333 1.125
1.5625 2. 2.0375 3.6125 4.12857143 4.56574675]
sklearn_alphas: [0. 0.125 0.16666667 0.5 1.02083333 1.125
1.5625 2. 2.0375 3.6125 4.12857143 4.56574675]
Регрессия после прунинга
tree_reg_prediction = tree_regressor.predict(X2_test)
tree_reg_error = mean_absolute_percentage_error(y2_test, tree_reg_prediction)
sk_tree_reg_prediction = sk_tree_regressor.predict(X2_test)
sk_reg_error= mean_absolute_percentage_error(y2_test, sk_tree_reg_prediction)
best_reg_ccp_alpha = 3.613 # based on a plot
best_tree_regressor = DecisionTreeCART(ccp_alpha=best_reg_ccp_alpha, regression=True)
best_tree_regressor.fit(X2_train, y2_train)
best_tree_reg_prediction = best_tree_regressor.predict(X2_test)
lowest_tree_reg_error = mean_absolute_percentage_error(y2_test, best_tree_reg_prediction)
best_sk_tree_regressor = DecisionTreeRegressor(random_state=0, ccp_alpha=best_reg_ccp_alpha)
best_sk_tree_regressor.fit(X2_train, y2_train)
best_sk_tree_reg_prediction = best_sk_tree_regressor.predict(X2_test)
lowest_sk_reg_error = mean_absolute_percentage_error(y2_test, best_sk_tree_reg_prediction)
print('tree prediction', tree_reg_prediction, ' ', sep='\n')
print('sklearn prediction', sk_tree_reg_prediction, ' ', sep='\n')
print('best tree prediction', best_tree_reg_prediction, ' ', sep='\n')
print('best sklearn prediction', best_sk_tree_reg_prediction, ' ', sep='\n')
pprint(best_tree_regressor.tree)
tree_plot(best_sk_tree_regressor, X2_train)
print(f'tree error: before {tree_reg_error} -> after pruning {lowest_tree_reg_error}')
print(f'sklearn tree error: before {sk_reg_error} -> after pruning {lowest_sk_reg_error}')
tree prediction
[46. 50. 60. 50.]
sklearn prediction
[46. 50. 60. 50.]
best tree prediction
[49.5 49.5 56.8 56.8]
best sklearn prediction
[49.5 49.5 56.8 56.8]
{'Jumps <= 90.5 | as_leaf 54.625 error_rate 29.359375': [{'Jumps <= 46.0 | as_leaf 52.90909090909091 error_rate 17.181818181818183': [{'Jumps <= 34.0 | as_leaf 54.857142857142854 error_rate 11.428571428571429': ['50.0 '
'error_rate '
'2.0',
'56.8 '
'error_rate '
'5.3']},
'49.5 '
'error_rate '
'1.1875']},
'58.4 error_rate 5.7']}
tree error: before 0.1571452674393851 -> after pruning 0.1321371427989075
sklearn tree error: before 0.1571452674393851 -> after pruning 0.13213714279890754
Преимущества и недостатки дерева решений
Преимущества:
простота в интерпретации и визуализации;
неплохая работа с нелинейными зависимостями в данных;
не требуется особой подготовки тренировочного набора;
относительно высокая скорость обучения и прогнозирования.
Недостатки:
поиск оптимального дерева является NP-полной задачей;
нестабильность работы даже при небольшом изменении данных;
возможность переобучения из-за чувствительности к шуму и выбросам в данных.
Дополнительные источники
Статья «The CART Decision Tree for Mining Data Streams», Leszek Rutkowskia, Maciej Jaworskia, Lena Pietruczuka, Piotr Dudaa.
Документация:
Лекции: один, два, три, четыре.
Пошаговое построение дерева решений: один, два, три.
Pruning: