Как стать автором
Обновить

SAUNet: Shape Attentive U-Net for Interpretable Medical Image Segmentation

Время на прочтение3 мин
Количество просмотров1.5K

Все чаще для сегментации изображений используется глубокое обучение и сверточные нейронные сети. В случае медицинских картинок достаточно сильно проявляются основные проблемы этого метода: не хватает робастности и интерпретируемости. Происходит это в основном из-за того, что CNN обучаются на текстуре изображения, а не на форме, или требуются дополнительные вычисления post hoc, которые, как было показано, ненадежны с точки зрения интерпретируемости.

В статье SAUNet: Shape Attentive U-Net for Interpretable Medical Image Segmentation авторы (Jesse Sun, Fatemeh Darbehani, Mark Zaidi, Bo Wang) предлагают добавить к модели U-Net второй поток данных о форме, а также использовать dual-attention декодер. Такой метод позволил получить очень хорошие результаты на датасетах изображений МРТ сердца SUN09 и AC17, обеспечивая высокую интерпретируемость при различных разрешениях.

Работа опирается на последние достижения в области моделей channel-attention с использованием модулей сжатия и возбуждения, предложенных Hu и др., и spatial attention c оценкой внимания, предложенных Jetley и др..

Суть подхода

ASAUNet состоит из двух потоков: texture stream и gated shape stream.

Texture stream

Текстурный поток имеет структуру U-Net, но энкодер заменен на структуру блоков из DenseNet-121, которые похожи на TiramisuNet. Также авторы предлагают декодер двойного внимания собственной разработки.

Dual attention decoder block

Этот блок объединяет карты, выдаваемые энкодером, с картами декодера с более низким разрешением, которые захватывают больше контекстной и пространственной информации.
Особенно для медицины важно понимать, основываясь на каких особенностях, модель принимает решения.

Авторы предлагают следующую структуру этого блока:

  • стандартная нормализованная свертка 3\times 3

  • spatial attention path

  • channel-wise attention path

Spatial Attention Path

Как раз используется для большей интерпретируемости решений. Состоит из двух сверток 1\times 1, первая из C каналов получает\tfrac{C}{2}, вторая 1. Далее применяется сигмоида для вложения в отрезок [0, 1]. Затем складываем C копий в один канал для дальнейшего умножения.

Channel Attention Path

Состоит из модуля сжатия и внимания, выдает коэффициент сжатия для каждого канала, затем масштабирует на него.

Gated shape stream

Второй поток обрабатывает информацию о форме и границах, используя объекты, обработанные энкодером из первого потока текстур.
На каждом шаге данные из U-Net модели передаются в Shape Stream, где вычисляется карта внимания границ \alpha_lпо следующей формуле:

\alpha_l = \sigma\Bigl( C_{1\times 1} (S_l \| C_{1\times 1}(T_t))\Bigr)

где \sigma - сигмоидная функция, \| - конкатенация карт по каналам, T_t и S_l - feature maps из потоков, C_{1\times 1} - нормализованная свертка 1\times 1.

Feature map следующего слоя этого потока получают применением Residual блока, состоящего из двух нормализованных конволюций 3\times 3 с skip-connection, к предыдущему, поэлементно домноженному на \alpha_l:

S_{l+1} = R(S_l \otimes \alpha_l)

Функция потерь

Состоит из обычной функции кросс энтропии и dice loss, вычисляющей перекрытие и сходство между двумя наборами.

Первая определена как

L_{CE} (\hat y, y) = \frac{1}{\lvert \Omega \rvert} \sum_i^{\Omega} - y_i \log (\hat y_i) - (1-y_i) \log(1-\hat y_i)

Вторая

L_{Dice} (\hat y, y) = 1 - \frac{2}{K} \sum_k^{K-1} \frac{\sum_i^{\Omega} y^k \hat y_i^k}{\sum_i^{\Omega} y_i^k + \hat y_i^k},

гдеK- количество классов, аy_k^kобозначаетi-й пиксельk-го индексированного класса матрицыy.

Также определим L_{Edge} - двоичная кросс энтропия предсказанных границ.

Итоговая функция потерь имеет вид

L_{total} = \lambda_1 L_{CE} + \lambda_2 L_{Dice} + \lambda_3 L_{Edge},

где\lambda_1, \lambda_2, \lambda_3- гиперпараметры. Авторы используют их равными единице, и все работает хорошо, при этом утверждают, что если занулить один из гиперпараметров, модель плохо и медленно обучается.

Эксперименты

Эксперименты проводились на двух датасетах:

  • SUN09, два класса (endocardium, epicardium), всего 260 и 135 соответственно двумерных MРT кадров 128 на 128 пикселей. Все модели работали 120 эпох, с размером батча 4.

  • AC17, 200 наборов МРТ снимков от 100 уникальных пациентов с разрешением 256 на 256, 180 эпох.

Дополнительно авторы провели сравнение на втором датасете, подтверждающее эффективность дополнительного потока.

Итоговые оценки на датасетах получились следующие:

Интерпретируемось

Для получения карты значимости каждого изображения с использованием модели SmoothGrad приходится прогонять модель 25-50 раз вперед-назад, из-за этого на все 384 изображения ушло 24 минуты. А для модели авторов всего 20 секунд! Кроме этого новый метод предлагает и методы определения значимости на разных уровнях и разрешениях, которые другие модели, основанные на градиенте, не предлагают.

Пример приведен ниже:

Заключение

Получилась неплохая модель, особенно учитывая, что она гораздо проще интерпретируемая, чем ее предшественники, поэтому авторы надеются, что приблизили методы глубокого обучения к клиническому применению, обещают проделать дополнительную работу по интерперетируемости.

P.S.
Этот пост написан для https://github.com/spbu-math-cs/ml-course

Теги:
Хабы:
Всего голосов 4: ↑4 и ↓0+4
Комментарии0

Публикации

Истории

Работа

Data Scientist
61 вакансия

Ближайшие события