Деплоим ML проект, используя Flask как REST API, и делаем доступным через приложение на Flutter

Автор оригинала: SHARON ZACHARIA
  • Перевод
  • Tutorial


Введение


Машинное обучение уже везде и, пожалуй, почти невозможно найти софт, не использующий его прямо или косвенно. Давайте создадим небольшое приложение, способное загружать изображения на сервер для последующего распознавания с помощью ML. А после сделаем их доступными через мобильное приложение с текстовым поиском по содержимому.


Мы будем использовать Flask для нашего REST API, Flutter для мобильного приложения и Keras для машинного обучения. В качестве базы данных для хранения информации о содержимом изображений используем MongoDB, а для получения информации возьмём уже натренированную модель ResNet50. При необходимости мы сможем заменить модель, используя методы save_model() и load_model(), доступные в Keras. Последний потребует около 100 Мб при первоначальной загрузке модели. Почитать о других доступных моделях можно в документации.


Начнём с Flask


Если вы незнакомы с Flask, то создать роут на нём можно просто добавив к контроллеру декоратор app.route('/'), где app — переменная приложения. Пример:


from flask import Flask

app = Flask(__name__)

@app.route('/') 
def hello_world():
    return 'Hello, World!'

При запуске и переходе по дефолтному адресу 127.0.0.1:5000/ мы увидим ответ Hello World! О том, как сделать что-то посложнее, можно почитать в документации.


Приступим же к созданию полноценного бэкенда:


import os
import tensorflow as tf
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing import image as img
from keras.preprocessing.image import img_to_array
import numpy as np
from PIL import Image
from keras.applications.resnet50 import ResNet50,decode_predictions,preprocess_input
from datetime import datetime
import io
from flask import Flask,Blueprint,request,render_template,jsonify
from modules.dataBase import collection as db

Как можно заметить импорты содержат tensorflow, который мы будем использовать как бэкенд для keras, а так же numpy для работы с мультиразмерными массивами.


mod = Blueprint('backend', __name__, template_folder='templates', static_folder='./static')
UPLOAD_URL = 'http://192.168.1.103:5000/static/'
model = ResNet50(weights='imagenet')
model._make_predict_function()

На первой строчке мы создаём блюпринт для более удобной организации приложения. Из-за этого надо будет использовать mod.route('/') для декорирования контроллера. Предварительно натренированная на imagenet модель Resnet50 нуждается в вызове _make_predict_function() для инициализации. Без этого шага есть вероятность получить ошибку. А другую модель можно использовать, заменив строку


model = ResNet50(weights='imagenet')

на 


model = load_model('saved_model.h5')

Вот как будет выглядеть контроллер:


@mod.route('/predict', methods=['POST'])
def predict():  
     if request.method == 'POST':
        # проверяем, что прислали файл
        if 'file' not in request.files:
           return "someting went wrong 1"
      
        user_file = request.files['file']
        temp = request.files['file']
        if user_file.filename == '':
            return "file name not found ..." 
       
        else:
            path = os.path.join(os.getcwd()+'\\modules\\static\\'+user_file.filename)
            user_file.save(path)
            classes = identifyImage(path)
            db.addNewImage(
                user_file.filename,
                classes[0][0][1],
                str(classes[0][0][2]),
                datetime.now(),
                UPLOAD_URL+user_file.filename)

            return jsonify({
                "status":"success",
                "prediction":classes[0][0][1],
                "confidence":str(classes[0][0][2]),
                "upload_time":datetime.now()
                })

В коде выше загруженное изображение передаётся в метод identifyImage(file_path), который реализован так:


def identifyImage(img_path):   
    image = img.load_img(img_path, target_size=(224,224))
    x = img_to_array(image)
    x = np.expand_dims(x, axis=0)
    # images = np.vstack([x])
    x = preprocess_input(x)
    preds = model.predict(x)
    preds = decode_predictions(preds, top=1)
    print(preds)
    return preds

Сначала мы преобразуем изображение к размеру 224*224, т.к. именно он нужен нашей модели. Затем передаём в model.predict() предварительно обработанные байты изображения. Теперь наша модель может предсказать, что находится на изображении (top=1 нужен чтобы получить единственный самый вероятный результат).


Сохраним полученные данные о содержимом изображения в MongoDB с помощью функции db.addData(). Вот релевантная часть кода:


from pymongo import MongoClient
from bson import ObjectId 

client = MongoClient("mongodb://localhost:27017")  # host uri  
db = client.image_predition #Select the database  
image_details = db.imageData

def addNewImage(i_name, prediction, conf, time, url):
    image_details.insert({
        "file_name":i_name,
        "prediction":prediction,
        "confidence":conf,
        "upload_time":time,
        "url":url
    })
    
def getAllImages():
    data = image_details.find()
    return data

Так как мы использовали блюпринт, код для API можно разместить в отдельном файле:


from flask import Flask,render_template,jsonify,Blueprint
mod = Blueprint('api',__name__,template_folder='templates')
from modules.dataBase import collection as db
from bson.json_util import dumps

@mod.route('/')
def api():
    return dumps(db.getAllImages())

Как можно заметить, для возвращения данных БД мы используем json. Посмотреть на результат можно по адресу 127.0.0.1:5000/api


Выше, разумеется, только самые важные куски кода. Полностью проект можно посмотреть в GitHub репозитории. А больше о Pymongo можно почитать здесь.


Создаём приложение Flutter


Мобильная версия будет получать изображения и данные об их содержимом по REST API. Вот что получится в итоге:



ImageData класс инкапсулирует данные об изображении:


import 'dart:convert';
import 'package:http/http.dart' as http;
import 'dart:async';

class ImageData
{
//  static String  BASE_URL ='http://192.168.1.103:5000/';
  String uri;
  String prediction;
  ImageData(this.uri,this.prediction);
}

Future<List<ImageData>> LoadImages() async
{
  List<ImageData> list;
  //complete fetch ....
  var data = await http.get('http://192.168.1.103:5000/api/');
  var jsondata = json.decode(data.body);
  List<ImageData> newslist = [];

  for (var data in jsondata) {
    ImageData n = ImageData(data['url'],data['prediction']);
    newslist.add(n);
  }

  return newslist; 
}

Здесь мы получаем json, преобразуем его в список объектов ImageData и возвращаем во Future Builder с помощью функции LoadImages()


Загрузка изображений на сервер


uploadImageToServer(File imageFile) async {
  print("attempting to connecto server......");
  var stream =
      new http.ByteStream(DelegatingStream.typed(imageFile.openRead()));
  var length = await imageFile.length();
  print(length);

  var uri = Uri.parse('http://192.168.1.103:5000/predict');
  print("connection established.");
  var request = new http.MultipartRequest("POST", uri);
  var multipartFile = new http.MultipartFile('file', stream, length,
      filename: basename(imageFile.path));
  //contentType: new MediaType('image', 'png'));

  request.files.add(multipartFile);
  var response = await request.send();
  print(response.statusCode);
}

Чтобы сделать Flask доступным в локальной сети отключите режим дебага и найдите ipv4 адрес, используя ipconfig. Запустить локальный сервер можно так:


app.run(debug=False, host='192.168.1.103', port=5000)

Иногда файрвол может мешать приложению обращаться к локалхосту, тогда его придётся перенастроить или отключить.




Весь исходный код приложения доступен на гитхабе. Вот ссылки, которые помогут разобраться в происходящем:


Keras : https://keras.io/


Flutter : https://flutter.dev/


MongoDB : https://www.tutorialspoint.com/mongodb/


Курс Гарварда по Python и Flask: https://www.youtube.com/watch?v=j5wysXqaIV8&t=5515s (особенно важны лекции 2,3,4)


GitHub : https://github.com/SHARONZACHARIA

Похожие публикации

AdBlock похитил этот баннер, но баннеры не зубы — отрастут

Подробнее
Реклама

Комментарии 5

    0
    Пожалуйста, поправьте форматирование для кода мобильного приложения. Если не трудно)
      0
      Готово!
        0
        Благодарю!
      +2
      Шутки про индусский код еще актуальны?
      Кое-кому не помешало бы узнать про .gitignore
        +1
        Хорошая статья. И кстати забавно, месяц назад с ребятами тоже сделали подобное приложение, из любопытства. Также использовали Python, Flask, MongoDB, Keras с TensorFlow. Ну и всё это завернули в докер (контейнер с приложением и контейнер с БД). Правда обычная UI, выдающая результаты совпадений и сохраняющая сессию по ip. Готовые архитектуры и веса требуют еще долгой доработки, так как достаточно криво работают. ИМХО.

        Только полноправные пользователи могут оставлять комментарии. Войдите, пожалуйста.

        Самое читаемое