Как стать автором
Обновить

BERT для классификации русскоязычных текстов

Время на прочтение6 мин
Количество просмотров24K

Зачем

В интернете полно прекрасных статей про BERT. Но часто они слишком подробны для человека, который хочет просто дообучить модель для своей задачи. Данный туториал поможет максимально быстро и просто зафайнтюнить русскоязычный BERT для задачи классификации. Полный код и описание доступны в репозитории на github, есть возможность запустить все в google colab одной кнопкой.

Workflow

  1. Данные для обучения

  2. Модель

  3. Helpers

  4. Train

  5. Inference

Данные для обучения

Для обучения использовались очищенные данные русскоязычного твиттера из датасета RuTweetCorp. Данные размечены на 2 класса:

  • '0' - негативные

  • '1' - позитивные

Для упрощения работы используется кастомизированный класс Dataset:

from torch.utils.data import Dataset

class CustomDataset(Dataset):

  def __init__(self, texts, targets, tokenizer, max_len=512):
    self.texts = texts
    self.targets = targets
    self.tokenizer = tokenizer
    self.max_len = max_len

  def __len__(self):
    return len(self.texts)

  def __getitem__(self, idx):
    text = str(self.texts[idx])
    target = self.targets[idx]

    encoding = self.tokenizer.encode_plus(
        text,
        add_special_tokens=True,
        max_length=self.max_len,
        return_token_type_ids=False,
        padding='max_length',
        return_attention_mask=True,
        return_tensors='pt',
    )

    return {
      'text': text,
      'input_ids': encoding['input_ids'].flatten(),
      'attention_mask': encoding['attention_mask'].flatten(),
      'targets': torch.tensor(target, dtype=torch.long)
    }

Стандартный класс расширяется методами __init__, __len__, __getitem__. В методе __init__ инициализируем тексты, метки, максимальную дину текста в токенах, а так же токенайзер. Токенайзер загружаем из репозитория huggingface rubert-tiny. Для загрузки модели используем команду:

from transformers import BertTokenizer
tokenizer_path = 'cointegrated/rubert-tiny'
tokenizer = BertTokenizer.from_pretrained(tokenizer_path)

Метод len возвращает длину нашего датасета. Метод getitem возвращает словарь, который состоит из самого исходного текста, списка токенов, маски внимания, а также метки класса. Отдельно хочется остановить на настройках токенизатора с помощью метода .encode_plus(). В этом методе мы указываем токенизатору, что исходный текст нужно обрамлять служебными токенами add_special_tokens=True, а также дополнять полученные векторы до максимально длины padding='max_len'.

Модель

Используется русскоязычная модель BERT из репозитория huggingface rubert-tiny. Для загрузки модели используем команду:

from transformers import BertForSequenceClassification
model_path = 'cointegrated/rubert-tiny'
model = BertForSequenceClassification.from_pretrained(model_path)

Для классификации необходимо добавить полносвязный слой, количество входов которого — внутренняя размерность эмбеддинга сети, а выход - число классов для классификации. В нашем случае классификация у нас происходит на 2 класса, а внутреннюю размерность можно получить,выполнив следующую команду:

out_features = model.bert.encoder.layer[1].output.dense.out_features

В нашем случае размерность равна 312. Конфигурируем полносвязный слой:

model.classifier = torch.nn.Linear(312, 2)
Инициализация класса выглядит следующим образом:
class BertClassifier:

    def __init__(self, model_path, tokenizer_path, n_classes=2, epochs=1, model_save_path='/content/bert.pt'):
        self.model = BertForSequenceClassification.from_pretrained(model_path)
        self.tokenizer = BertTokenizer.from_pretrained(tokenizer_path)
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.model_save_path=model_save_path
        self.max_len = 512
        self.epochs = epochs
        self.out_features = self.model.bert.encoder.layer[1].output.dense.out_features
        self.model.classifier = torch.nn.Linear(self.out_features, n_classes)
        self.model.to(self.device)

Helpers

Для работы нам необходимо инициализировать вспомогательные элементы.

DataLoader

Используется для формирования батчей. В качестве входных параметров использует кастомный датасет, описанный ранее, а также количество сэмплов в батче.

from torch.utils.data DataLoader
train_set = CustomDataset(X_train, y_train, tokenizer)
train_loader = DataLoader(train_set, batch_size=2, shuffle=True)

Optimizer

Оптимизатор градиентного спуска. В качестве входных параметров передаем параметры нашей модели model.parameters(), а так же скорость обучения lr.

from transformers import AdamW
optimizer = AdamW(model.parameters(), lr=2e-5, correct_bias=False)

Scheduler

Планировщик, нужен для настройки параметров оптимизатора во время обучения. В качестве входных параметров передаем оптимизатор, а так же общее количество шагов для обучения, которое равно произведению количества батчей тренировочной выборки на количество эпох обучения:

from transformers import get_linear_schedule_with_warmup
scheduler = get_linear_schedule_with_warmup(
                optimizer,
                num_warmup_steps=0,
                num_training_steps=len(train_loader) * epochs
            )

Loss

Функция потерь, считаем по ней ошибку модели:

loss_fn = torch.nn.CrossEntropyLoss()
Функция инициализации хэлперов:
def preparation(self, X_train, y_train, X_valid, y_valid):
    # create datasets
    self.train_set = CustomDataset(X_train, y_train, self.tokenizer)
    self.valid_set = CustomDataset(X_valid, y_valid, self.tokenizer)

    # create data loaders
    self.train_loader = DataLoader(self.train_set, batch_size=2, shuffle=True)
    self.valid_loader = DataLoader(self.valid_set, batch_size=2, shuffle=True)

    # helpers initialization
    self.optimizer = AdamW(self.model.parameters(), lr=2e-5, correct_bias=False)
    self.scheduler = get_linear_schedule_with_warmup(
            self.optimizer,
            num_warmup_steps=0,
            num_training_steps=len(self.train_loader) * self.epochs
        )
    self.loss_fn = torch.nn.CrossEntropyLoss().to(self.device)

Train

Обучение для одной эпохи:
def fit(self):
    self.model = self.model.train()
    losses = []
    correct_predictions = 0

    for data in self.train_loader:
        input_ids = data["input_ids"].to(self.device)
        attention_mask = data["attention_mask"].to(self.device)
        targets = data["targets"].to(self.device)

        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask
            )

        preds = torch.argmax(outputs.logits, dim=1)
        loss = self.loss_fn(outputs.logits, targets)

        correct_predictions += torch.sum(preds == targets)

        losses.append(loss.item())
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
        self.optimizer.step()
        self.scheduler.step()
        self.optimizer.zero_grad()

    train_acc = correct_predictions.double() / len(self.train_set)
    train_loss = np.mean(losses)
    return train_acc, train_loss

Данные в цикле батчами генерируются с помощью DataLoader:

for data in self.train_loader:
    input_ids = data["input_ids"].to(self.device)
    attention_mask = data["attention_mask"].to(self.device)
    targets = data["targets"].to(self.device)

Батч подается в модель:

outputs = self.model(
    input_ids=input_ids,
    attention_mask=attention_mask
    )

На выходе получаем распределение вероятности по классам и значение ошибки:

preds = torch.argmax(outputs.logits, dim=1)
loss = self.loss_fn(outputs.logits, targets)

Делаем шаг на всех вспомогательных функциях:

  • loss.backward(): обратное распространение ошибки;

  • clip_grad_norm(): обрезаем градиенты для предотвращения "взрыва" градиентов;

  • optimizer.step(): шаг оптимизатора;

  • scheduler.step(): шаг планировщика;

  • optimizer.zero_grad(): обнуляем градиенты.

Код метода eval:
def eval(self):
    self.model = self.model.eval()
    losses = []
    correct_predictions = 0

    with torch.no_grad():
        for data in self.valid_loader:
            input_ids = data["input_ids"].to(self.device)
            attention_mask = data["attention_mask"].to(self.device)
            targets = data["targets"].to(self.device)

            outputs = self.model(
                input_ids=input_ids,
                attention_mask=attention_mask
                )

            preds = torch.argmax(outputs.logits, dim=1)
            loss = self.loss_fn(outputs.logits, targets)
            correct_predictions += torch.sum(preds == targets)
            losses.append(loss.item())
    
    val_acc = correct_predictions.double() / len(self.valid_set)
    val_loss = np.mean(losses)
    return val_acc, val_loss

Для обучения на нескольких эпохах используется метод train, в котором последовательно вызываются методы fit и eval.

Код метода train:
def train(self):
    best_accuracy = 0
    for epoch in range(self.epochs):
        print(f'Epoch {epoch + 1}/{self.epochs}')
        train_acc, train_loss = self.fit()
        print(f'Train loss {train_loss} accuracy {train_acc}')

        val_acc, val_loss = self.eval()
        print(f'Val loss {val_loss} accuracy {val_acc}')
        print('-' * 10)

        if val_acc > best_accuracy:
            torch.save(self.model, self.model_save_path)
            best_accuracy = val_acc

    self.model = torch.load(self.model_save_path)

Inference

Для предсказания класса для нового текста используется метод predict, который имеет смысл вызывать только после обучения модели. Метод работает следующим образом:

  • Токенизируется входной текст;

  • Токенизированный текст подается в модель;

  • На выходе получаем вероятности классов;

  • Возвращаем метку наиболее вероятного класса.

Код метода predict:
def predict(self, text):
    encoding = self.tokenizer.encode_plus(
        text,
        add_special_tokens=True,
        max_length=self.max_len,
        return_token_type_ids=False,
        truncation=True,
        padding='max_length',
        return_attention_mask=True,
        return_tensors='pt',
    )
    
    out = {
          'text': text,
          'input_ids': encoding['input_ids'].flatten(),
          'attention_mask': encoding['attention_mask'].flatten()
      }
    
    input_ids = out["input_ids"].to(self.device)
    attention_mask = out["attention_mask"].to(self.device)
    
    outputs = self.model(
        input_ids=input_ids.unsqueeze(0),
        attention_mask=attention_mask.unsqueeze(0)
    )
    
    prediction = torch.argmax(outputs.logits, dim=1).cpu().numpy()[0]

    return prediction

Ссылки

  1. Репозиторий на гитхаб

  2. RuTweetCorp

  3. ruBert tiny

Заключение

Хотелось максимально просто и кратко, но все равно получилось как-то объемно. Замечания, исправления и дополнения приветствуются!

Теги:
Хабы:
Всего голосов 10: ↑10 и ↓0+10
Комментарии6

Публикации