Как устроены 4.6-битные сети: обучение
Мы уже писали о том, что предложили новую модель квантования нейронных сетей, позволяющую ускорить их на 40% на центральных процессорах, а также о том, как она устроена вот тут.
Сегодня мы расскажем о том, как мы в Smart Engines обучали 4.6-битные сети. Основная проблема обучения квантованных моделей в том, что градиентные методы модифицируют веса непрерывным образом, а у квантованных сетей они дискретные.
Кроме того, распространен сценарий, когда сначала обучается вещественная сеть, а затем ее хочется отквантовать, чтобы ускорить систему. Часто обучающие данные при этом уже недоступны или доступны не в полном объеме.
Поэтому методы квантования делятся на 2 большие группы:
Post Training Quantization (PTQ): когда мы квантуем уже обученную сеть. Как правило, они минимизируют ошибку между выходами каждого слоя вещественной и квантованной сетей. Такие методы хорошо работают на больших моделях, в которых есть избыточные веса, а в случае компактных сетей для центральных процессоров отквантоваться без снижения точности обычно не получается.
Quantization Aware Training (QAT): когда мы уже во время обучения применяем методы, повышающие качество квантованной модели.
В Smart Engines все модели уже компактные, а обучающие данные никуда не деваются, так что нам в первую очередь интересны QAT методы или их комбинации с PTQ методами. Расскажем о них поподробнее.
Метод обучения
Самый очевидный способ обучить квантованную сеть – игнорировать квантование во время обратного прохода по сети, а модифицированные веса округлить. Тогда прямой проход будет уже квантованным и соответствовать “рабочему” режиму сети. Так работает Straight Through Estimator (STE), один из старейших и популярных методов обучения квантованных сетей. На удивление, он неплохо работает, если взять уже обученную 8-битную модель и ее постепенно квантовать, то есть 256 значений для сети оказывается не так уж и мало. Однако округление на каждой итерации делает процесс обучения “шумным” и может привести к невозможности обучить сеть вообще, что и наблюдается в случае меньших разрядностей.
Поэтому даже для 8-битных моделей есть методы получше, например, инкрементное обучение.
Идея инкрементного обучения или понейронной конвертации заключается в том, что мы постепенно преобразуем часть весов каждого слоя в квантованный вид. Уже квантованная часть сети немного дообучается (например, с помощью STE) и “замораживается” (веса фиксируются и больше не меняются). Вещественная часть сети дообучается, чтобы адекватно обрабатывать изменившиеся входные значения от квантованной части. Этот подход прост, эффективен в отношении 4- и 8-битных сетей, но требует времени.
AdaQuant – PTQ метод, работающий послойно. Его идея которого заключается в том, что мы используем небольшой калибровочный датасет, чтобы определить динамические диапазоны активаций, а также коэффициент масштабирования и точку нуля так, чтобы минимизировать среднеквадратичную ошибку между выходом квантованного и вещественного слоев.
В данной работе мы использовали комбинацию этих подходов. Сначала параметры каждого слоя отквантовали с помощью Adaquant, а затем применили инкрементное обучение, а именно прошли от начала к концу и еще немного дообучили каждый из квантованных слоев (остальные “замораживались”) с помощью стохастического градиентного спуска с моментом 0.9 и скоростью обучения (learning rate)
Далее мы рассмотрели три задачи компьютерного зрения, решаемые с помощью разных нейросетевых архитектур, и применили в них 4.6-битные сети. Кстати, наш код экспериментов есть на GitHub: https://github.com/SmartEngines/QNN_training_4.6bit
CIFAR-10
Это давно набившая оскомину небольшая выборка с цветными изображениями объектов 10 классов размера 32 на 32 (см. примеры на рис. 1). В качестве аугментации использовали отражения вдоль вертикальной оси, вырезание случайных регионов со сдвигами и случайные повороты на +-9 градусов с дополнением изображений на 4 пикселя с краев.
На ней мы обучили несколько простых сверточных моделей. Их архитектуры приведены в таблице 1.
Обозначения:
conv(c, f, k, [p]) – сверточный слой с c-канальным входом, f фильтрами размера k х k и паддингом p (в обоих направлениях, по умолчанию 0),
pool(n) – 2D max-pooling с окном n на n,
bn – batch normalization,
HardTanh – активация HardTanh(x) = min(1, max(-1, x)),
relu6 – активация ReLU6(x) = min(max(0, x),6),
tanh – гиперболический тангенс,
fc(n) – полносвязный слой с n выходами.
Таблица 1. Архитектуры сверточных сетей для CIFAR-10.
CNN6 | CNN7 | CNN8 | CNN9 | CNN10 |
conv(3, 4, 1) HardTanh | conv(3, 8, 1) HardTanh | conv(3, 8, 1) HardTanh | conv(3, 8, 1) HardTanh | conv(3, 8, 1) HardTanh |
conv(4, 8, 5) bn+relu6 pool(2) | conv(8, 8, 3) bn+relu6 conv(8, 12, 3) bn+relu6 pool(2) | conv(8, 8, 3) bn+relu6 conv(8, 12, 3) bn+relu6 pool(2) | conv(8, 8, 3) bn+relu6 conv(8, 12, 3) bn+relu6 pool(2) | conv(8, 16, 3, 1) bn+relu6 conv(16, 32, 3, 1) bn+relu6 pool(2) |
conv(8, 16, 3) bn+relu6 pool(2) | conv(12, 16, 3) bn+relu6 pool(2) | conv(12, 24, 3) bn+relu6 pool(2) | conv(12, 12, 3, 1) bn+relu6 conv(12, 24, 3) bn+relu6 pool(2) | conv(32, 32, 3, 1) bn+relu6 conv(32, 64, 3, 1) bn+relu6 pool(2) |
conv(16, 32, 3) bn+relu6 pool(2) | conv(16, 32, 3) bn+relu6 pool(2) | conv(24, 24, 3) bn+relu6 conv(24, 40, 3) bn+relu6 | conv(24, 24, 3) bn+relu6 conv(24, 48, 3) bn+relu6 | conv(64, 64, 3) bn+relu6 conv(64, 64, 3) bn+relu6 conv(64, 128, 3) bn+relu6 |
fc(64) tanh | fc(64) tanh | fc(64) tanh | fc(96) tanh | fc(256) tanh |
fc(10) | fc(10) | fc(10) | fc(10) | fc(10) |
Trainable parameters | ||||
15.6k | 16.9k | 29.1k | 40.7k | 315.6k |
Сначала мы обучили вещественные версии этих сетей 250 эпох с оптимизатором AdamW с параметрами по умолчанию, кроме коэффициента убывания весов (weight decay), который был установлен в
Дальше слои батч нормализации были интегрированы в свертки и мы приступили к квантованию. Первый и последний слой мы не квантовали, так как это улучшает результаты обучения и практически не снижает вычислительную эффективность. Результаты приведены в таблице 2. Напомним, что у нашей схемы квантования есть параметры
Таблица 2. Точность классификации на выборке CIFAR-10
Quantization | Accuracy, % | |||||
CNN6 | CNN7 | CNN8 | CNN9 | CNN10 | ||
5 | 127 | 63.5 ± 0.3 | 69.6 ± 0.1 | 70.6 ± 0.3 | 71.7 ± 0.3 | 81.3 ± 0.3 |
7 | 85 | 68.4 ± 0.3 | 73.4 ± 0.1 | 74.7 ± 0.2 | 76.2 ± 0.2 | 85.4 ± 0.1 |
9 | 63 | 70.9 ± 0.1 | 75.0 ± 0.2 | 76.4 ± 0.2 | 78.0 ± 0.2 | 86.9 ± 0.2 |
11 | 51 | 71.8 ± 0.3 | 75.7 ± 0.1 | 77.4 ± 0.1 | 79.0 ± 0.2 | 87.6 ± 0.1 |
13 | 43 | 72.7 ± 0.2 | 76.2 ± 0.1 | 77.8 ± 0.1 | 79.5 ± 0.1 | 88.0 ± 0.1 |
15 | 37 | 73.1 ± 0.2 | 76.4 ± 0.1 | 78.1 ± 0.1 | 79.8 ± 0.2 | 88.2 ± 0.2 |
17 | 31 | 73.0 ± 0.1 | 76.6 ± 0.2 | 78.1 ± 0.3 | 79.8 ± 0.2 | 88.2 ± 0.1 |
19 | 29 | 73.1 ± 0.2 | 76.4 ± 0.2 | 78.5 ± 0.1 | 80.0 ± 0.3 | 88.5 ± 0.3 |
21 | 25 | 73.4 ± 0.2 | 76.7 ± 0.3 | 78.3 ± 0.2 | 79.9 ± 0.1 | 88.4 ± 0.2 |
23 | 23 | 73.3 ± 0.3 | 76.5 ± 0.1 | 78.2 ± 0.2 | 79.9 ± 0.3 | 88.4 ± 0.2 |
25 | 21 | 73.1 ± 0.1 | 76.6 ± 0.2 | 78.2 ± 0.2 | 79.9 ± 0.2 | 88.5 ± 0.2 |
29 | 19 | 73.0 ± 0.1 | 76.3 ± 0.1 | 78.3 ± 0.2 | 79.9 ± 0.2 | 88.3 ± 0.2 |
31 | 17 | 73.1 ± 0.2 | 76.1 ± 0.1 | 78.0 ± 0.3 | 79.7 ± 0.2 | 88.2 ± 0.1 |
37 | 15 | 72.8 ± 0.2 | 75.5 ± 0.4 | 77.7 ± 0.3 | 79.4 ± 0.2 | 87.9 ± 0.3 |
43 | 13 | 72.0 ± 0.4 | 74.8 ± 0.2 | 77.5 ± 0.2 | 79.0 ± 0.3 | 87.9 ± 0.1 |
51 | 11 | 70.9 ± 0.3 | 74.0 ± 0.1 | 76.0 ± 0.3 | 78.1 ± 0.2 | 87.5 ± 0.1 |
63 | 9 | 69.0 ± 0.4 | 71.7 ± 0.3 | 74.3 ± 0.5 | 76.7 ± 0.4 | 86.3 ± 0.1 |
85 | 7 | 65.9 ± 0.5 | 67.7 ± 1.0 | 70.6 ± 0.4 | 73.4 ± 0.7 | 84.5 ± 0.3 |
127 | 5 | 47.5 ± 0.4 | 55.2 ± 0.6 | 58.2 ± 1.1 | 67.5 ± 0.4 | 74.9 ± 2.3 |
4 бита | 72.0 ± 0.2 | 75.4 ± 0.2 | 77.3 ± 0.2 | 79.3 ± 0.3 | 87.7 ± 0.2 | |
8 бит | 74.7 ± 0.1 | 77.6 ± 0.1 | 79.4 ± 0.1 | 80.8 ± 0.1 | 89.2 ± 0.1 | |
float32 | 74.95 | 77.83 | 79.66 | 81.4 | 89.07 |
Можно видеть, что 8-битные модели демонстрируют почти ту же точность, что и вещественные, а вот 4-битные им сильно уступают. Для 4.6-битных моделей оказалось, что лучше всего распределять дискреты между активациями и весами равномерно: наилучшие результаты показывают схемы квантования от (15, 37) до (31, 17). Отметим, что точность классификации в этом диапазоне заметно лучше, чем у 4-битных моделей. Тем не менее, 4.6-битные сети все же немного уступили 8-битным и вещественным моделям, что говорит о том, что метод обучения можно придумать и получше.
ImageNet
Это уже стандартная выборка для оценки точности классифицирующих моделей (см. примеры на рис. 2).
Здесь мы использовали ResNet’ы: ResNet-18 (11.7М весов) и ResNet-34 (21.8М весов), предобученные модели из PyTorch's Torchvision. Все активации ReLU заменили на ReLU6, батч нормализацию интегрировали в свертки и отквантовали, но первый и последний слои также не трогали.
Калибровочная выборка была меньше: 25 батчей по 64 изображения для ResNet-18 и 5 батчей по 64 изображения для ResNet-34.
Точности классификации top-1 и top-5 для выборки ImageNet показаны в таблице 3. Лучшая пара параметров для 4.6-битных моделей оказалась (
Таблица 3. Точность классификации на ImageNet.
Quantization | ResNet-18 | ResNet-34 | |||
top-1, % | top-5, % | top-1, % | top-5, % | ||
29 | 19 | 65.6 ± 0.3 | 86.7 ± 0.1 | 68.6 ± 0.1 | 88.5 ± 0.1 |
25 | 21 | 65.9 ± 0.2 | 86.9 ± 0.1 | 68.8 ± 0.3 | 88.7 ± 0.2 |
23 | 23 | 66.1 ± 0.1 | 87.0 ± 0.1 | 69.1 ± 0.2 | 88.9 ± 0.1 |
4 бита | 64.2 ± 0.2 | 85.7 ± 0.1 | 66.1 ± 0.3 | 87.0 ± 0.2 | |
8 бит | 68.3 ± 0.1 | 88.3 ± 0.1 | 71.4 ± 0.1 | 90.1 ± 0.1 | |
float32 | 68.7 | 88.5 | 72.3 | 90.8 |
TCIA
Это выборка с изображениями МРТ головного мозга, изображения RGB, размер каждого 256 на 256. На этих изображениях ставится задача сегментации: нужно найти и отметить аномалии. Эта задача прекрасно решается с помощью модели U-Net с 7.76М параметров. Мы взяли предобученную модель тут. Она принимает на вход изображение и выдает бинарную маску, которая отмечает аномалии.
Также как и в ResNet’ах мы заменили активации ReLU на ReLU6 и отквантовали сеть. Размер калибровочной выборки был 5 батчей по 10 изображений. Для оценки качества сегментации использовались средние значения Dice и Intersection over Union (IoU) для тех изображений, где аномалии были. Также посчитали метрики качества бинарной классификации (есть аномалия или нет): accuracy, precision, recall, ошибку 1-го рода (false positive rate) и ошибку 2-го рода (false negative rate). Метрики приведены в таблице 4, а примеры работы на рис. 3.
Таблица 4. Качество сегментации моделей U-Net на выборке TCIA.
float32 | 8 бит | 4.6 бита | 4 бита | |
Dice ↑ | 0.7643 | 0.7843 ± 0.0008 | 0.769 ± 0.006 | 0.746 ± 0.009 |
IoU ↑ | 0.6875 | 0.7046 ± 0.0008 | 0.688 ± 0.005 | 0.662 ± 0.009 |
Accuracy ↑ | 0.8119 | 0.8124 ± 0.0013 | 0.781 ± 0.013 | 0.57 ± 0.06 |
Precision ↑ | 0.6654 | 0.6624 ± 0.0016 | 0.623 ± 0.015 | 0.45 ± 0.04 |
Recall ↑ | 0.9286 | 0.9447 ± 0.0009 | 0.948 ± 0.004 | 0.969 ± 0.013 |
Type I ↓ | 0.1631 | 0.1682 ± 0.0011 | 0.201 ± 0.013 | 0.42 ± 0.07 |
Type II ↓ | 0.0249 | 0.0193 ± 0.0003 | 0.0182 ± 0.0013 | 0.011 ± 0.005 |
Можно видеть, что 8-битная сеть даже немного лучше вещественной: у нее чуть больше ошибка 1-го рода, но при этом выше качество сегментации. 4-битной сети откровенно плохо как по метрикам, так и визуально: добавился заметный шум. А 4.6-битная модель работает разумным образом и демонстрирует качество близкое к качеству 8-битной и вещественной моделей.
Обсуждение
Наши эксперименты ясно показали, что даже при использовании простейших методов обучения, не слишком отличающихся от подходов для 4- и 8-битных моделей, 4.6-битные сети работают заметно лучше 4-битных при практически такой же скорости работы. Это было ожидаемо, так как 4.6-битная схема квантования предлагает больше дискретов, чем 4-битная, а мы только проверили это теоретическое обстоятельство на практике.
На сегодняшний день в нише быстрых моделей для центральных процессоров основную роль играют 8-битные модели за счет сочетания факторов: заметное ускорение и малые усилия по конвертации к 8-битному формату. Именно их наиболее актуально сравнивать с 4.6-битными. Здесь 4.6-битные сети однозначно выиграли по скорости, продемонстрировав ускорение на 30-40%. При использовании простых в имплементации методов обучения они несколько проиграли 8-битным по качеству, однако можно взять чуть более сложную 4.6-битную модель, которая даст сходное качество, но будет все еще работать быстрее исходной. Поэтому на практике мы получаем ускорение без снижения качества, а также активно работаем над усовершенствованием метода обучения 4.6-битных сетей.
На самом деле здесь есть много простора для фантазии:
параметры
, позволяют модифицировать схему квантования под разные задачи, никто не говорит, что они должны быть одинаковые в разных слоях модели,
AdaQuant все-таки устарел, есть и другие методы, которые могут подойти лучше.
Поэтому мы считаем 4.6-битные сети крайне перспективными для практического использования и дальнейшего усовершенствования. Мы уже используем их в наших системах распознавания паспортов, распознавания документов, а также подали заявку на патентование этой технологии в РФ и США.