Pull to refresh

Обзор архитектуры Swin Transformer

Reading time7 min
Views14K

Трансформеры шагают по планете! В статье вспомним/узнаем как работает visual attention, поймём, что с ним не так, а главное как его поправить, чтобы получить на выходе best paper ICCV21.

CV-трансформеры in a nutshell

Attention Is All You Need

Начнём издалека, а именно с 2017 года, когда A Vaswani et al. опубликовали знаменитую статью «Attention Is All You Need», в которой была предложена архитектура нейронной сети Transformer для решения задачи seq2seq и в частности машинного перевода. Не буду говорить о том, насколько значимым было это событие для всего NLP. Скажу лишь, что на данный момент почти каждое ML решение, работающее с текстом, пожинает плоды того успеха, используя Transformer-based архитектуру напрямую, работая с эмбеддингами из BERT-а или еще каким-нибудь образом. Ключевым и идейно чуть ли не единственным компонентом трансформера является слой Multi-Head Attention. В применении к задаче машинного перевода он дал возможность учитытвать взаимодействие между словами, находящимися на произвольно большом расстоянии в тексте, что выгодно выделило трансформер на фоне других моделей перевода и позволило ему занять место под солнцем. Формально этот слой записывается в терминах следующих преобразований:

\text{Attention}(Q, K, V) = \text{Softmax}\left(\frac{QK^{\mathsf{T}}}{\sqrt{d_k}}\right)V,\text{MultiHeadAttention}(Q, K, V) = \text{Concat}(\text{head}_1, \dots, \text{head}_h)W^O,

где

\text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V).

С 2017 года было предложено бесчисленное множество модификаций трансформера (Linformer, Reformer, Perfomer, etc.), делающих его более вычислительно эффективным, стабилизирующих обучение и так далее. Такой бум трансформеров не мог не затронуть другие сферы применения глубинного обучения помимо NLP, и с ~2020 года они начали проникать в CV.

Трансформеризация Computer Vision-а

Вообще идея применения трансформера к изображениям изначально может смутить читателя. Всё-таки текст и картинки — это довольно отличающиеся модальности, как минимум тем, что текст является последовательностью слов. Изображение — это тоже в некотором смысле последовательность (пикселей), направление которой можно определить искусственно, например построчно, правда у такого определения не будет семантического смысла в отличие от текста. Однако не стоит забывать, что Multi-Head Attention на самом деле является операцией не над последовательностями, а над неупорядоченными множествами векторов, а последовательной структурой текст наделяется искусственно с помощью positional encoding-а, так что аргумент выше становится невалидным. Валидным аргументом против Visual Transformer-ов может быть отсутствие в них приятных inductive bias-ов, имеющихся у свёрточных сетей: эквивариантности относительно сдвигов и предположении о пространственной локальности принзнаков. Однако это так же спорный момент, подробнее про который можно почитать здесь. А пока сомневающиеся сомневались, исследовательские группы делали, и явили миру несколько Visual Transoformer-ов (неплохой survey), в том числе ViT, на примере которого мы выясним как же переформулировать Multi-Head Attention для изображений.

An Image Is Worth 16x16 Words

Авторы ViT-иа предложили довольно прямолинейную архитектуру:

Исходная картинка нарезается на патчи 16x16, они вытягиваются в вектора и все пропускаются через линейный слой. Далее к ним прибавляются обучаемые вектора, играющие роль positional embedding-ов, а также к набору добавлется отдельный обучаемый эмбеддинг, являющийся прямым аналогом CLS-токена BERT-а. А на этом то и всё! Далее идёт самый обычный Transformer Encoder (N x Multi-Head Attention если угодно), и класс изображения предсказывается маленьким перцептроном, берущим на вход то, что получилось на месте CLS-токена. Как и ~любой трансформер, модель получилась очень прожорливой в том смысле, что для получения околосотовых результатов ей нужно предобучаться на громадных датасетах, таких как закрытый гугловский JFT-300M. Тем не менее в определённом сетапе сетка смогла обойти сотовых BiT-L и Noisy Student-а, что можно считать успехом. За подробностям отсылаю читателя к оригинальной статье, много интересного можно найти в ablation-е, особенно советую изучить графики Mean Attention Distance-а, являющимся аналогом receptive field-а.

Всё вроде и неплохо, но не классификацией единой занимаются в CV. Есть задачи по типу Object Detection-а, в которых зачастую важны мелкие детали, или же задачи сегментации, для которой вообще необходимо делать pixel-level предсказание. Все это требует как минимум возможности работы с изображениями высокого разрешения, то есть значительного увеличения размера входа. А как нетрудно видеть, Attention работает за квадратичное по входу время, что в случае картинок 1920х1920 является острейшей проблемой, так как время forward pass-а взмывает до небес. К тому же мелкие детали могу потеряться уже на первом слое, который суть свёртка 16х16 со страйдом 16. Кто виноват и что делать? На первый вопрос ответ +- понятен — дело в слишком твердолобой адаптации трансформерной архитектуры под CV. А ответу на второй вопрос посвящается оставшаяся часть этой статьи.

Swin Transformer

Проблемы ViT-а обозначились ещё в предыдущем параграфе, поэтому не будем ходить вокруг да около и сразу перейдем к рассмотрению архитектуры, предложенной в статье «Swin Transformer: Hierarchical Vision Transformer using Shifted Windows»:

Первый слой качественно такой же, как и в ViT-е — исходная картинка нарезается на патчи и проецируется линейным слоем. Единственное отличие в том, что в Swin-е на первом слое патчи имеют размер 4х4, что позволяет обрабатывать более мелкий контекст. Далее идут несколько Patch Merging и Swin Transformer Block слоёв. Patch Merging занимается тем, что конкатенирует фичи соседних (в окне 2х2) токенов и понижает размерность, получая более высокоуровневое представление. Таким образом, после каждого Stage-а образуются «карты» признаков, содержащие информацию на разных пространственных масштабах, что как раз и позволяет получить иерархическое представление изображения, полезное для дальнейшей сегментации/Object Detection-а/etc:

Благодаря этому Swin Transfomer может служить универсальным backbone-ом для различных задач CV.

Swin Transformer Block — ключевая изюминка всей архитектуры:

Как видно из схемы, два последовательных блока представляют собой два классических трансформерных блока с MLP, LayerNorm-ами и Pre-Activation Residual-ами, однако Attention заменён на нечто более хитрое, к разбору чего мы непременно переходим.

(Shifted) Window Multi-Head Attention

Как было упомянуто, проблемой Multi-Head Attention-а является его квадратичная сложность, больно стреляющая в ногу при применении на картинках высокого разрешения. На ум приходит довольно простое решение, представленное еще в статье про Longformer — давайте для каждого токена считать Attention не со всеми другими токенами, а только с находящимися в некотором окне фиксированного размера (Window Mutli-Head Attention). Если размерность токенов — C, а размер окна — MxM, то сложности для (Window) Multi-Head Self Attention-ов получаются следующие:

\Omega(MSA) = 4hwC^2 + 2(hw)^2C,\Omega(W\text{-}MSA) = 4hwC^2 + 2M^2hwC

То есть Attention теперь работает за линейное по hw время! Однако такой подход уменьшает общую репрезентативную способность сети, так как токены из различных окон никак не будут взаимодействовать. Чтобы исправить ситуацию, авторы поступили любопытным образом. После каждого блока с Window Multi-Head Attention-ом они поставили аналогичный слой, со смещёнными по диагонали окнами Attention-а:

Это вернуло взаимодействие между токенами, оставив при этом линейную вычислительную сложность.

Как проиллюстрировано выше, сдвиг окон Attention-а увеличивает их количество. Это значит, что реализация этого слоя с наивным паддингом исходной «карты» признаков нулями обяжет считать больше Attention-ов (9 вместо 4 в примере), чем мы посчитали бы без сдвига. Чтобы не производить лишних вычислений, авторы предложили перед подсчётом циклически сдвигать само изображение и вычислять уже маскированный Attention, чтобы исключить взаимодействие не соседних токенов. Такой подход вычислительно эффективнее наивного, так как количество считаемых Attention-ов не увеличивается:

Также в Swin-е авторы использовали несколько другие positional embedding-и. Их заменили на обучаемую матрицу В, называемую relative position bias, которая прибавляется к произведению query и key под софтмаксом:

\text{Attention}(Q, K, V) = \text{Softmax}\left(\frac{QK^{\mathsf{T}}}{\sqrt{d_k}}+ B\right)V.

Как оказалось, такой подод приводит к лучшему качеству.

Эксперименты и результаты

Всего авторы предложили 4 модели разных размеров:

Для честного сравнения параметры были подобраны так, чтобы по размерам и количеству вычислений Swin-B примерно соответствовал ViT-B/DeiT-B, а Swin-T и Swin-S ResNet-50 и ResNet-101 соответственно.

ImageNet-1k классификация

В данном бенчмарке были проверены два сетапа: обучение на ImageNet-1k и предобучение на ImageNet-22K с дообучением на ImageNet-1K. Модели сравнивались по top-1 accuracy.

В первой постановке Swin-ы более чем на 1.5% обошли другие Visual Transformer-ы, в том числе ViT-ы, отстающие от первых на целых 4%. Сотовые же EfficienetNet-ы и RegNet-ы оказались соперником посерьёзнее — статистически значимо тут можно говорить разве что об улучшении баланса между точностью и быстродействием. Во второй постановке предобучение на ImageNet-22K дало ~2%-ый прирост точности, а Swin-L достиг 87.3% top-1 accuracy. Это ещё раз подтверждает важность предобучения в особенности для трансформерных архитектур.

COCO детектирование объектов

Для оценки Swin-а в качестве backbone-а для детекции авторы использовали его вместе с такими фреймворками детекции как Cascade Mask R-CNN, ATSS, RepPointsV2 и Sparse R-CNN. В качестве backbone-ов для сравнения были взяты ResNe(X)t, DeiT и несколько сотовых свёрточных аритектур.

Для всех фреймворков Swin backbone дал уверенные +3.5-4.2% AP относительно классического ResNet50. Относительно ResNe(X)t-а Swin также показал рост в ~3% AP сразу для нескольких его версий Swin-T, Swin-S и Swin-B. DeiT проиграл Swin-у чуть меньше — около 2% AP, но был сильно медленее из-за честного Multi-Head Attention-а по всей картинке. Ну и относительно большого набора сотовых детекторов Swin-L с HTC показал улучшение в ~2.6 AP.

ADE20K семантическая сегментация

Для сегментации с помощью Swin-а был выбран фреймворк UperNet, он сравнивался с несколькими популярными сегментаторами, а так же с моделью на основе DeiT. Swin-S обошел Deit-S на целых 5.3 mIoU, а ResNet-101 и ResNeSt на 4.4 и 2.4 mIou соответственно. При этом Swin-L, предобученный на ImageNet-22k, выбил 53.5 mIoU, обойдя SETR на 3.2 mIoU.

Итоги

В результате имеем следующее: авторам удалось несколько переформулировать трансформерную архитектуру под задачи CV, сделав её вычислительно более оптимальной за счёт использования локального Attention-а. При этом Shifted Window Multi-Head Attention оставил репрезентативную способность сети на уровне, достаточном, чтобы соревноваться с текущими сотовыми моделями. Благодаря этому стало возможным построить архитеткуру, позволяющую извлекать из изображений фичи на разных пространственных масштабах, что позволило успешно использовать Swin как backbone в задачах сегментации и детекции, где до этого трансформеры были на более низких позициях. This is success!

Полезные ссылки

Tags:
Hubs:
Total votes 6: ↑6 and ↓0+6
Comments1

Articles