Мы уже писали о том, что предложили новую модель квантования нейронных сетей, позволяющую ускорить их на 40% на центральных процессорах, а также о том, как она устроена вот тут.

Сегодня мы расскажем о том, как мы в Smart Engines обучали 4.6-битные сети. Основная проблема обучения квантованных моделей в том, что градиентные методы модифицируют веса непрерывным образом, а у квантованных сетей они дискретные. 

Кроме того, распространен сценарий, когда сначала обучается вещественная сеть, а затем ее хочется отквантовать, чтобы ускорить систему. Часто обучающие данные при этом уже недоступны или доступны не в полном объеме.

Поэтому методы квантования делятся на 2 большие группы:

  • Post Training Quantization (PTQ): когда мы квантуем уже обученную сеть. Как правило, они минимизируют ошибку между выходами каждого слоя вещественной и квантованной сетей. Такие методы хорошо работают на больших моделях, в которых есть избыточные веса, а в случае компактных сетей для центральных процессоров отквантоваться без снижения точности обычно не получается.

  • Quantization Aware Training (QAT): когда мы уже во время обучения применяем методы, повышающие качество квантованной модели.

В Smart Engines все модели уже компактные, а обучающие данные никуда не деваются, так что нам в первую очередь интересны QAT методы или их комбинации с PTQ методами. Расскажем о них поподробнее.

Метод обучения

Самый очевидный способ обучить квантованную сеть – игнорировать квантование во время обратного прохода по сети, а модифицированные веса округлить. Тогда прямой проход будет уже квантованным и соответствовать “рабочему” режиму сети. Так работает Straight Through Estimator (STE), один из старейших и популярных методов обучения квантованных сетей. На удивление, он неплохо работает, если взять уже обученную 8-битную модель и ее постепенно квантовать, то есть 256 значений для сети оказывается не так уж и мало. Однако округление на каждой итерации делает процесс обучения “шумным” и может привести к невозможности обучить сеть вообще, что и наблюдается в случае меньших разрядностей.

Поэтому даже для 8-битных моделей есть методы получше, например, инкрементное обучение.

Идея инкрементного обучения или понейронной конвертации заключается в том, что мы постепенно преобразуем часть весов каждого слоя в квантованный вид. Уже квантованная часть сети немного дообучается (например, с помощью STE) и “замораживается” (веса фиксируются и больше не меняются). Вещественная часть сети дообучается, чтобы адекватно обрабатывать изменившиеся входные значения от квантованной части. Этот подход прост, эффективен в отношении 4- и 8-битных сетей, но требует времени.

AdaQuant – PTQ метод, работающий послойно. Его идея которого заключается в том, что мы используем небольшой калибровочный датасет, чтобы определить динамические диапазоны активаций, а также коэффициент масштабирования и точку нуля так, чтобы минимизировать среднеквадратичную ошибку между выходом квантованного и вещественного слоев.

В данной работе мы использовали комбинацию этих подходов. Сначала параметры каждого слоя отквантовали с помощью Adaquant, а затем применили инкрементное обучение, а именно прошли от начала к концу и еще немного дообучили каждый из квантованных слоев (остальные “замораживались”) с помощью стохастического градиентного спуска с моментом 0.9 и скоростью обучения (learning rate) .

Далее мы рассмотрели три задачи компьютерного зрения, решаемые с помощью разных нейросетевых архитектур, и применили в них 4.6-битные сети. Кстати, наш код экспериментов есть на GitHub: https://github.com/SmartEngines/QNN_training_4.6bit

CIFAR-10

Это давно набившая оскомину небольшая выборка с цветными изображениями объектов 10 классов размера 32 на 32 (см. примеры на рис. 1). В качестве аугментации использовали отражения вдоль вертикальной оси, вырезание случайных регионов со сдвигами и случайные повороты на +-9 градусов с дополнением изображений на 4 пикселя с краев.

Рис. 1. Примеры изображений из выборки CIFAR-10.

На ней мы обучили несколько простых сверточных моделей. Их архитектуры приведены в таблице 1. 

Обозначения:

  • conv(c, f, k, [p]) – сверточный слой с c-канальным входом, f фильтрами размера k х k и паддингом p (в обоих направлениях, по умолчанию 0),

  • pool(n) – 2D max-pooling с окном n на n,

  • bn – batch normalization,

  • HardTanh – активация HardTanh(x) = min(1, max(-1, x)),

  • relu6 – активация ReLU6(x) = min(max(0, x),6),

  • tanh – гиперболический тангенс,

  • fc(n) – полносвязный слой с n выходами.

Таблица 1. Архитектуры сверточных сетей для CIFAR-10. 

CNN6

CNN7

CNN8

CNN9

CNN10

conv(3, 4, 1)

HardTanh

conv(3, 8, 1)

HardTanh

conv(3, 8, 1)

HardTanh

conv(3, 8, 1)

HardTanh

conv(3, 8, 1)

HardTanh

conv(4, 8, 5)

bn+relu6

pool(2)

conv(8, 8, 3)

bn+relu6

conv(8, 12, 3)

bn+relu6

pool(2)

conv(8, 8, 3)

bn+relu6

conv(8, 12, 3)

bn+relu6

pool(2)

conv(8, 8, 3)

bn+relu6

conv(8, 12, 3)

bn+relu6

pool(2)

conv(8, 16, 3, 1)

bn+relu6

conv(16, 32, 3, 1)

bn+relu6

pool(2)

conv(8, 16, 3)

bn+relu6

pool(2) 

conv(12, 16, 3)

bn+relu6

pool(2) 

conv(12, 24, 3)

bn+relu6

pool(2) 

conv(12, 12, 3, 1)

bn+relu6

conv(12, 24, 3)

bn+relu6

pool(2) 

conv(32, 32, 3, 1)

bn+relu6

conv(32, 64, 3, 1)

bn+relu6 

pool(2) 

conv(16, 32, 3)

bn+relu6

pool(2)

conv(16, 32, 3)

bn+relu6

pool(2)

conv(24, 24, 3)

bn+relu6

conv(24, 40, 3) bn+relu6

conv(24, 24, 3)

bn+relu6

conv(24, 48, 3)

bn+relu6

conv(64, 64, 3)

bn+relu6

conv(64, 64, 3)

bn+relu6

conv(64, 128, 3)

bn+relu6

fc(64)

tanh

fc(64)

tanh

fc(64)

tanh

fc(96)

tanh

fc(256)

tanh

fc(10)

fc(10)

fc(10)

fc(10)

fc(10)

Trainable parameters

15.6k

16.9k

29.1k

40.7k

315.6k

Сначала мы обучили вещественные версии этих сетей 250 эпох с оптимизатором AdamW с параметрами по умолчанию, кроме коэффициента убывания весов (weight decay), который был установлен в , и скорости обучения (learning rate), которая была задана как и уменьшалась вдвое каждые 50 эпох. Размер батча был 100.

Дальше слои батч нормализации были интегрированы в свертки и мы приступили к квантованию. Первый и последний слой мы не квантовали, так как это улучшает результаты обучения и практически не снижает вычислительную эффективность. Результаты приведены в таблице 2. Напомним, что у нашей схемы квантования есть параметры и , которые обозначают количество дискретов для активаций и весов соответственно, мы перебрали все возможные их комбинации. Эксперименты повторили 5 раз с разной начальной инициализацией сети и посчитали среднюю точность и погрешность.

Таблица 2. Точность классификации на выборке CIFAR-10

Quantization

Accuracy, %

CNN6

CNN7

CNN8

CNN9

CNN10

5

127

63.5 ± 0.3

69.6 ± 0.1

70.6 ± 0.3

71.7 ± 0.3

81.3 ± 0.3

7

85

68.4 ± 0.3

73.4 ± 0.1

74.7 ± 0.2

76.2 ± 0.2

85.4 ± 0.1

9

63

70.9 ± 0.1

75.0 ± 0.2

76.4 ± 0.2

78.0 ± 0.2

86.9 ± 0.2

11

51

71.8 ± 0.3

75.7 ± 0.1

77.4 ± 0.1

79.0 ± 0.2

87.6 ± 0.1

13

43

72.7 ± 0.2

76.2 ± 0.1

77.8 ± 0.1

79.5 ± 0.1

88.0 ± 0.1

15

37

73.1 ± 0.2

76.4 ± 0.1

78.1 ± 0.1

79.8 ± 0.2

88.2 ± 0.2

17

31

73.0 ± 0.1

76.6 ± 0.2

78.1 ± 0.3

79.8 ± 0.2

88.2 ± 0.1

19

29

73.1 ± 0.2

76.4 ± 0.2

78.5 ± 0.1

80.0 ± 0.3

88.5 ± 0.3

21

25

73.4 ± 0.2

76.7 ± 0.3

78.3 ± 0.2

79.9 ± 0.1

88.4 ± 0.2

23

23

73.3 ± 0.3

76.5 ± 0.1

78.2 ± 0.2

79.9 ± 0.3

88.4 ± 0.2

25

21

73.1 ± 0.1

76.6 ± 0.2

78.2 ± 0.2

79.9 ± 0.2

88.5 ± 0.2

29

19

73.0 ± 0.1

76.3 ± 0.1

78.3 ± 0.2

79.9 ± 0.2

88.3 ± 0.2

31

17

73.1 ± 0.2

76.1 ± 0.1

78.0 ± 0.3

79.7 ± 0.2

88.2 ± 0.1

37

15

72.8 ± 0.2

75.5 ± 0.4

77.7 ± 0.3

79.4 ± 0.2

87.9 ± 0.3

43

13

72.0 ± 0.4

74.8 ± 0.2

77.5 ± 0.2

79.0 ± 0.3

87.9 ± 0.1

51

11

70.9 ± 0.3

74.0 ± 0.1

76.0 ± 0.3

78.1 ± 0.2

87.5 ± 0.1

63

9

69.0 ± 0.4

71.7 ± 0.3

74.3 ± 0.5

76.7 ± 0.4

86.3 ± 0.1

85

7

65.9 ± 0.5

67.7 ± 1.0

70.6 ± 0.4

73.4 ± 0.7

84.5 ± 0.3

127

5

47.5 ± 0.4

55.2 ± 0.6

58.2 ± 1.1

67.5 ± 0.4

74.9 ± 2.3

4 бита

72.0 ± 0.2

75.4 ± 0.2

77.3 ± 0.2

79.3 ± 0.3

87.7 ± 0.2

8 бит

74.7 ± 0.1

77.6 ± 0.1

79.4 ± 0.1

80.8 ± 0.1

89.2 ± 0.1

float32

74.95

77.83

79.66

81.4

89.07

Можно видеть, что 8-битные модели демонстрируют почти ту же точность, что и вещественные, а вот 4-битные им сильно уступают. Для 4.6-битных моделей оказалось, что лучше всего распределять дискреты между активациями и весами равномерно:  наилучшие результаты показывают схемы квантования от (15, 37) до (31, 17). Отметим, что точность классификации в этом диапазоне заметно лучше, чем у 4-битных моделей. Тем не менее, 4.6-битные сети все же немного уступили 8-битным и вещественным моделям, что говорит о том, что метод обучения можно придумать и получше.

ImageNet

Это уже стандартная выборка для оценки точности классифицирующих моделей (см. примеры на рис. 2). 

Рис. 2. Примеры изображений из ImageNet.

Здесь мы использовали ResNet’ы: ResNet-18 (11.7М весов) и ResNet-34 (21.8М весов), предобученные модели из PyTorch's Torchvision. Все активации ReLU заменили на ReLU6, батч нормализацию интегрировали в свертки и отквантовали, но первый и последний слои также не трогали. 

Калибровочная выборка была меньше: 25 батчей по 64 изображения для ResNet-18 и 5 батчей по 64 изображения для ResNet-34.

Точности классификации top-1 и top-5 для выборки ImageNet показаны в таблице 3. Лучшая пара параметров для 4.6-битных моделей оказалась (, ) = (23, 23). Также как и на CIFAR-10 8-битные модели имеют сравнимое качество с вещественными, а 4-битные сильно хуже. 4.6-битные сети оказались почти ровно посередине. Это ухудшение не такое большое и с ним уже можно работать!

Таблица 3. Точность классификации на ImageNet.

Quantization

ResNet-18

ResNet-34

top-1, %

top-5, %

top-1, %

top-5, %

29

19

65.6 ± 0.3

86.7 ± 0.1

68.6 ± 0.1

88.5 ± 0.1

25

21

65.9 ± 0.2

86.9 ± 0.1

68.8 ± 0.3

88.7 ± 0.2

23

23

66.1 ± 0.1

87.0 ± 0.1

69.1 ± 0.2

88.9 ± 0.1

4 бита

64.2 ± 0.2

85.7 ± 0.1

66.1 ± 0.3

87.0 ± 0.2

8 бит

68.3 ± 0.1

88.3 ± 0.1

71.4 ± 0.1

90.1 ± 0.1

float32

68.7

88.5

72.3

90.8

TCIA

Это выборка с изображениями МРТ головного мозга, изображения RGB, размер каждого 256 на 256. На этих изображениях ставится задача сегментации: нужно найти и отметить аномалии. Эта задача прекрасно  решается с помощью модели U-Net с 7.76М параметров. Мы взяли предобученную модель тут. Она принимает на вход изображение и выдает бинарную маску, которая отмечает аномалии. 

Также как и в ResNet’ах мы заменили активации ReLU на ReLU6 и отквантовали сеть. Размер калибровочной выборки был 5 батчей по 10 изображений. Для оценки качества сегментации использовались средние значения Dice и Intersection over Union (IoU) для тех изображений, где аномалии были. Также посчитали метрики качества бинарной классификации (есть аномалия или нет): accuracy, precision, recall, ошибку 1-го рода (false positive rate) и ошибку 2-го рода (false negative rate). Метрики приведены в таблице 4, а примеры работы на рис. 3.

Таблица 4. Качество сегментации моделей U-Net на выборке TCIA.

float32

8 бит

4.6 бита

4 бита

Dice ↑

0.7643

0.7843 ± 0.0008

0.769 ± 0.006

0.746 ± 0.009

IoU ↑

0.6875

0.7046 ± 0.0008

0.688 ± 0.005

0.662 ± 0.009

Accuracy ↑

0.8119

0.8124 ± 0.0013

0.781 ± 0.013

0.57 ± 0.06

Precision ↑

0.6654

0.6624 ± 0.0016

0.623 ± 0.015

0.45 ± 0.04

Recall ↑

0.9286

0.9447 ± 0.0009

0.948 ± 0.004

0.969 ± 0.013

Type I ↓

0.1631

0.1682 ± 0.0011

0.201 ± 0.013

0.42 ± 0.07

Type II ↓

0.0249

0.0193 ± 0.0003

0.0182 ± 0.0013

0.011 ± 0.005

Рис. 3. Примеры изображений TCIA, разметка и результаты сегментации.

Можно видеть, что 8-битная сеть даже немного лучше вещественной: у нее чуть больше ошибка 1-го рода, но при этом выше качество сегментации. 4-битной сети откровенно плохо как по метрикам, так и визуально: добавился заметный шум. А 4.6-битная модель работает разумным образом и демонстрирует качество близкое к качеству 8-битной и вещественной моделей.

Обсуждение

Наши эксперименты ясно показали, что даже при использовании простейших методов обучения, не слишком отличающихся от подходов для 4- и 8-битных моделей, 4.6-битные сети работают заметно лучше 4-битных при практически такой же скорости работы. Это было ожидаемо, так как 4.6-битная схема квантования предлагает больше дискретов, чем 4-битная, а мы только проверили это теоретическое обстоятельство на практике. 

На сегодняшний день в нише быстрых моделей для центральных процессоров основную роль играют 8-битные модели за счет сочетания факторов: заметное ускорение и малые усилия по конвертации к 8-битному формату. Именно их наиболее актуально сравнивать с 4.6-битными. Здесь 4.6-битные сети однозначно выиграли по скорости, продемонстрировав ускорение на 30-40%. При использовании простых в имплементации методов обучения они несколько проиграли 8-битным по качеству, однако можно взять чуть более сложную 4.6-битную модель, которая даст сходное качество, но будет все еще работать быстрее исходной. Поэтому на практике мы получаем ускорение без снижения качества, а также активно работаем над усовершенствованием метода обучения 4.6-битных сетей.

На самом деле здесь есть много простора для фантазии:

  • параметры , позволяют модифицировать схему квантования под разные задачи,

  • никто не говорит, что они должны быть одинаковые в разных слоях модели,

  • AdaQuant все-таки устарел, есть и другие методы, которые могут подойти лучше.

Поэтому мы считаем 4.6-битные сети крайне перспективными для практического использования и дальнейшего усовершенствования. Мы уже используем их в наших системах распознавания паспортов, распознавания документов, а также подали заявку на патентование этой технологии в РФ и США.