Привет, Хабр!

В мире медицинского Machine Learning сейчас доминируют англоязычные открытые решения (базирующиеся в основном на датасетах вроде MIMIC-CXR или CheXpert). Если вы хотите развернуть локальную мультимодальную (Vision-Language) модель, которая будет генерировать медицинские репорты по рентгену на русском языке, вы столкнетесь с полным вакуумом.

В этой статье я расскажу о своем пет-проекте: как я с нуля собрал и обучил архитектуру VisionEncoderDecoder, используя "глаза" от Google и "мозг" от Сбера, как решал проблемы с датасетами на Kaggle и почему Seq2SeqTrainer от Hugging Face крашится при сохранении чекпоинтов.

<cut /> 1. Выбор архитектуры: хирургия конфигураций

Обучать мультимодальную модель с нуля — задача для корпораций. Я пошел по пути объединения двух предобученных моделей с помощью VisionEncoderDecoderModel от Hugging Face.

  • Encoder (Зрение): google/vit-base-patch16-224-in21k. Отличный экстрактор визуальных фичей, разбивающий картинку на патчи.

  • Decoder (Текст): ai-forever/rugpt3small_based_on_gpt2. Компактная русскоязычная генеративная модель.

Проблема: ruGPT-3 — это классическая каузальная языковая модель. У неё нет слоев кросс-внимания (Cross-Attention), чтобы принимать скрытые состояния от энкодера.

Решение: Пришлось модифицировать конфигурацию на лету перед сборкой модели:

code Python

downloadcontent_copy

expand_less

from transformers import AutoConfig, AutoModelForCausalLM, AutoModel

# Загружаем визуальный энкодер
encoder = AutoModel.from_pretrained("google/vit-base-patch16-224-in21k")

# Взламываем конфиг текстового декодера
decoder_config = AutoConfig.from_pretrained("ai-forever/rugpt3small_based_on_gpt2")
decoder_config.is_decoder = True
decoder_config.add_cross_attention = True # Заставляем модель отрастить новые слои

decoder = AutoModelForCausalLM.from_pretrained("ai-forever/rugpt3small_based_on_gpt2", config=decoder_config)

После сборки Hugging Face инициализирует пустые веса для новых слоев кросс-внимания. Именно их, а также проекционные слои, нам и предстояло обучить.

2. Data Engineering: как Kaggle прячет файлы

Где взять русскоязычный датасет пар "Рентген -> Текст"? Нигде.
Я взял открытый американский датасет Indiana University Chest X-Ray (IU X-Ray) (около 7500 снимков и репортов). Прямо в ноутбуке Kaggle я поднял пайплайн на базе Helsinki-NLP/opus-mt-en-ru и батчами перевел медицинские заключения (раздел findings) на русский язык.

Но при написании кастомного PyTorch Dataset вылез классический Kaggle-баг. Стандартные скрипты маппинга видели только 4 картинки-превьюшки из 7000. Платформа раскидала реальные файлы по скрытым директориям.
Пришлось писать хардкорный Deep Scan Mapping:

code Python

downloadcontent_copy

expand_less

# Тотальное сканирование диска Kaggle
file_map = {}
for root, dirs, files in os.walk('/kaggle/input'):
    for file in files:
        if file.lower().endswith(('.png', '.jpg', '.jpeg')):
            base_name = os.path.splitext(file)[0]
            file_map[base_name] = os.path.join(root, file)

Только после привязки CSV-идентификаторов к этому словарю file_map конвейер данных заработал стабильно.

3. Обучение: обход крашей и защита от таймаута

Обучение проходило на 2x NVIDIA T4. Чтобы влезть в память, я использовал смешанную точность (fp16=True) и gradient_accumulation_steps=4 (при batch_size=4 на устройство мы получали виртуальный батч 32).

Грабли Hugging Face:
На 100-м шаге обучение стабильно падало с ошибкой внутри savecheckpoint. Оказалось, что Seq2SeqTrainer в последних версиях transformers сходит с ума, пытаясь автоматически сохранить композитную модель с неродными генеративными параметрами в процессе обучения.
Фикс: Отключить промежуточные сохранения (save_strategy="no") и перенести параметры генерации в отдельный объект GenerationConfig перед финальным save_pretrained().

А чтобы Kaggle не убил сессию по таймауту неактивности (пока я пил кофе 2.5 часа, ожидая завершения 15 эпох), в консоль браузера был заботливо отправлен скрипт:

code JavaScript

downloadcontent_copy

expand_less

function KeepClicking(){ document.querySelector("body").click() }
setInterval(KeepClicking, 60000);

4. Результаты и Инференс

Модель училась всего 15 эпох, но результаты меня приятно удивили. Она научилась профессиональному медицинскому сленгу и уверенно распознает базовые паттерны: чистые легкие, пневмоторакс, кардиомегалию.

Конечно, из-за маленького датасета присутствуют галлюцинации (может "найти" катетер там, где его нет), но архитектура (Vision + Language) доказала свою жизнеспособность.

Инференс (как это запустить):

code Python

downloadcontent_copy

expand_less

import torch
from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
from PIL import Image

model_id = "livadies/Russian-Radiologist-ruGPT-ViT"
model = VisionEncoderDecoderModel.from_pretrained(model_id)
feature_extractor = ViTImageProcessor.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)

image = Image.open("xray.jpg").convert("RGB")
pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values

generated_ids = model.generate(pixel_values, max_length=128, num_beams=4)
print(tokenizer.decode(generated_ids[0], skip_special_tokens=True))

Для тех, кто хочет потыкать модель руками без написания кода, я развернул Gradio-приложение на бесплатных серверах Hugging Face:
👉 Live Demo: AI-Radiologist-RU Space (предупреждение: работает на CPU, ответ генерируется 10-15 секунд).

Весь исходный код, пайплайны подготовки данных и ноутбуки доступны в моем репозитории:
👉 GitHub: Multimodal-XRay-Analyzer-RU

Буду рад код-ревью, советам по улучшению датасета и конструктивной критике в комментариях!

(Дисклеймер: Модель создана исключительно в исследовательских целях и не является заменой реального врача).