Pull to refresh

Могут ли KAN справляться с задачами компьютерного зрения?

Level of difficultyMedium
Reading time10 min
Views2.9K
Original author: Ivan Drokin

Недавняя статья об новой архитектуре нейронных сетей на основе теоремы Колмогорова-Арнольда (KAN Kolmogorov-Arnold Networks) вызвала большой ажиотаж: уже было представлено множество вариаций того, как правильно создавать такие сети, ведутся горячие дебаты, а рабочая ли схема и имеет ли право на жизнь и многое другое. Цель этой статьи постараться ответить на простой вопрос: могут ли KAN справляться с компьютерным зрением?

Исходный код всех экспериментов данной статьи можете найти по ссылке.

Пока ещё не ушли далеко, сразу же скажем, что у этой статьи два автора: Иван Дрокин, который написал оригинал на английском языке, проводил эксперименты и написал собственную библиотеку свёрток для KAN — torch-conv-kan. За вольный перевод для читателей Хабра, многочисленные ревью и правки оригинала статьи, а также часть глупых мемов отвечал Антон Клочков — автор телеграм-канала MLE шатает Produnction (нет, в названии нет опечатки) про нейронные сети, мемы, языки программирования и рассуждения об общих топиках в IT.

Все рассуждения в статье ведутся от лица Ивана.

Основы

Давайте начнем с самого фундаментального — математики. Если говорить кратко, то многослойные перцептроны (MLP — Multi-Layer Perceptron) представляют собой нелинейную функцию от взвешенной суммы входных данных, тогда как сети Колмогорова-Арнольда представляют собой сумму нелинейных унарных функций от входных данных.

Работоспособность MLP основывается на теореме Цыбенко или более известной в широких кругах, как универсальной теореме апроксимации. Если не вдаваться в детали, то она утверждает, что любую функцию можно апроксимировать с любой точностью с помощью одного достаточно широкого скрытого слоя с нелинейной функцией активации.

Чтобы не расписывать формулы в статье, скрин из Wikipedia с необходимым утверждением
Чтобы не расписывать формулы в статье, скрин из Wikipedia с необходимым утверждением

С другой стороны, как уже было написано в начале статьи, работа KAN базируется на основе теоремы Колмогорова-Арнольда:

Чтобы не расписывать формулы в статье, скрин из Wikipedia с необходимым утверждением
Чтобы не расписывать формулы в статье, скрин из Wikipedia с необходимым утверждением

Авторы статьи "KAN: Kolmogorov-Arnold Networks" разработали новый тип архитектуры нейронных сетей: обучаемые активации на рёбрах и суммирование результатов в узлах. В противовес этому, в MLP в узлах применяется фиксированная нелинейная функция, а на ребрах производится линейная проекция входов.

Наглядное представление разницы между двумя архитектурами. Картинка взята из статьи "KAN: Kolmogorov-Arnold Networks".
Наглядное представление разницы между двумя архитектурами. Картинка взята из статьи "KAN: Kolmogorov-Arnold Networks".

Что же это за "обучаемые активации" и как их обучать? В оригинальной статье было предложено использовать B-сплайны (с классными визуализациями можно глянуть следующую статью). Сплайн-функция степени n представляет собой кусочно-полиномиальную функцию степени n-1. Места, где отрезки соединяются, называются узлами. Ключевое свойство сплайн-функций заключается в том, что они и их производные могут быть непрерывными, в зависимости от кратности узлов.

Визуализация B-сплайна (синяя линия) и контролирующих её точек (красным цветом). Источник.
Визуализация B-сплайна (синяя линия) и контролирующих её точек (красным цветом). Источник.

Плюсы и минусы KAN

Авторы выделяют следующие плюсы использование KAN относительно MLP:

  1. Выше качество: KAN показали точность выше на большом множестве задач в сравнении с MLP. Первые могут более эффективно (в терминах качества) представлять сложные многомерные функции (спасибо теореме Колмогорова-Арнольда), что положительно сказывается на качестве.

  2. Интерпретируемость: MLP — это чёрный ящик, сложно сказать, что происходит у них внутри, а KAN может предложить интересные возможности. Например, можно разложить сложную многомерную функцию на более простые компоненты, анализ которых может служить инсайтом о поведении модели, о работе на конкретных данных.

  3. Гибкость и обобщаемость: за счёт обучаемости активаций можно лучше находить нелинейные зависимости в данных, что также ведёт к обобщаемости (но это не так просто).

  4. Устойчивость к шумным данным и Adversarial Attacks: способность KAN улавливать более устойчивые представления данных с помощью адаптивных функций активации позволяют KAN быть более устойчивым к шумам и атакам.

Но есть и определённые сложности с KAN (No free lunch, помните?):

  1. Чувствительность к гиперпараметрам: как и любая другая нейронная сеть, KAN чувствительна к множеству гиперпараметров, таких как learning rate, силу регуляризации, самой архитектуре. Проблема подбора правильных параметров остаётся актуальной и может в значительной степени влиять на сходимость. Тут стоит отметить, что на настоящий момент нет чётких рецептов того, как именно тюнить разного рода гиперпараметры (L1/L2 веса, параметры активаций, dropout коэффициенты и т.д.). Это чем-то напоминает времена, когда только начиналось развитие свёрточных нейронных сетей.

  2. Высокие вычислительные затраты: адаптивные функции активации (например, B-сплайны) могут требовать больше ресурсов, чем классические MLP, что отражается на длительности и стоимости обучения и инференса;

  3. Сложность моделей и масштабируемость: KAN масштабируемы в терминах гибкости модели, но при этом более глубокие сети ещё сильнее увеличивают вычислительные затраты. Масштабирование KAN на большие датасеты и сложные задачи, где вычислительная эффективность и интерпретируемость выходят на передний план, является пока сложной задачей.

Разновидности KAN

Некоторые из проблем, описанные выше, частично были решены в следующих модификациях.

В первую очередь, была представлена версия Fast KAN, в которой B-сплайны заменены радиальными базисными функциями (RBF). Это изменение помогает уменьшить вычислительные затраты на использование сплайнов.

Также появилось несколько Polynomial KAN: Лежандра, Чебышева, Якоби, Грама, Бернштейна, и Wavelet-based KAN. Про каждый из них можно узнать подробнее, перейдя по соответствующим ссылкам.

Свёртки для KAN

Для начала вспомним, что такое свёрточный слой? Наиболее распространенный тип свёрточного слоя — это слой двухмерной свёртки, обычно сокращаемый как Conv2D. В этом слое фильтр (или ядро) "скользит" по двухмерным входным данным (например, картинкам), выполняя поэлементное умножение. Результаты суммируются в одно число. Ядро выполняет ту же операцию для каждой позиции, по которой оно скользит, преобразуя двухмерную матрицу признаков в другую.

Хотя свёртки в одномерном и трёхмерном пространствах имеют тот же принцип, они используют разные ядра, размеры входных данных и выходных данных. Однако для упрощения мы сосредоточимся на двухмерной свёрточном слое. Если вы хотите углубиться в эту тему, прочитайте этот хороший пост.

Визуализация работы двухмерного свёрточного слоя. Пример взят из Wikimedia.
Визуализация работы двухмерного свёрточного слоя. Пример взят из Wikimedia.

Как правило, после слоя свёртки применяются слой нормализации (например, BatchNorm, InstanceNorm и т.д.) и нелинейные активации (ReLU, LeakyReLU, SiLU и т.д.).

Более формально: предположим, что у нас есть входное изображение  y размером  N \times N. Для упрощения мы опустим ось канала (т.е. рассмотрим одноканальные изображения), она добавляет еще одно суммирование по оси канала. Итак, сначала нам нужно выполнить свертку с нашим ядром W размером m \times m:

x_{ij} = \sum_{a=0}^{m-1} \sum_{b=0}^{m-1} \omega_{ab} y_{i+a,j+b}; \quad i, j = 1, N - m + 1

Затем применяем батч нормализацию и нелинейность, например, ReLU:

x=ReLU(BatchNorm(x))

Свёрточный слой в KAN будет работать иначе: ядро будет содержать не обучаемые веса (конкретные числа), а унарные, обучаемые нелинейные функции. В этом случае ядро "скользит" по двухмерным входным данным, выполняя поэлементное применение функций активации из этого фильтра. Результатом каждого из применений будет число, которые потом суммируются.

Опять же, более формально: давайте применим к нашему входному изображению y свёрточный слой в терминах теоремы Колмогорова-Арнольда:

x_{ij} = \sum_{a=0}^{m-1} \sum_{b=0}^{m-1} \phi_{ab}(y_{i+a,j+b}); \quad i, j = 1, N - m + 1

Каждая \phi_{ab} является унарной нелинейной обучаемой функцией со своими обучаемыми параметрами. В оригинальной статье авторы предлагают использовать функцию следующего вида:

\phi(x) = \omega_{b}b(x) + w_{s}spline(x)

где w_{b} и w_{s} — обучаемые параметры, spline(x) — B-сплайн. — это \omega_b b(x)esidual activation functions, похожие на residual connctions из ResNet сетей. В оригинальной статье авторы предлагают выбрать в качестве b(x) функцию SiLU:

b(x) = SiLU(x) = \frac {x}{1 + e^{-x}}

Как было уже упомянуто выше, недавно было предложено использовать вместо сплайнов полиномиальные функции или RBF.

Подытожим вышесказанное: в классических свёрточных слоях ядра содержат просто веса (числа), тогда как KAN-свёртки содержат унарные функции:

K = \begin{bmatrix} w_{1,1} & w_{1,2} & w_{1,3} \\ w_{2,1} & w_{2,2} & w_{2,3} \\ w_{3,1} & w_{3,2} & w_{3,3} \end{bmatrix}, \quad K_{\text{kan}} = \begin{bmatrix} \phi_{1,1} & \phi_{1,2} & \phi_{1,3} \\ \phi_{2,1} & \phi_{2,2} & \phi_{2,3} \\ \phi_{3,1} & \phi_{3,2} & \phi_{3,3} \end{bmatrix}

Эксперименты

Реализацию различных типов KAN-свёрток, моделей, датасетов, сетапов экспериментов и многое другое можете найти моём репозитории torch-conv-kan.

MNIST

Итак, давайте начнём эксперименты со всем знакомого MNISTа.

Бейзлайн модели — простая нейронная сеть, состоящая из четырёх свёрточных слоёв. Для уменьшения размерности во втором и третьем свёрточных слоях используется параметр dilation=2.

Количество каналов в свёрточных слоях одинаково для всех моделей: 32, 64, 128, 256. После свёрток применяется Global Average Pooling, за которым следует линейный выходной слой. Кроме того, для регуляризиции используется Dropout слой: с параметром p = 0.25 в свёрточных слоях и p = 0.5 перед выходным слоем.

Пример реализации с использованием Pytorch и torch-conv-kan
import torch
import torch.nn as nn

from kan_convs import KANConv2DLayer


class SimpleConvKAN(nn.Module):
    def __init__(
            self,
            layer_sizes,
            num_classes: int = 10,
            input_channels: int = 1,
            spline_order: int = 3,
            groups: int = 1):
        super(SimpleConvKAN, self).__init__()
        self.layers = nn.Sequential(
            KANConv2DLayer(input_channels, layer_sizes[0], spline_order, kernel_size=3, groups=1, padding=1, stride=1,
                           dilation=1),
            KANConv2DLayer(layer_sizes[0], layer_sizes[1], spline_order, kernel_size=3, groups=groups, padding=1,
                           stride=2, dilation=1),
            KANConv2DLayer(layer_sizes[1], layer_sizes[2], spline_order, kernel_size=3, groups=groups, padding=1,
                           stride=2, dilation=1),
            KANConv2DLayer(layer_sizes[2], layer_sizes[3], spline_order, kernel_size=3, groups=groups, padding=1,
                           stride=1, dilation=1),
            nn.AdaptiveAvgPool2d((1, 1))
        )
        self.output = nn.Linear(layer_sizes[3], num_classes)
        self.drop = nn.Dropout(p=0.25)

    def forward(self, x):
        x = self.layers(x)
        x = torch.flatten(x, 1)
        x = self.drop(x)
        x = self.output(x)
        return x

Заметьте, что в случае классических свёрточных слоев, структура слоёв была бы примерно следующей: Conv2D -> Batch Normalization -> ReLU.

Для проведения экспериментов используются аугментации, которые вы можете посмотреть под катом.

Пример аугментаций с использованием torchvision
from torchvision.transforms import v2

transform_train = v2.Compose([
    v2.RandomAffine(degrees=20, translate=(0.1, 0.1), scale=(0.9, 1.1)),
    v2.ColorJitter(brightness=0.2, contrast=0.2),
    v2.ToTensor(),
    v2.Normalize((0.5,), (0.5,))
])

Кроме того, нам также необходимо исследовать влияние различных слоёв нормализации внутри свёрток KAN и влияние L1 регуляризации. В столбце Norm Layer во всех таблицах указывается, какой слой нормализации использовался во время эксперимента, а в столбце Affine указывается, был ли параметр affine слоя нормализации BatchNorm2d установлен как True или False.

Все эксперименты были выполнены с использованием NVIDIA RTX 3090 с идентичными параметрами.

Model

Accuracy

Parameters

Eval Time, s

Norm Layer

Affine

L1

SimpleConv, 4 layers

99.42

101066

0.7008

BatchNorm2D

False

0

SimpleKANConv, 4 layers

96.80

3488814

2.8306

InstanceNorm2D

False

0

SimpleKANConv, 4 layers

99.41

3489774

2.6362

InstanceNorm2D

True

0

SimpleKANConv, 4 layers

99.00

3489774

2.6401

BatchNorm2D

True

0

SimpleKANConv, 4 layers

98.89

3489774

2.4138

BatchNorm2D

True

1e-05

Модель на основе классических свёрток работает куда лучше нейронной сети с KAN-свёртами, которая к тому же имеет в 34 раза больше параметров и требует куда большего времени исполнения. Кажется, это не та "революция" в нейронных сетях, которую мы ожидали.

Давайте попробуем воспользоваться подходом Fast KAN и заменим сплайны на RBF:

Model

Accuracy

Parameters

Eval Time, s

Norm Layer

Affine

L1

SimpleConv, 4 layers

99.42

101066

0.7008

BatchNorm2D

True

0

SimpleFastKANConv, 4 layers

99.26

3488810

1.5636

InstanceNorm2D

False

0

SimpleFastKANConv, 4 layers

99.01

3489260

1.7406

InstanceNorm2D

True

0

SimpleFastKANConv, 4 layers

97.65

3489260

1.5999

BatchNorm2D

True

0

SimpleFastKANConv, 4 layers

95.62

3489260

1.6158

BatchNorm2D

True

1e-05

По времени стало получше, но при этом сами результаты неоднозначные: где-то стало лучше, где-то хуже.

Тогда попробуем ещё одну функцию вместо сплайнов: дискретные полиномы Чебышева. Сам факт что полиномы дискретны в теории должно дать неплохие результаты для обработки дискретных данных — изображений и текстов. Давайте узнаем, так ли это:

Model

Accuracy

Parameters

Eval Time, s

Norm Layer

Affine

L1

SimpleConv, 4 layers

99.42

101066

0.7008

BatchNorm2D

True

0

SimpleKAGNConv, 4 layers

98.21

487866

1.8506

InstanceNorm2D

False

0

SimpleKAGNConv, 4 layers

99.46

488826

1.8813

InstanceNorm2D

True

0

SimpleKAGNConv, 4 layers

99.49

488826

1.7253

BatchNorm2D

True

0

SimpleKAGNConv, 4 layers

99.44

488826

1.8979

BatchNorm2D

True

1e-05

Новые KAN-свёртки выдают качество чуть лучше традиционной модели, но работают в 2.5 раза медленнее и имеют почти в пять раз больше параметров. L1 регуляризация немного снижает производительность модели, но это область для дальнейших улучшений.

CIFAR 100

Бейзлайн модели здесь меняется. Теперь у нас будет восемь свёрточных слоёв . Для уменьшения размерности во втором, третьем и шестом свёрточных слоях используется параметр dilation=2. Для каждого слоя число каналов равно соответвующе 16, 32, 64, 128, 256, 256, 512, 512. Всё остальное в модели остается без изменений.

Аугментации, которые используются в экспериментах с CIFAR 100 можете посмотреть под катом.

Пример аугментаций с использованием torchvision
from torchvision.transforms import v2
from torchvision.transforms.autoaugment import AutoAugmentPolicy

transform_train = v2.Compose([
    v2.RandomHorizontalFlip(p=0.5),
    v2.RandomChoice([v2.AutoAugment(AutoAugmentPolicy.CIFAR10),
        v2.AutoAugment(AutoAugmentPolicy.IMAGENET),
        v2.AutoAugment(AutoAugmentPolicy.SVHN),
        v2.TrivialAugmentWide()]),
    v2.ToTensor(),
    v2.Normalize((0.5,), (0.5,))
])

Model

Accuracy

Parameters

Eval Time, s

Norm Layer

Affine

L1

SimpleConv, 8 layers

57.52

1187172

1.8265

BatchNorm2D

True

0

SimpleKAGNConv, 8 layers

29.39

22655732

2.5358

InstanceNorm2D

False

0

SimpleKAGNConv, 8 layers

48.56

22659284

2.0454

InstanceNorm2D

True

0

SimpleKAGNConv, 8 layers

59.27

22659284

2.6460

BatchNorm2D

True

0

SimpleKAGNConv, 8 layers

58.07

22659284

2.2583

BatchNorm2D

True

1e-05

KAN-свёртки на основе дискретных полиномов Чебышева показывают лучшее качество, хотя и с бОльшими временными затратами и существенно бОльшим количеством параметров (более чем в 20 раз больше). BatchNorm2D, по-видимому, является лучшим вариантом для нормализации внутренних признаков в KAN-свёртках на дискретных полиномов Чебышева.

Использование последних кажется многообещающим для дальнейших экспериментов на ImageNet1k и других "реальных" наборах данных.

Заключение

Итак, могут ли KAN справляться с компьютерным зрением? Кажется, что да! Смогут ли они заменить классические CNN? Это еще предстоит выяснить.

MLP используется уже много лет и нуждается в обновлении. Мы уже видели подобные изменения. Например, шесть лет назад сети Long Short-Term Memory (LSTM), которые долгое время были основой для моделирования последовательностей, были заменены трансформерами в качестве стандартного строительного блока для архитектуры языковых моделей. Похожий сдвиг для MLP был бы интригующим.

Свёрточные нейронные сети, которые доминировали в течение многих лет (и до сих пор являются основой для компьютерного зрения), в конечном итоге были частично заменены визуальными трансформерами (ViT). Возможно, пришло время для нового лидера в этой области?

Однако прежде чем это произойдет, сообществу необходимо найти эффективные методы для обучения сетей Колмогорова-Арнольда, свёрточных сетей Колмогорова-Арнольда (ConvKAN) и ViT-KAN, а также решить проблемы, которые связаны с этими моделями (например, скорость инференса, количество параметров).

Хотя меня очень вдохновляет эта новая архитектура и начальные эксперименты показывают обнадеживающие результаты, я остаюсь несколько скептичен. Необходимы ещё эксперименты. Оставайтесь с нами, мы собираемся углубиться в эту тему.

Tags:
Hubs:
Total votes 17: ↑17 and ↓0+23
Comments2

Articles