Pull to refresh
VK
Building the Internet

Повышаем интерпретируемость SHAP-графиков

Level of difficultyMedium
Reading time8 min
Views6.6K

Привет, Хабр! В моей работе часто возникают задачи на исследование влияния факторов, на которые мы можем оказывать продуктовое влияние, на целевые метрики сообществ ВКонтакте. Один из возможных способов решения подобных задач — обучение ML‑моделей и последующий анализ значимости признаков в них. Базовым подходом видится использование графиков из библиотеки shap. Однако наиболее популярным является summary_plot, хотя он и повышает интерпретируемость модели, но отвечает не на все возникающие вопросы.

Меня зовут Сергей Королёв, я продуктовый аналитик в бизнес‑юните СМБ в VK, занимаюсь улучшением опыта предпринимателей на нашей платформе. В этой статье я представлю свое решение по кастомизации shap.dependence_plot для простого восприятия графиков влияния факторов на целевую метрику.

Недостатки summary_plot

Рассмотрим summary_plot на примере датасета о выживаемости пассажиров «Титаника». В качестве модели используем catboost с минимальной предобработкой признаков и оставим только количественные и категориальные признаки с ограниченным количеством категорий.

Код обработки данных и обучения модели:

titanic_train = pd.read_csv('titanic/train.csv')

titanic_train['Age'] = titanic_train.Age.fillna(titanic_train.Age.dropna().median())
titanic_train['Embarked'] = titanic_train.Embarked.fillna(titanic_train.Embarked.dropna().mode()[0])

X = titanic_train.drop(columns=['Survived', 'Cabin', 'Ticket', 'PassengerId', 'Name'])
y = titanic_train.Survived

scale_pos_weight = int(y.value_counts()[0] / y.value_counts()[1])
model = CatBoostClassifier(
    subsample=0.66, rsm=0.5, depth=3, cat_features=['Sex', 'Embarked'],
    random_seed=42, verbose=False, scale_pos_weight=scale_pos_weight
)
model.fit(X, y)

Построенная модель обучена на признаках:

  • Pclass — класс, которым путешествовал пассажир;

  • Sex — пол пассажира;

  • Age — возраст пассажира;

  • SibSp — количество братьев и сестер или супругов на борту;

  • Parch — количество родителей или детей на борту;

  • Fare — стоимость билета;

  • Embarked — порт посадки.

Их значимость для модели можно отобразить с помощью кода:

explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(X)
shap.summary_plot(shap_values, X);

По графику можно понять, что пол пассажира и класс являются наиболее важными признаками. И заметить, что пассажиры с большими значениями класса, соответствующими более низким классам обслуживания, по мнению модели имели меньше шансов выжить. Но есть ряд вопросов, которые важны при анализе влияния признаков на целевую переменную, и которые нельзя понять по этому графику:

  1. Какие значения категориальных признаков и как влияют на решения, принимаемые моделью?

  2. Какие абсолютные значения количественных признаков являются высокими, а какие низкими?

  3. Сколько объектов имеют определённое значение признака?

  4. Есть ли доминирующие признаки, или их влияние на решение модели сопоставимо?

На все эти вопросы можно ответить, кастомизировав другой тип графиков, доступный в библиотеке SHAP — dependence_plot.

Обход недостатков summary_plot

Влияние конкретных значений категориальных и количественных признаков на решения, принимаемые моделью, можно увидеть, используя dependence_plot, который строится для каждого признака отдельно. На графиках по оси Х выводится значение признака, а по оси Y — значение SHAP для него.

Код dependence plot:

shap.dependence_plot('Sex', shap_values, X, interaction_index=None)
shap.dependence_plot('Age', shap_values, X, interaction_index=None)

Для категориальных признаков график выглядит так:

Для количественных переменных:

Для наглядности можно добавить сетку и ярче выделить нулевую линию, отделяющую позитивное влияние от негативного.

Чтобы по графику можно было понять, какие значения признака встречаются в данных чаще, для количественных признаков можно использовать график типа scatter из библиотеки shap, который совмещает dependence_plot с гистограммой распределения признака:

shap_explainer = explainer(X)
shap.plots.scatter(shap_explainer[:, 'Age']);

Однако такой вид графиков не поддерживает категориальные переменные. Поэтому более универсальным решением будет использование параметра прозрачности alpha для dependence_plot, за счет снижения которого более часто встречающиеся значения будут отображаться на графике интенсивнее. Код dependence plot с настройкой прозрачности и сеткой:

fig, ax = plt.subplots();
ax.grid();
plt.axhline(y=0, color='grey', linewidth=2.5);
shap.dependence_plot('Age', shap_values, X, interaction_index=None, alpha=0.2, ax=ax);

Последним важным для меня недостатком summary_plot стало то, что он по умолчанию сортируется по среднему вкладу признаков в решения модели, но не отображает этот вклад. Чтобы оценить его, можно воспользоваться отдельным графиком типа bar.

shap.plots.bar(shap_explainer)

Либо можно собрать dependence_plot по всем признакам на одном графике и вынести на них относительную значимость, рассчитанную по матрице с SHAP values.

После устранения с помощью dependence_plot недостатков, характерных для summary_plot, возникает ряд дополнительных проблем, снижающих удобство пользования графиком:

  1. Некоторые количественные признаки, с которыми мне приходится работать на практике, могут иметь распределение с длинным правым хвостом, близкое к логнормальному. Или вовсе иметь большую часть нулевых значений и длинный правый хвост. Поэтому хочется иметь возможность отсекать выбросы на графике, чтобы смотреть на основные зависимости в удобном масштабе. Для отсечения выбросов я использовал порог в 1,5 межквартильных размаха от первого и третьего квартилей. Либо, если межквартильный размах для признака равен 0, то по 2,5 % значений с каждой стороны.

  2. Категориальные признаки также могут содержать много категорий, одновременное отображение которых на одном графике сделает его нечитаемым. Как правило, среди этих категорий будет несколько частых и много редких. Поэтому хочется иметь фильтрацию, оставляющую ограниченное количество категорий с наибольшим количеством наблюдений.

При исследовании влияния признаков на целевую переменную также может возникнуть желание смотреть на те наблюдения, где ML-модель в своих предсказаниях близка к реальному значению целевой переменной. Поэтому полезно добавить отсечение задаваемого процента наименее точных предсказаний.

Для удобства использования все описанные идеи можно реализовать в рамках одной функции.

Итоговая функция отрисовки значимости признаков
def plot_shap_feature_importances(
    model, X_test, y_test, continuous_target, predictions_trashold, cat_features, 
    alpha, plots_folder, target_col
):
    '''
    :param model: Обученая ML-модель (на деревьях)
    :param X_test: Тестовый датафрейм с предикторами
    :param y_test: Тестовый датафрейм с таргетом
    :param continuous_target: Флаг непрерывной целевой переменной (регрессионная модель)
    :param predictions_trashold: Доля откидываемых наблюдений с наибольшими ошибками предсказания
    :param cat_features: Список категориальных фичей
    :param alpha: Непрозрачность заливки точек
    :param plots_folder: Путь к папке сохранения графиков
    :param target_col: Название целевой переменной для сохранения графиков
    '''

    # Отбрасываем заданный процент не точных предсказаний
    if continuous_target:
        y_preds = model.predict(X_test)
    else:
        y_preds = model.predict_proba(X_test)[:, 1]
    errors = np.abs(y_test - y_preds)
    mask = errors <= np.quantile(errors, 1 - predictions_trashold)
    X_test_for_plot = X_test[mask].copy()
 
    explainer = shap.TreeExplainer(model)
    shap_values = explainer.shap_values(X_test_for_plot)
 
    # Очистка чтобы график не накладывался на предыдущий
    plt.close('all')
    shap.summary_plot(shap_values, X_test_for_plot, show=False);
    plt.savefig(f'{plots_folder}/shap_summary_{target_col}.png', bbox_inches='tight');
 
    # Смотрим относительное влияние фичей
    vals = np.abs(shap_values).mean(0)
    fi = pd.DataFrame(
        list(zip(X_test_for_plot.columns, vals)),
        columns=['features', 'importance']
    )
    fi.sort_values(by=['importance'], ascending=False, inplace=True)
    fi['importance'] = fi.importance / fi.importance.sum()
 
    # Отрисовываем отдельные графики
    subplots_number = len(X_test_for_plot.columns)
    nrows = int(np.ceil(subplots_number / 2))
    fig, axes = plt.subplots(
        nrows=nrows, ncols=[1, 2][int(subplots_number > 1)], 
        figsize=(12, nrows * 5), layout='constrained'
    )
 
    for idx, col in enumerate(zip(fi.features.to_list(), fi.importance.to_list())):
 
        if col[0] in cat_features:
            # Для категориальных фичей отрисовывем топ-10 категорий
            important_cats = list(X_test[col[0]].value_counts().index)
            important_cats = important_cats[:min(10, len(important_cats))]
            temp_df = X_test_for_plot[X_test_for_plot[col[0]].isin(important_cats)].copy()
        else:
            # Для количественных смотрим базовую стратегию отсечения по межквартильному интервалу
            q1 = X_test[col[0]].replace(np.inf, np.nan).dropna().quantile(0.25)
            q3 = X_test[col[0]].replace(np.inf, np.nan).dropna().quantile(0.75)
            iqr = q3 - q1
            min_trashold = max(
                q1 - 1.5 * iqr,
                X_test[col[0]].replace(np.inf, np.nan).dropna().min()
            )
            max_trashold = min(
                q3 + 1.5 * iqr,
                X_test[col[0]].replace(np.inf, np.nan).dropna().max()
            )
 
            if X_test[col[0]].nunique() <= 50:
                # Если в количественной фиче не более 50 значений, оставляем их все
                temp_df = X_test_for_plot.copy()
            elif iqr > 0:
                # Если есть межквартильный размах, отсекаем лежащее за перделеами 1.5 его величин
                temp_df = X_test_for_plot[
                    (X_test_for_plot[col[0]] >= min_trashold) &
                    (X_test_for_plot[col[0]] <= max_trashold)
                ].copy()
            else:
                # Если межквартильный размах равен 0, то отсекаем по 2.5% с каждой стороны
                q025 = X_test[col[0]].replace(np.inf, np.nan).dropna().quantile(0.025)
                q975 = X_test[col[0]].replace(np.inf, np.nan).dropna().quantile(0.975)
                temp_df = X_test_for_plot[
                    (X_test_for_plot[col[0]] >= q025) &
                    (X_test_for_plot[col[0]] <= q975)
                ].copy()
 
        # Смотрим SHAP только для релевантных наблюдений
        shap_values = explainer.shap_values(temp_df)
 
        # Определяем позицию графика для отрисовки текущей фичи
        if nrows > 1:
            ax = axes[idx // 2, idx % 2]
        elif subplots_number > 1:
            ax = axes[idx % 2]
        else:
            ax = axes
 
        # Отрисовываем график по фиче
        shap.dependence_plot(
            col[0], shap_values, temp_df, ax=ax,
            interaction_index=None, show=False, alpha=alpha
        )
        ax.grid();
        xlims = ax.get_xlim()
        ax.hlines(y=0, xmin=xlims[0], xmax=xlims[1], color='grey', linewidth=2.5)
        ax.set_title(f'{col[0]} ({col[1]:.1%})');
        if col in cat_features:
            ax.tick_params(axis='x', labelsize=6, labelrotation=90)
 
    # Удаляем не используемый график
    if (nrows > 1) and (subplots_number % 2 == 1):
        fig.delaxes(axes[nrows - 1][1])
     
    # Сохраняем полученный график
    plt.savefig(
        f'{plots_folder}/shap_dependencies_{target_col}.png', bbox_inches='tight'
    );

В качестве входных аргументов эта функция принимает:

  • обученную модель; 

  • тестовый набор предикторов и значения целевой переменной в тестовой выборке;

  • флаг непрерывности переменной для определения способа оценки ошибок предсказания; 

  • долю откидываемых наименее точных предсказаний; 

  • список категориальных признаков для определения способа отсечения выбросов по каждому из предикторов; 

  • параметр прозрачности; 

  • а также папку для сохранения графиков и общий модификатор названия для них.

По результатам работы данная функция сохраняет summary_plot и собранный график состоящий из dependence_plot для каждой фичи в указанную папку под указанным названием.

До и после

С графика с одной визуализацией мы начинали анализировать влияния признаков на целевую переменную.

И пришли к графику, объединяющему несколько подграфиков, на котором каждый признак вынесен отдельно. А каждый из подграфиков показывает влияние конкретных значений признака на решение модели и относительную значимость признака для решений модели. 

Заключение

После создания этой функции я стал использовать её при отрисовке графиков для анализа влияния признаков на целевые метрики в обученных ML-моделях. Её основные преимущества для меня по сравнению со стандартным summary_plot:

  • По одному графику можно понять важность признаков и их влияние на решения модели.

  • Проще визуально воспринимать пороговые значения, с которых начинается позитивное влияние на целевую переменную.

  • Проще донести интерпретацию модели до менеджеров.

Tags:
Hubs:
Total votes 43: ↑43 and ↓0+52
Comments0

Articles

Information

Website
vk.com
Registered
Founded
Employees
5,001–10,000 employees
Location
Россия
Representative
Миша Берггрен