Pull to refresh

Хайп вокруг аппаратного ускорения ИИ и реальная ситуация. Обучение модели на телефоне и результаты в миллисекундах

Level of difficultyEasy
Reading time8 min
Views3.4K

Сегодня в ленте была очередная порция хайпа про ИИ. Смешно читать про «аппаратное ускорение AI» на пользовательских устройствах. Автор, вы сами попробуйте добраться до этого аппаратного ускорения, и если найдете как — напишите статью. А то элементарная попытка использования GPU для работы TensorFlow Lite приводит только к потерянному времени, а ускорители NPU больше не поддерживаются именно там, где должны были бы. То есть за хайпом вокруг «аппаратного ускорения ИИ» производители создали новую категорию устройств, и теперь стандартно ноутбук будет стоить в 2 раза больше, чем было раньше. А по факту пользоваться этим ускорением будут только компании‑производители, чтобы еще больше заработать денег на пользователях через рекламу, «правильные» модели и торговлю персональными данными.

А мы сегодня запустим TensorFlow Lite на устройствах разного класса и года выпуска и посмотрим, что там с производительностью и ускорением.

Для начала берем модель, подготовленную на предыдущем шаге, и копируем ее в ассеты проекта Android Studio. Можно взять модель с сохраненными весами после обучения (ruLearnModel.tflite), можно сохранить необученную модель (в моем случае ruLearnModel_noWeights.tflite). А можно сделать и то, и другое:

Все как обычно в программировании под Android!
Все как обычно в программировании под Android!

Дальше находим на гитхабе пример с обучением модели на устройстве и копируем оттуда зависимости для нашего проекта. Я посчитал по опыту предыдущей статьи, что лучше использовать старые версии библиотек, которые работают, чем новые, для которых вообще непонятно как писать код.

    // Tensorflow lite dependencies
    implementation 'org.tensorflow:tensorflow-lite:2.9.0'
    implementation 'org.tensorflow:tensorflow-lite-gpu:2.9.0'
    implementation 'org.tensorflow:tensorflow-lite-support:0.4.2'
    implementation 'org.tensorflow:tensorflow-lite-select-tf-ops:2.9.0'

Напомню, что текущая версия TensorFlow 2.17. Наша модель сделана на Python для версии 2.8.

Теперь по архитектуре. Создадим класс-singleton TFLiteHelper, в котором главное действующее лицо - это "interpreter" - экземпляр работающей модели TensorFlow Lite. Он может быть для моего приложения в единственном экземпляре и будет выполнять разные функции - предсказывать или учиться.

    private TFLiteHelper(Context context, String courseName) {
        this.context = context;
        this.courseName = courseName;
        interpreter = new Interpreter(loadModelfile(context));
    }

    public static TFLiteHelper getInstance(Context context, String courseName) {
        if (instance == null)
            instance = new TFLiteHelper(context, courseName);
        return instance;
    }

Передаем в конструктор context и название курса и запоминаем их для дальшейшего использования. Создаем новый интерпретатор, загружая модель из ассетов:

    private MappedByteBuffer loadModelfile(Context context) {
        try {
            AssetFileDescriptor fileDescriptor = context.getAssets().openFd("ruLearnModel_noWeights.tflite");
            FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
            FileChannel fileChannel = inputStream.getChannel();
            long startOffset = fileDescriptor.getStartOffset();
            long declaredLength = fileDescriptor.getDeclaredLength();
            return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
        } catch (IOException e) {
            //добавим обработку исключений позже
            //в какой ситуации может не найтись модель, которая была в ассетах во время компиляции?
            //это был риторический вопрос
        }
    }

Интерпретатор загружен. Что делать дальше? Можно запустить предсказание, для этого нужны 1) веса модели (они есть, если мы сохранили модель с весами) 2) нормализация (мы можем взять пример уже нормализованных значений из предыдущей статьи. И тогда этот код будет работать:

   private float[] doInference(TensorBuffer tb) {
        int tbSize = tb.getShape()[0];
        Map<String, Object> inputs = new HashMap<>();
        inputs.put("x", tb.getFloatArray());
        Map<String, Object> outputs = new HashMap<>();
        FloatBuffer output = FloatBuffer.allocate(tbSize);
        outputs.put("output_0", output);
        long millis = System.currentTimeMillis();
        interpreter.runSignature(inputs, outputs, "infer");
        long xxx = System.currentTimeMillis() - millis;
        System.out.println("model ran for: " + xxx + " milliseconds");
        FloatBuffer buffer = (FloatBuffer) outputs.get("output_0");
        float[] outPutArray = buffer.array();
        return outPutArray;
    }

На входе нужен TensorBuffer - внутренний объект для TF Lite. Создается он так:

TensorBuffer tensorBuffer = TensorBuffer.createFixedSize(new int[]{numRowCount, 4}, DataType.FLOAT32);
float[] featureArray = new float[]{
       -1.2826425, 0.7050209, -0.6746891, 1.5294912
};
tensorBuffer.loadArray(featureArray);

где new int[]{numRowCount, 4} - это форма буфера. То есть он будет из произвольного количества рядов с 4 колонками. DataType.FLOAT32 - это чуть ли не единственный тип данных, нормально поддерживаемый TF Lite. В буфер надо загрузить одномерный массив, и tensorBuffer сам переупакует его в соответствии с заданной формой.

Для выходных значений нужно тоже подготовить буфер. В нашем случае для каждого ряда параметров на выходе будет одно значение float. Поэтому буфер должен быть равен <число рядов * размер float32>. Число рядов равно tb.getShape()[0] - первому элементу формы буфера. Дальше идет загадочная история с упаковкой буферов в Map. Это, по всей видимости, нужно для передачи именованных параметров, как было в Python (x=test_data):

y_lite = infer(x=test_data)
y_original = m.infer(x=test_data)

То есть в нашем случае мы записываем имя параметра в ключ Map, а данные - в значение.

Засекаем системное время, запускаем интерпретатор с функцией infer и смотрим, сколько ушло на предсказание. В статье, где мы первый раз запускали модель, на предсказание всей таблицы в цикле по одному ряду ушло 167 миллисекунд. Здесь у нас есть возможность сделать это за один подход и результат равен 11 миллисекунд на том же устройстве.

Обучение на устройстве и вспомогательные функции

Для обучения лучше взять пустую модель, без весов. На самом деле, данных у нас не прибавилось, поэтому учить натренированную модель бессмысленно. Но вначале нужно заняться нормализацией, то есть пересчитать данные по формуле:

normalizedInput = (input - mean) / sqrt(var)

Для этого можно использовать следующий код:

    private TensorBuffer doNormalize(int numRowCount, float[] mlArray) throws IOException {
        TensorBuffer notNormalizedTensorBuffer = TensorBuffer.createFixedSize(new int[]{numRowCount, 4}, DataType.FLOAT32);
        TensorBuffer normalizedTensorBuffer; //объявлять форму не надо, буфер вернет TensorProcessor
        notNormalizedTensorBuffer.loadArray(mlArray);
        float[] mean = new float[4];
        float[] stddev = new float[4];
        getMeansAndStddev(mean, stddev); //внутри метода файловая операция
        TensorProcessor processor = new TensorProcessor.Builder().add(new NormalizeOp(mean, stddev)).build();
        normalizedTensorBuffer = processor.process(notNormalizedTensorBuffer);
        return normalizedTensorBuffer;
    }

Метод getMeansAndStddev(mean, stddev) возвращает через параметры средние значения и стандартные отклонения для датасета. У нас 4 параметра в модели, поэтому 4 средних по каждой колонке значений и 4 отклонения. Можно взять его из предыдущей статьи или посчитать самостоятельно самому или применить библиотечные функции из Apache Commons Statistics. В моем случае я читаю эти значения из файла, поэтому throws IOException.

Давайте посмотрим, как выглядит код тренировки модели:

    public void doTrain(ArrayList<MLData> input, int num_epoch) throws IOException{
        float[] dataArray = mlDataToArrays(input).getData();
        float[] resultArray = mlDataToArrays(input).getResults();
        calculateMeansAndStandardDeviation(input); //файловые операции
        TensorBuffer tb_in = doNormalize(input.size(), dataArray);
        float[][] dataArray2dim = new float[][]{tb_in.getFloatArray()};
        float[][] resultArray2dim = new float[][]{resultArray};
        long millis = System.currentTimeMillis();
        for (int i = 0; i < num_epoch; i++) {
            Map<String, Object> inputs = new HashMap<>();
            inputs.put("x", dataArray2dim);
            inputs.put("y", resultArray2dim);
            Map<String, Object> outputs = new HashMap<>();
            FloatBuffer loss = FloatBuffer.allocate(1);
            outputs.put("loss", loss);
            interpreter.runSignature(inputs, outputs, "train");
        }
        System.out.println("training took " + (System.currentTimeMillis() - millis) + " ms");
        saveCheckpoint();
    }

Начинается все с преобразования датасета, приходящего в метод из базы SQLite в форме ArrayList<упакованные в объект значения из таблицы с данными, на которых строилась модель>. Для работы TensorFlow нужны буферы или массивы, поэтому я переупаковываю данные соответственно в массивы для значений и результатов, на которых модель будет учиться. Дальше в функции calculateMeansAndStandardDeviation(input) по всему датасету считаю средние значения и стандартное отклонение, записываю эти значения в файл. Дальше почему-то данные и результаты надо переупаковать в двумерные массивы. Почему - непонятно и на самом деле код не работал, пока я не вспомнил, что и в Python было именно так:

test_data = np.array([[-1.2826425, 0.7050209, -0.6746891, 1.5294912],
                      [2.559783, -0.18800573,  0.62913984, -0.8701883 ],
                      [ 2.560703,   -0.18800573,  1.672203,   -0.87582266],
                      [-1.2808022, 0.7050209,  -0.6746891, 0.24507335]],dtype='float32')

Без этого преобразования происходит одна из ошибок TensorFlow с "очень полезными" описаниями типа "TypeError: src data type = 17 is not supported". В данном случае такая:

И на этом все!
И на этом все!

Дальше как обычно, упаковка в Map и запуск соответствующей функции interpreter.runSignature(inputs, outputs, "train"). Заодно засекаем время. Что с результатами? (это все под дебагом, в реальности будет намного быстрее)

  • Телефон Xiaomi Mi 1 2017 года, 60 000 Antutu: 2000 мс

  • Телефон Xiaomi 11 Lite 5G 2020 года, 300 000 Antutu: 400 мс

  • Телефон Samsung Galaxy S22 2022 года, 1 000 000 Antutu: 265 мс.

Очевидно, что обучение с нуля такой модели на современном устройстве - не проблема. Можно переучивать модель хоть при каждом старте приложения и пользователь это не заметит. А если использовать Checkpoint'ы, то тем более. Так или иначе, нам нужно сохраниться, поэтому привожу код метода saveCheckpoint():

    private void saveCheckpoint() {
        File outputFile = new File(context.getFilesDir(), courseName + ".ckpt");
        Map<String, Object> inputs = new HashMap<>();
        inputs.put("checkpoint_path", outputFile.getAbsolutePath());
        Map<String, Object> outputs = new HashMap<>();
        interpreter.runSignature(inputs, outputs, "save");
    }

Метод Inference, который можно сделать публичным

После обучения у нас есть все необходимое, чтобы предоставить метод для предсказания значений, который будет принимать на вход данные из таблицы:

    public ArrayList<Float> runInference(ArrayList<MLData> input) throws IOException{
        openCheckpoint(); //файловая операция!
        float[] floatArray = mlDataToArrays(input).getData();
        TensorBuffer tb_in = doNormalize(input.size(), floatArray);
        floatArray = doInference(tb_in);
        return mlArrayToData(floatArray);
    }

    private void openCheckpoint() throws IOException{
        File outputFile = new File(context.getFilesDir(), courseName + ".ckpt");
        if(outputFile.exists()) {
            Map<String, Object> inputs = new HashMap<>();
            inputs.put("checkpoint_path", outputFile.getAbsolutePath());
            Map<String, Object> outputs = new HashMap<>();
            interpreter.runSignature(inputs, outputs, "restore");
        }
        else
            throw new IOException("checkpoint not found");
    }

Преобразуем данные из ArrayList в массив, нормализуем и запускаем метод doInference, который мы рассмотрели выше. И все, мы получаем предсказание на базе модели, которую мы обучили на телефоне!

Выводы

Машинное обучение для простых моделей вполне можно запускать на мобильных устройствах. Аппаратное ускорение в этом случае может и не понадобиться, не говоря о том, что даже на GPU мне не удалось запустить интерпретатор ни на одном из имеющихся телефонов. При этом доступ к NPU, даже если он сегодня еще и возможен, будет запрещен для TensorFlow Lite начиная с пятнадцатой версии Android. А ведь TensowFlow Lite - единственный официальный инструмент от Google для машинного обучения на Android!

Сколько миллиардов было потрачено на то, чтобы произвести все эти NPU, к которым просто нет доступа, поразительно. Даже на Orange Pi 3, который у меня дома выполняет функции NAS, есть "built-in AI accelerator NPU with 0.8Tops computing power".

Несмотря на то, что документация по TensorFlow Lite устаревшая и неполная, все еще можно написать код для своего приложения. Ну а вам будет легче, ведь теперь на Хабре есть туториал :)

Tags:
Hubs:
Total votes 8: ↑8 and ↓0+11
Comments0

Articles