Как стать автором
Обновить
765.57
OTUS
Цифровые навыки от ведущих экспертов

Кастомные loss-функции в TensorFlow/Keras и PyTorch

Уровень сложностиПростой
Время на прочтение5 мин
Количество просмотров965

Привет, Хабр!

Стандартные loss‑функции, такие как MSE или CrossEntropy, хороши, но часто им не хватает гибкости для сложных задач. Допустим, есть тот же проект с огромным дисбалансом классов, или хочется внедрить специфическую регуляризацию прямо в функцию потерь. Стандартный функционал тут бессилен — тут на помощь приходят кастомные loss'ы.

Custom Loss Functions в TensorFlow/Keras

TensorFlow/Keras радуют удобным API, но за простоту приходится платить вниманием к деталям.

Focal Loss

Focal Loss помогает сместить фокус обучения на сложные примеры, снижая влияние легко классифицируемых данных:

import tensorflow as tf
from tensorflow.keras import backend as K

def focal_loss(gamma=2., alpha=0.25):
    """
    Реализация Focal Loss для задач с дисбалансом классов.
    :param gamma: фокусирующий параметр для усиления влияния сложных примеров.
    :param alpha: коэффициент балансировки классов.
    :return: функция потерь, принимающая (y_true, y_pred).
    """
    def focal_loss_fixed(y_true, y_pred):
        # Защита от log(0) – обрезаем значения предсказаний.
        y_pred = K.clip(y_pred, K.epsilon(), 1. - K.epsilon())
        # Вычисляем кросс-энтропию для каждого примера.
        cross_entropy = -y_true * tf.math.log(y_pred)
        # Применяем вес для "тяжёлых" примеров.
        weight = alpha * tf.pow(1 - y_pred, gamma)
        loss = weight * cross_entropy
        # Усредняем по батчу и классам.
        return tf.reduce_mean(tf.reduce_sum(loss, axis=-1))
    return focal_loss_fixed

# Пример использования Focal Loss:
if __name__ == "__main__":
    # Тестовые данные для отладки (да, я тоже люблю маленькие эксперименты)
    y_true = tf.constant([[1, 0], [0, 1]], dtype=tf.float32)
    y_pred = tf.constant([[0.9, 0.1], [0.2, 0.8]], dtype=tf.float32)
    
    loss_fn = focal_loss(gamma=2.0, alpha=0.25)
    loss_value = loss_fn(y_true, y_pred)
    print("Focal Loss:", loss_value.numpy())

Интеграция кастомного loss в модель Keras

Создадим простую CNN‑модель для распознавания изображений и подключим Focal Loss:

import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense

def create_model(input_shape=(28, 28, 1), num_classes=10):
    model = Sequential([
        Conv2D(32, kernel_size=(3, 3), activation='relu', input_shape=input_shape),
        MaxPooling2D(pool_size=(2, 2)),
        Flatten(),
        Dense(128, activation='relu'),
        Dense(num_classes, activation='softmax')
    ])
    return model

# Компилируем модель с кастомной функцией потерь
model = create_model()
model.compile(optimizer='adam', loss=focal_loss(gamma=2.0, alpha=0.25), metrics=['accuracy'])

# Создадим тестовые данные (набор из случайных изображений и меток)
import numpy as np
X_train = np.random.rand(100, 28, 28, 1)
y_train = tf.keras.utils.to_categorical(np.random.randint(0, 10, 100), num_classes=10)

print("Запускаем обучение модели с кастомным Focal Loss...")
model.fit(X_train, y_train, epochs=3, batch_size=16)

Модель обучается и градиенты сходятся.

Нюансы вычисления градиентов

Нельзя забывать — любые операции, выполняемые с numpy, ломают автоматическое вычисление градиентов. Пример плохой практики:

import numpy as np
import tensorflow as tf

def loss_with_numpy(y_true, y_pred):
    # Плохая практика: переводим тензоры в numpy и разрываем градиентный поток.
    y_true_np = y_true.numpy()  # Ой-ой, ошибка внутри GradientTape!
    y_pred_np = y_pred.numpy()
    loss_np = np.mean((y_true_np - y_pred_np) ** 2)
    return tf.constant(loss_np, dtype=tf.float32)

if __name__ == "__main__":
    x = tf.constant([[1.0], [2.0]])
    y_true = tf.constant([[1.5], [2.5]])
    
    with tf.GradientTape() as tape:
        tape.watch(x)
        y_pred = x * 2
        try:
            loss = loss_with_numpy(y_true, y_pred)
            grad = tape.gradient(loss, x)
            print("Gradient:", grad)
        except Exception as e:
            print("Ошибка при вычислении градиента:", e)

Оставайтесь в мире тензоров — TensorFlow умеет всё, что нужно, если вы не решите подмешать туда numpy.

Custom Loss Functions в PyTorch

Реализация кастомной loss через torch.autograd.Function

Начнем с простейшей реализации кастомной loss‑функции, которая считает квадратичную ошибку:

import torch

class CustomLossFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, target):
        """
        Прямой проход: вычисляем MSE.
        """
        ctx.save_for_backward(input, target)
        loss = torch.mean((input - target) ** 2)
        return loss

    @staticmethod
    def backward(ctx, grad_output):
        """
        Обратный проход: аккуратно считаем градиенты.
        """
        input, target = ctx.saved_tensors
        grad_input = grad_output * 2 * (input - target) / input.numel()
        return grad_input, None

# Тестовый пример использования:
if __name__ == "__main__":
    x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
    y = torch.tensor([1.5, 2.5, 3.5])
    
    loss = CustomLossFunction.apply(x, y)
    print("Custom Loss (PyTorch):", loss.item())
    
    loss.backward()
    print("Gradient (PyTorch):", x.grad)

Focal Loss в PyTorch

Focal Loss существует не только в TensorFlow. В PyTorch можно сделать не хуже:

import torch
import torch.nn as nn
import torch.nn.functional as F

class FocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2.0, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        # Если inputs – логиты, используем sigmoid для преобразования
        BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
        pt = torch.exp(-BCE_loss)
        focal_loss = self.alpha * (1 - pt) ** self.gamma * BCE_loss
        
        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss

# Тестируем Focal Loss в PyTorch:
if __name__ == "__main__":
    inputs = torch.tensor([[0.2, -1.0], [1.5, 0.3]], requires_grad=True)
    targets = torch.tensor([[0, 1], [1, 0]], dtype=torch.float32)
    
    criterion = FocalLoss(alpha=0.25, gamma=2.0)
    loss = criterion(inputs, targets)
    print("Focal Loss (PyTorch):", loss.item())
    
    loss.backward()
    print("Gradients (Focal Loss):", inputs.grad)

Работа с эмбеддингами

Для задач, где нужно сравнивать схожесть объектов, подойдут Contrastive и Triplet Loss. Реализуем их в PyTorch.

Contrastive Loss

class ContrastiveLoss(nn.Module):
    def __init__(self, margin=1.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, output1, output2, label):
        # Евклидова дистанция между эмбеддингами
        euclidean_distance = F.pairwise_distance(output1, output2)
        loss_contrastive = torch.mean((1 - label) * torch.pow(euclidean_distance, 2) +
                                      (label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2))
        return loss_contrastive

# Пример использования Contrastive Loss:
if __name__ == "__main__":
    output1 = torch.tensor([[1.0, 2.0], [3.0, 4.0]], requires_grad=True)
    output2 = torch.tensor([[1.5, 2.5], [2.5, 3.5]], requires_grad=True)
    # label: 0 для похожих пар, 1 для непохожих.
    label = torch.tensor([0, 1], dtype=torch.float32)
    
    criterion = ContrastiveLoss(margin=1.0)
    loss = criterion(output1, output2, label)
    print("Contrastive Loss:", loss.item())
    loss.backward()

Triplet Loss

class TripletLoss(nn.Module):
    def __init__(self, margin=1.0):
        super(TripletLoss, self).__init__()
        self.margin = margin

    def forward(self, anchor, positive, negative):
        pos_distance = F.pairwise_distance(anchor, positive, p=2)
        neg_distance = F.pairwise_distance(anchor, negative, p=2)
        losses = torch.relu(pos_distance - neg_distance + self.margin)
        return losses.mean()

# Пример использования Triplet Loss:
if __name__ == "__main__":
    anchor = torch.tensor([[1.0, 2.0], [2.0, 3.0]], requires_grad=True)
    positive = torch.tensor([[1.1, 2.1], [1.9, 2.9]], requires_grad=True)
    negative = torch.tensor([[3.0, 4.0], [4.0, 5.0]], requires_grad=True)
    
    criterion = TripletLoss(margin=1.0)
    loss = criterion(anchor, positive, negative)
    print("Triplet Loss:", loss.item())
    loss.backward()

Если вам хочется поделиться опытом — пишите в комментариях.

Все актуальные методы и инструменты DS и ML можно освоить на онлайн-курсах OTUS: в каталоге можно посмотреть список всех программ, а в календаре — записаться на открытые уроки.

Теги:
Хабы:
+7
Комментарии0

Публикации

Информация

Сайт
otus.ru
Дата регистрации
Дата основания
Численность
101–200 человек
Местоположение
Россия
Представитель
OTUS

Истории