Сегодня в ленте была очередная порция хайпа про ИИ. Смешно читать про «аппаратное ускорение AI» на пользовательских устройствах. Автор, вы сами попробуйте добраться до этого аппаратного ускорения, и если найдете как — напишите статью. А то элементарная попытка использования GPU для работы TensorFlow Lite приводит только к потерянному времени, а ускорители NPU больше не поддерживаются именно там, где должны были бы. То есть за хайпом вокруг «аппаратного ускорения ИИ» производители создали новую категорию устройств, и теперь стандартно ноутбук будет стоить в 2 раза больше, чем было раньше. А по факту пользоваться этим ускорением будут только компании‑производители, чтобы еще больше заработать денег на пользователях через рекламу, «правильные» модели и торговлю персональными данными.
А мы сегодня запустим TensorFlow Lite на устройствах разного класса и года выпуска и посмотрим, что там с производительностью и ускорением.
Для начала берем модель, подготовленную на предыдущем шаге, и копируем ее в ассеты проекта Android Studio. Можно взять модель с сохраненными весами после обучения (ruLearnModel.tflite), можно сохранить необученную модель (в моем случае ruLearnModel_noWeights.tflite). А можно сделать и то, и другое:
Дальше находим на гитхабе пример с обучением модели на устройстве и копируем оттуда зависимости для нашего проекта. Я посчитал по опыту предыдущей статьи, что лучше использовать старые версии библиотек, которые работают, чем новые, для которых вообще непонятно как писать код.
// Tensorflow lite dependencies
implementation 'org.tensorflow:tensorflow-lite:2.9.0'
implementation 'org.tensorflow:tensorflow-lite-gpu:2.9.0'
implementation 'org.tensorflow:tensorflow-lite-support:0.4.2'
implementation 'org.tensorflow:tensorflow-lite-select-tf-ops:2.9.0'
Напомню, что текущая версия TensorFlow 2.17. Наша модель сделана на Python для версии 2.8.
Теперь по архитектуре. Создадим класс-singleton TFLiteHelper, в котором главное действующее лицо - это "interpreter" - экземпляр работающей модели TensorFlow Lite. Он может быть для моего приложения в единственном экземпляре и будет выполнять разные функции - предсказывать или учиться.
private TFLiteHelper(Context context, String courseName) {
this.context = context;
this.courseName = courseName;
interpreter = new Interpreter(loadModelfile(context));
}
public static TFLiteHelper getInstance(Context context, String courseName) {
if (instance == null)
instance = new TFLiteHelper(context, courseName);
return instance;
}
Передаем в конструктор context и название курса и запоминаем их для дальшейшего использования. Создаем новый интерпретатор, загружая модель из ассетов:
private MappedByteBuffer loadModelfile(Context context) {
try {
AssetFileDescriptor fileDescriptor = context.getAssets().openFd("ruLearnModel_noWeights.tflite");
FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
FileChannel fileChannel = inputStream.getChannel();
long startOffset = fileDescriptor.getStartOffset();
long declaredLength = fileDescriptor.getDeclaredLength();
return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
} catch (IOException e) {
//добавим обработку исключений позже
//в какой ситуации может не найтись модель, которая была в ассетах во время компиляции?
//это был риторический вопрос
}
}
Интерпретатор загружен. Что делать дальше? Можно запустить предсказание, для этого нужны 1) веса модели (они есть, если мы сохранили модель с весами) 2) нормализация (мы можем взять пример уже нормализованных значений из предыдущей статьи. И тогда этот код будет работать:
private float[] doInference(TensorBuffer tb) {
int tbSize = tb.getShape()[0];
Map<String, Object> inputs = new HashMap<>();
inputs.put("x", tb.getFloatArray());
Map<String, Object> outputs = new HashMap<>();
FloatBuffer output = FloatBuffer.allocate(tbSize);
outputs.put("output_0", output);
long millis = System.currentTimeMillis();
interpreter.runSignature(inputs, outputs, "infer");
long xxx = System.currentTimeMillis() - millis;
System.out.println("model ran for: " + xxx + " milliseconds");
FloatBuffer buffer = (FloatBuffer) outputs.get("output_0");
float[] outPutArray = buffer.array();
return outPutArray;
}
На входе нужен TensorBuffer - внутренний объект для TF Lite. Создается он так:
TensorBuffer tensorBuffer = TensorBuffer.createFixedSize(new int[]{numRowCount, 4}, DataType.FLOAT32);
float[] featureArray = new float[]{
-1.2826425, 0.7050209, -0.6746891, 1.5294912
};
tensorBuffer.loadArray(featureArray);
где new int[]{numRowCount, 4} - это форма буфера. То есть он будет из произвольного количества рядов с 4 колонками. DataType.FLOAT32 - это чуть ли не единственный тип данных, нормально поддерживаемый TF Lite. В буфер надо загрузить одномерный массив, и tensorBuffer сам переупакует его в соответствии с заданной формой.
Для выходных значений нужно тоже подготовить буфер. В нашем случае для каждого ряда параметров на выходе будет одно значение float. Поэтому буфер должен быть равен <число рядов * размер float32>. Число рядов равно tb.getShape()[0] - первому элементу формы буфера. Дальше идет загадочная история с упаковкой буферов в Map. Это, по всей видимости, нужно для передачи именованных параметров, как было в Python (x=test_data):
y_lite = infer(x=test_data)
y_original = m.infer(x=test_data)
То есть в нашем случае мы записываем имя параметра в ключ Map, а данные - в значение.
Засекаем системное время, запускаем интерпретатор с функцией infer и смотрим, сколько ушло на предсказание. В статье, где мы первый раз запускали модель, на предсказание всей таблицы в цикле по одному ряду ушло 167 миллисекунд. Здесь у нас есть возможность сделать это за один подход и результат равен 11 миллисекунд на том же устройстве.
Обучение на устройстве и вспомогательные функции
Для обучения лучше взять пустую модель, без весов. На самом деле, данных у нас не прибавилось, поэтому учить натренированную модель бессмысленно. Но вначале нужно заняться нормализацией, то есть пересчитать данные по формуле:
Для этого можно использовать следующий код:
private TensorBuffer doNormalize(int numRowCount, float[] mlArray) throws IOException {
TensorBuffer notNormalizedTensorBuffer = TensorBuffer.createFixedSize(new int[]{numRowCount, 4}, DataType.FLOAT32);
TensorBuffer normalizedTensorBuffer; //объявлять форму не надо, буфер вернет TensorProcessor
notNormalizedTensorBuffer.loadArray(mlArray);
float[] mean = new float[4];
float[] stddev = new float[4];
getMeansAndStddev(mean, stddev); //внутри метода файловая операция
TensorProcessor processor = new TensorProcessor.Builder().add(new NormalizeOp(mean, stddev)).build();
normalizedTensorBuffer = processor.process(notNormalizedTensorBuffer);
return normalizedTensorBuffer;
}
Метод getMeansAndStddev(mean, stddev) возвращает через параметры средние значения и стандартные отклонения для датасета. У нас 4 параметра в модели, поэтому 4 средних по каждой колонке значений и 4 отклонения. Можно взять его из предыдущей статьи или посчитать самостоятельно самому или применить библиотечные функции из Apache Commons Statistics. В моем случае я читаю эти значения из файла, поэтому throws IOException.
Давайте посмотрим, как выглядит код тренировки модели:
public void doTrain(ArrayList<MLData> input, int num_epoch) throws IOException{
float[] dataArray = mlDataToArrays(input).getData();
float[] resultArray = mlDataToArrays(input).getResults();
calculateMeansAndStandardDeviation(input); //файловые операции
TensorBuffer tb_in = doNormalize(input.size(), dataArray);
float[][] dataArray2dim = new float[][]{tb_in.getFloatArray()};
float[][] resultArray2dim = new float[][]{resultArray};
long millis = System.currentTimeMillis();
for (int i = 0; i < num_epoch; i++) {
Map<String, Object> inputs = new HashMap<>();
inputs.put("x", dataArray2dim);
inputs.put("y", resultArray2dim);
Map<String, Object> outputs = new HashMap<>();
FloatBuffer loss = FloatBuffer.allocate(1);
outputs.put("loss", loss);
interpreter.runSignature(inputs, outputs, "train");
}
System.out.println("training took " + (System.currentTimeMillis() - millis) + " ms");
saveCheckpoint();
}
Начинается все с преобразования датасета, приходящего в метод из базы SQLite в форме ArrayList<упакованные в объект значения из таблицы с данными, на которых строилась модель>. Для работы TensorFlow нужны буферы или массивы, поэтому я переупаковываю данные соответственно в массивы для значений и результатов, на которых модель будет учиться. Дальше в функции calculateMeansAndStandardDeviation(input) по всему датасету считаю средние значения и стандартное отклонение, записываю эти значения в файл. Дальше почему-то данные и результаты надо переупаковать в двумерные массивы. Почему - непонятно и на самом деле код не работал, пока я не вспомнил, что и в Python было именно так:
test_data = np.array([[-1.2826425, 0.7050209, -0.6746891, 1.5294912],
[2.559783, -0.18800573, 0.62913984, -0.8701883 ],
[ 2.560703, -0.18800573, 1.672203, -0.87582266],
[-1.2808022, 0.7050209, -0.6746891, 0.24507335]],dtype='float32')
Без этого преобразования происходит одна из ошибок TensorFlow с "очень полезными" описаниями типа "TypeError: src data type = 17 is not supported". В данном случае такая:
Дальше как обычно, упаковка в Map и запуск соответствующей функции interpreter.runSignature(inputs, outputs, "train"). Заодно засекаем время. Что с результатами? (это все под дебагом, в реальности будет намного быстрее)
Телефон Xiaomi Mi 1 2017 года, 60 000 Antutu: 2000 мс
Телефон Xiaomi 11 Lite 5G 2020 года, 300 000 Antutu: 400 мс
Телефон Samsung Galaxy S22 2022 года, 1 000 000 Antutu: 265 мс.
Очевидно, что обучение с нуля такой модели на современном устройстве - не проблема. Можно переучивать модель хоть при каждом старте приложения и пользователь это не заметит. А если использовать Checkpoint'ы, то тем более. Так или иначе, нам нужно сохраниться, поэтому привожу код метода saveCheckpoint():
private void saveCheckpoint() {
File outputFile = new File(context.getFilesDir(), courseName + ".ckpt");
Map<String, Object> inputs = new HashMap<>();
inputs.put("checkpoint_path", outputFile.getAbsolutePath());
Map<String, Object> outputs = new HashMap<>();
interpreter.runSignature(inputs, outputs, "save");
}
Метод Inference, который можно сделать публичным
После обучения у нас есть все необходимое, чтобы предоставить метод для предсказания значений, который будет принимать на вход данные из таблицы:
public ArrayList<Float> runInference(ArrayList<MLData> input) throws IOException{
openCheckpoint(); //файловая операция!
float[] floatArray = mlDataToArrays(input).getData();
TensorBuffer tb_in = doNormalize(input.size(), floatArray);
floatArray = doInference(tb_in);
return mlArrayToData(floatArray);
}
private void openCheckpoint() throws IOException{
File outputFile = new File(context.getFilesDir(), courseName + ".ckpt");
if(outputFile.exists()) {
Map<String, Object> inputs = new HashMap<>();
inputs.put("checkpoint_path", outputFile.getAbsolutePath());
Map<String, Object> outputs = new HashMap<>();
interpreter.runSignature(inputs, outputs, "restore");
}
else
throw new IOException("checkpoint not found");
}
Преобразуем данные из ArrayList в массив, нормализуем и запускаем метод doInference, который мы рассмотрели выше. И все, мы получаем предсказание на базе модели, которую мы обучили на телефоне!
Выводы
Машинное обучение для простых моделей вполне можно запускать на мобильных устройствах. Аппаратное ускорение в этом случае может и не понадобиться, не говоря о том, что даже на GPU мне не удалось запустить интерпретатор ни на одном из имеющихся телефонов. При этом доступ к NPU, даже если он сегодня еще и возможен, будет запрещен для TensorFlow Lite начиная с пятнадцатой версии Android. А ведь TensowFlow Lite - единственный официальный инструмент от Google для машинного обучения на Android!
Сколько миллиардов было потрачено на то, чтобы произвести все эти NPU, к которым просто нет доступа, поразительно. Даже на Orange Pi 3, который у меня дома выполняет функции NAS, есть "built-in AI accelerator NPU with 0.8Tops computing power".
Несмотря на то, что документация по TensorFlow Lite устаревшая и неполная, все еще можно написать код для своего приложения. Ну а вам будет легче, ведь теперь на Хабре есть туториал :)