Недавняя статья об новой архитектуре нейронных сетей на основе теоремы Колмогорова-Арнольда (KAN Kolmogorov-Arnold Networks) вызвала большой ажиотаж: уже было представлено множество вариаций того, как правильно создавать такие сети, ведутся горячие дебаты, а рабочая ли схема и имеет ли право на жизнь и многое другое. Цель этой статьи постараться ответить на простой вопрос: могут ли KAN справляться с компьютерным зрением?

Исходный код всех экспериментов данной статьи можете найти по ссылке.

Пока ещё не ушли далеко, сразу же скажем, что у этой статьи два автора: Иван Дрокин, который написал оригинал на английском языке, проводил эксперименты и написал собственную библиотеку свёрток для KAN — torch-conv-kan. За вольный перевод для читателей Хабра, многочисленные ревью и правки оригинала статьи, а также часть глупых мемов отвечал Антон Клочков — автор телеграм-канала MLE шатает Produnction (нет, в названии нет опечатки) про нейронные сети, мемы, языки программирования и рассуждения об общих топиках в IT.

Все рассуждения в статье ведутся от лица Ивана.

Основы

Давайте начнем с самого фундаментального — математики. Если говорить кратко, то многослойные перцептроны (MLP — Multi-Layer Perceptron) представляют собой нелинейную функцию от взвешенной суммы входных данных, тогда как сети Колмогорова-Арнольда представляют собой сумму нелинейных унарных функций от входных данных.

Работоспособность MLP основывается на теореме Цыбенко или более известной в широких кругах, как универсальной теореме апроксимации. Если не вдаваться в детали, то она утверждает, что любую функцию можно апроксимировать с любой точностью с помощью одного достаточно широкого скрытого слоя с нелинейной функцией активации.

Чтобы не расписывать формулы в статье, скрин из Wikipedia с необходимым утверждением

С другой стороны, как уже было написано в начале статьи, работа KAN базируется на основе теоремы Колмогорова-Арнольда:

Чтобы не расписывать формулы в статье, скрин из Wikipedia с необходимым утверждением

Авторы статьи "KAN: Kolmogorov-Arnold Networks" разработали новый тип архитектуры нейронных сетей: обучаемые активации на рёбрах и суммирование результатов в узлах. В противовес этому, в MLP в узлах применяется фиксированная нелинейная функция, а на ребрах производится линейная проекция входов.

Наглядное представление разницы между двумя архитектурами. Картинка взята из статьи "KAN: Kolmogorov-Arnold Networks".

Что же это за "обучаемые активации" и как их обучать? В оригинальной статье было предложено использовать B-сплайны (с классными визуализациями можно глянуть следующую статью). Сплайн-функция степени представляет собой кусочно-полиномиальную функцию степени . Места, где отрезки соединяются, называются узлами. Ключевое свойство сплайн-функций заключается в том, что они и их производные могут быть непрерывными, в зависимости от кратности узлов.

Визуализация B-сплайна (синяя линия) и контролирующих её точек (красным цветом). Источник.

Плюсы и минусы KAN

Авторы выделяют следующие плюсы использование KAN относительно MLP:

  1. Выше качество: KAN показали точность выше на большом множестве задач в сравнении с MLP. Первые могут более эффективно (в терминах качества) представлять сложные многомерные функции (спасибо теореме Колмогорова-Арнольда), что положительно сказывается на качестве.

  2. Интерпретируемость: MLP — это чёрный ящик, сложно сказать, что происходит у них внутри, а KAN может предложить интересные возможности. Например, можно разложить сложную многомерную функцию на более простые компоненты, анализ которых может служить инсайтом о поведении модели, о работе на конкретных данных.

  3. Гибкость и обобщаемость: за счёт обучаемости активаций можно лучше находить нелинейные зависимости в данных, что также ведёт к обобщаемости (но это не так просто).

  4. Устойчивость к шумным данным и Adversarial Attacks: способность KAN улавливать более устойчивые представления данных с помощью адаптивных функций активации позволяют KAN быть более устойчивым к шумам и атакам.

Но есть и определённые сложности с KAN (No free lunch, помните?):

  1. Чувствительность к гиперпараметрам: как и любая другая нейронная сеть, KAN чувствительна к множеству гиперпараметров, таких как learning rate, силу регуляризации, самой архитектуре. Проблема подбора правильных параметров остаётся актуальной и может в значительной степени влиять на сходимость. Тут стоит отметить, что на настоящий момент нет чётких рецептов того, как именно тюнить разного рода гиперпараметры (L1/L2 веса, параметры активаций, dropout коэффициенты и т.д.). Это чем-то напоминает времена, когда только начиналось развитие свёрточных нейронных сетей.

  2. Высокие вычислительные затраты: адаптивные функции активации (например, B-сплайны) могут требовать больше ресурсов, чем классические MLP, что отражается на длительности и стоимости обучения и инференса;

  3. Сложность моделей и масштабируемость: KAN масштабируемы в терминах гибкости модели, но при этом более глубокие сети ещё сильнее увеличивают вычислительные затраты. Масштабирование KAN на большие датасеты и сложные задачи, где вычислительная эффективность и интерпретируемость выходят на передний план, является пока сложной задачей.

Разновидности KAN

Некоторые из проблем, описанные выше, частично были решены в следующих модификациях.

В первую очередь, была представлена версия Fast KAN, в которой B-сплайны заменены радиальными базисными функциями (RBF). Это изменение помогает уменьшить вычислительные затраты на использование сплайнов.

Также появилось несколько Polynomial KAN: Лежандра, Чебышева, Якоби, Грама, Бернштейна, и Wavelet-based KAN. Про каждый из них можно узнать подробнее, перейдя по соответствующим ссылкам.

Свёртки для KAN

Для начала вспомним, что такое свёрточный слой? Наиболее распространенный тип свёрточного слоя — это слой двухмерной свёртки, обычно сокращаемый как Conv2D. В этом слое фильтр (или ядро) "скользит" по двухмерным входным данным (например, картинкам), выполняя поэлементное умножение. Результаты суммируются в одно число. Ядро выполняет ту же операцию для каждой позиции, по которой оно скользит, преобразуя двухмерную матрицу признаков в другую.

Хотя свёртки в одномерном и трёхмерном пространствах имеют тот же принцип, они используют разные ядра, размеры входных данных и выходных данных. Однако для упрощения мы сосредоточимся на двухмерной свёрточном слое. Если вы хотите углубиться в эту тему, прочитайте этот хороший пост.

Визуализация работы двухмерного свёрточного слоя. Пример взят из Wikimedia.

Как правило, после слоя свёртки применяются слой нормализации (например, BatchNorm, InstanceNorm и т.д.) и нелинейные активации (ReLU, LeakyReLU, SiLU и т.д.).

Более формально: предположим, что у нас есть входное изображение размером . Для упрощения мы опустим ось канала (т.е. рассмотрим одноканальные изображения), она добавляет еще одно суммирование по оси канала. Итак, сначала нам нужно выполнить свертку с нашим ядром размером :

Затем применяем батч нормализацию и нелинейность, например, ReLU:

Свёрточный слой в KAN будет работать иначе: ядро будет содержать не обучаемые веса (конкретные числа), а унарные, обучаемые нелинейные функции. В этом случае ядро "скользит" по двухмерным входным данным, выполняя поэлементное применение функций активации из этого фильтра. Результатом каждого из применений будет число, которые потом суммируются.

Опять же, более формально: давайте применим к нашему входному изображению свёрточный слой в терминах теоремы Колмогорова-Арнольда:

Каждая является унарной нелинейной обучаемой функцией со своими обучаемыми параметрами. В оригинальной статье авторы предлагают использовать функцию следующего вида:

где и — обучаемые параметры, — B-сплайн. — это esidual activation functions, похожие на residual connctions из ResNet сетей. В оригинальной статье авторы предлагают выбрать в качестве функцию SiLU:

Как было уже упомянуто выше, недавно было предложено использовать вместо сплайнов полиномиальные функции или RBF.

Подытожим вышесказанное: в классических свёрточных слоях ядра содержат просто веса (числа), тогда как KAN-свёртки содержат унарные функции:

Эксперименты

Реализацию различных типов KAN-свёрток, моделей, датасетов, сетапов экспериментов и многое другое можете найти моём репозитории torch-conv-kan.

MNIST

Итак, давайте начнём эксперименты со всем знакомого MNISTа.

Бейзлайн модели — простая нейронная сеть, состоящая из четырёх свёрточных слоёв. Для уменьшения размерности во втором и третьем свёрточных слоях используется параметр dilation=2.

Количество каналов в свёрточных слоях одинаково для всех моделей: 32, 64, 128, 256. После свёрток применяется Global Average Pooling, за которым следует линейный выходной слой. Кроме того, для регуляризиции используется Dropout слой: с параметром p = 0.25 в свёрточных слоях и p = 0.5 перед выходным слоем.

Пример реализации с использованием Pytorch и torch-conv-kan
import torch
import torch.nn as nn

from kan_convs import KANConv2DLayer


class SimpleConvKAN(nn.Module):
    def __init__(
            self,
            layer_sizes,
            num_classes: int = 10,
            input_channels: int = 1,
            spline_order: int = 3,
            groups: int = 1):
        super(SimpleConvKAN, self).__init__()
        self.layers = nn.Sequential(
            KANConv2DLayer(input_channels, layer_sizes[0], spline_order, kernel_size=3, groups=1, padding=1, stride=1,
                           dilation=1),
            KANConv2DLayer(layer_sizes[0], layer_sizes[1], spline_order, kernel_size=3, groups=groups, padding=1,
                           stride=2, dilation=1),
            KANConv2DLayer(layer_sizes[1], layer_sizes[2], spline_order, kernel_size=3, groups=groups, padding=1,
                           stride=2, dilation=1),
            KANConv2DLayer(layer_sizes[2], layer_sizes[3], spline_order, kernel_size=3, groups=groups, padding=1,
                           stride=1, dilation=1),
            nn.AdaptiveAvgPool2d((1, 1))
        )
        self.output = nn.Linear(layer_sizes[3], num_classes)
        self.drop = nn.Dropout(p=0.25)

    def forward(self, x):
        x = self.layers(x)
        x = torch.flatten(x, 1)
        x = self.drop(x)
        x = self.output(x)
        return x

Заметьте, что в случае классических свёрточных слоев, структура слоёв была бы примерно следующей: Conv2D -> Batch Normalization -> ReLU.

Для проведения экспериментов используются аугментации, которые вы можете посмотреть под катом.

Пример аугментаций с использованием torchvision
from torchvision.transforms import v2

transform_train = v2.Compose([
    v2.RandomAffine(degrees=20, translate=(0.1, 0.1), scale=(0.9, 1.1)),
    v2.ColorJitter(brightness=0.2, contrast=0.2),
    v2.ToTensor(),
    v2.Normalize((0.5,), (0.5,))
])

Кроме того, нам также необходимо исследовать влияние различных слоёв нормализации внутри свёрток KAN и влияние L1 регуляризации. В столбце Norm Layer во всех таблицах указывается, какой слой нормализации использовался во время эксперимента, а в столбце Affine указывается, был ли параметр affine слоя нормализации BatchNorm2d установлен как True или False.

Все эксперименты были выполнены с использованием NVIDIA RTX 3090 с идентичными параметрами.

Model

Accuracy

Parameters

Eval Time, s

Norm Layer

Affine

L1

SimpleConv, 4 layers

99.42

101066

0.7008

BatchNorm2D

False

0

SimpleKANConv, 4 layers

96.80

3488814

2.8306

InstanceNorm2D

False

0

SimpleKANConv, 4 layers

99.41

3489774

2.6362

InstanceNorm2D

True

0

SimpleKANConv, 4 layers

99.00

3489774

2.6401

BatchNorm2D

True

0

SimpleKANConv, 4 layers

98.89

3489774

2.4138

BatchNorm2D

True

1e-05

Модель на основе классических свёрток работает куда лучше нейронной сети с KAN-свёртами, которая к тому же имеет в 34 раза больше параметров и требует куда большего времени исполнения. Кажется, это не та "революция" в нейронных сетях, которую мы ожидали.

Давайте попробуем воспользоваться подходом Fast KAN и заменим сплайны на RBF:

Model

Accuracy

Parameters

Eval Time, s

Norm Layer

Affine

L1

SimpleConv, 4 layers

99.42

101066

0.7008

BatchNorm2D

True

0

SimpleFastKANConv, 4 layers

99.26

3488810

1.5636

InstanceNorm2D

False

0

SimpleFastKANConv, 4 layers

99.01

3489260

1.7406

InstanceNorm2D

True

0

SimpleFastKANConv, 4 layers

97.65

3489260

1.5999

BatchNorm2D

True

0

SimpleFastKANConv, 4 layers

95.62

3489260

1.6158

BatchNorm2D

True

1e-05

По времени стало получше, но при этом сами результаты неоднозначные: где-то стало лучше, где-то хуже.

Тогда попробуем ещё одну функцию вместо сплайнов: дискретные полиномы Чебышева. Сам факт что полиномы дискретны в теории должно дать неплохие результаты для обработки дискретных данных — изображений и текстов. Давайте узнаем, так ли это:

Model

Accuracy

Parameters

Eval Time, s

Norm Layer

Affine

L1

SimpleConv, 4 layers

99.42

101066

0.7008

BatchNorm2D

True

0

SimpleKAGNConv, 4 layers

98.21

487866

1.8506

InstanceNorm2D

False

0

SimpleKAGNConv, 4 layers

99.46

488826

1.8813

InstanceNorm2D

True

0

SimpleKAGNConv, 4 layers

99.49

488826

1.7253

BatchNorm2D

True

0

SimpleKAGNConv, 4 layers

99.44

488826

1.8979

BatchNorm2D

True

1e-05

Новые KAN-свёртки выдают качество чуть лучше традиционной модели, но работают в 2.5 раза медленнее и имеют почти в пять раз больше параметров. L1 регуляризация немного снижает производительность модели, но это область для дальнейших улучшений.

CIFAR 100

Бейзлайн модели здесь меняется. Теперь у нас будет восемь свёрточных слоёв . Для уменьшения размерности во втором, третьем и шестом свёрточных слоях используется параметр dilation=2. Для каждого слоя число каналов равно соответвующе 16, 32, 64, 128, 256, 256, 512, 512. Всё остальное в модели остается без изменений.

Аугментации, которые используются в экспериментах с CIFAR 100 можете посмотреть под катом.

Пример аугментаций с использованием torchvision
from torchvision.transforms import v2
from torchvision.transforms.autoaugment import AutoAugmentPolicy

transform_train = v2.Compose([
    v2.RandomHorizontalFlip(p=0.5),
    v2.RandomChoice([v2.AutoAugment(AutoAugmentPolicy.CIFAR10),
        v2.AutoAugment(AutoAugmentPolicy.IMAGENET),
        v2.AutoAugment(AutoAugmentPolicy.SVHN),
        v2.TrivialAugmentWide()]),
    v2.ToTensor(),
    v2.Normalize((0.5,), (0.5,))
])

Model

Accuracy

Parameters

Eval Time, s

Norm Layer

Affine

L1

SimpleConv, 8 layers

57.52

1187172

1.8265

BatchNorm2D

True

0

SimpleKAGNConv, 8 layers

29.39

22655732

2.5358

InstanceNorm2D

False

0

SimpleKAGNConv, 8 layers

48.56

22659284

2.0454

InstanceNorm2D

True

0

SimpleKAGNConv, 8 layers

59.27

22659284

2.6460

BatchNorm2D

True

0

SimpleKAGNConv, 8 layers

58.07

22659284

2.2583

BatchNorm2D

True

1e-05

KAN-свёртки на основе дискретных полиномов Чебышева показывают лучшее качество, хотя и с бОльшими временными затратами и существенно бОльшим количеством параметров (более чем в 20 раз больше). BatchNorm2D, по-видимому, является лучшим вариантом для нормализации внутренних признаков в KAN-свёртках на дискретных полиномов Чебышева.

Использование последних кажется многообещающим для дальнейших экспериментов на ImageNet1k и других "реальных" наборах данных.

Заключение

Итак, могут ли KAN справляться с компьютерным зрением? Кажется, что да! Смогут ли они заменить классические CNN? Это еще предстоит выяснить.

MLP используется уже много лет и нуждается в обновлении. Мы уже видели подобные изменения. Например, шесть лет назад сети Long Short-Term Memory (LSTM), которые долгое время были основой для моделирования последовательностей, были заменены трансформерами в качестве стандартного строительного блока для архитектуры языковых моделей. Похожий сдвиг для MLP был бы интригующим.

Свёрточные нейронные сети, которые доминировали в течение многих лет (и до сих пор являются основой для компьютерного зрения), в конечном итоге были частично заменены визуальными трансформерами (ViT). Возможно, пришло время для нового лидера в этой области?

Однако прежде чем это произойдет, сообществу необходимо найти эффективные методы для обучения сетей Колмогорова-Арнольда, свёрточных сетей Колмогорова-Арнольда (ConvKAN) и ViT-KAN, а также решить проблемы, которые связаны с этими моделями (например, скорость инференса, количество параметров).

Хотя меня очень вдохновляет эта новая архитектура и начальные эксперименты показывают обнадеживающие результаты, я остаюсь несколько скептичен. Необходимы ещё эксперименты. Оставайтесь с нами, мы собираемся углубиться в эту тему.