
Привет, Хабр!
В мире медицинского Machine Learning сейчас доминируют англоязычные открытые решения (базирующиеся в основном на датасетах вроде MIMIC-CXR или CheXpert). Если вы хотите развернуть локальную мультимодальную (Vision-Language) модель, которая будет генерировать медицинские репорты по рентгену на русском языке, вы столкнетесь с полным вакуумом.
В этой статье я расскажу о своем пет-проекте: как я с нуля собрал и обучил архитектуру VisionEncoderDecoder, используя "глаза" от Google и "мозг" от Сбера, как решал проблемы с датасетами на Kaggle и почему Seq2SeqTrainer от Hugging Face крашится при сохранении чекпоинтов.
<cut /> 1. Выбор архитектуры: хирургия конфигураций
Обучать мультимодальную модель с нуля — задача для корпораций. Я пошел по пути объединения двух предобученных моделей с помощью VisionEncoderDecoderModel от Hugging Face.
Encoder (Зрение): google/vit-base-patch16-224-in21k. Отличный экстрактор визуальных фичей, разбивающий картинку на патчи.
Decoder (Текст): ai-forever/rugpt3small_based_on_gpt2. Компактная русскоязычная генеративная модель.
Проблема: ruGPT-3 — это классическая каузальная языковая модель. У неё нет слоев кросс-внимания (Cross-Attention), чтобы принимать скрытые состояния от энкодера.
Решение: Пришлось модифицировать конфигурацию на лету перед сборкой модели:
code Python
downloadcontent_copy
expand_less
from transformers import AutoConfig, AutoModelForCausalLM, AutoModel # Загружаем визуальный энкодер encoder = AutoModel.from_pretrained("google/vit-base-patch16-224-in21k") # Взламываем конфиг текстового декодера decoder_config = AutoConfig.from_pretrained("ai-forever/rugpt3small_based_on_gpt2") decoder_config.is_decoder = True decoder_config.add_cross_attention = True # Заставляем модель отрастить новые слои decoder = AutoModelForCausalLM.from_pretrained("ai-forever/rugpt3small_based_on_gpt2", config=decoder_config)
После сборки Hugging Face инициализирует пустые веса для новых слоев кросс-внимания. Именно их, а также проекционные слои, нам и предстояло обучить.
2. Data Engineering: как Kaggle прячет файлы
Где взять русскоязычный датасет пар "Рентген -> Текст"? Нигде.
Я взял открытый американский датасет Indiana University Chest X-Ray (IU X-Ray) (около 7500 снимков и репортов). Прямо в ноутбуке Kaggle я поднял пайплайн на базе Helsinki-NLP/opus-mt-en-ru и батчами перевел медицинские заключения (раздел findings) на русский язык.
Но при написании кастомного PyTorch Dataset вылез классический Kaggle-баг. Стандартные скрипты маппинга видели только 4 картинки-превьюшки из 7000. Платформа раскидала реальные файлы по скрытым директориям.
Пришлось писать хардкорный Deep Scan Mapping:
code Python
downloadcontent_copy
expand_less
# Тотальное сканирование диска Kaggle file_map = {} for root, dirs, files in os.walk('/kaggle/input'): for file in files: if file.lower().endswith(('.png', '.jpg', '.jpeg')): base_name = os.path.splitext(file)[0] file_map[base_name] = os.path.join(root, file)
Только после привязки CSV-идентификаторов к этому словарю file_map конвейер данных заработал стабильно.
3. Обучение: обход крашей и защита от таймаута
Обучение проходило на 2x NVIDIA T4. Чтобы влезть в память, я использовал смешанную точность (fp16=True) и gradient_accumulation_steps=4 (при batch_size=4 на устройство мы получали виртуальный батч 32).
Грабли Hugging Face:
На 100-м шаге обучение стабильно падало с ошибкой внутри savecheckpoint. Оказалось, что Seq2SeqTrainer в последних версиях transformers сходит с ума, пытаясь автоматически сохранить композитную модель с неродными генеративными параметрами в процессе обучения.
Фикс: Отключить промежуточные сохранения (save_strategy="no") и перенести параметры генерации в отдельный объект GenerationConfig перед финальным save_pretrained().
А чтобы Kaggle не убил сессию по таймауту неактивности (пока я пил кофе 2.5 часа, ожидая завершения 15 эпох), в консоль браузера был заботливо отправлен скрипт:
code JavaScript
downloadcontent_copy
expand_less
function KeepClicking(){ document.querySelector("body").click() } setInterval(KeepClicking, 60000);
4. Результаты и Инференс
Модель училась всего 15 эпох, но результаты меня приятно удивили. Она научилась профессиональному медицинскому сленгу и уверенно распознает базовые паттерны: чистые легкие, пневмоторакс, кардиомегалию.
Конечно, из-за маленького датасета присутствуют галлюцинации (может "найти" катетер там, где его нет), но архитектура (Vision + Language) доказала свою жизнеспособность.
Инференс (как это запустить):
code Python
downloadcontent_copy
expand_less
import torch from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer from PIL import Image model_id = "livadies/Russian-Radiologist-ruGPT-ViT" model = VisionEncoderDecoderModel.from_pretrained(model_id) feature_extractor = ViTImageProcessor.from_pretrained(model_id) tokenizer = AutoTokenizer.from_pretrained(model_id) image = Image.open("xray.jpg").convert("RGB") pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values generated_ids = model.generate(pixel_values, max_length=128, num_beams=4) print(tokenizer.decode(generated_ids[0], skip_special_tokens=True))
Для тех, кто хочет потыкать модель руками без написания кода, я развернул Gradio-приложение на бесплатных серверах Hugging Face:
👉 Live Demo: AI-Radiologist-RU Space (предупреждение: работает на CPU, ответ генерируется 10-15 секунд).
Весь исходный код, пайплайны подготовки данных и ноутбуки доступны в моем репозитории:
👉 GitHub: Multimodal-XRay-Analyzer-RU
Буду рад код-ревью, советам по улучшению датасета и конструктивной критике в комментариях!
(Дисклеймер: Модель создана исключительно в исследовательских целях и не является заменой реального врача).
