Введение
Машинное обучение уже везде и, пожалуй, почти невозможно найти софт, не использующий его прямо или косвенно. Давайте создадим небольшое приложение, способное загружать изображения на сервер для последующего распознавания с помощью ML. А после сделаем их доступными через мобильное приложение с текстовым поиском по содержимому.
Мы будем использовать Flask для нашего REST API, Flutter для мобильного приложения и Keras для машинного обучения. В качестве базы данных для хранения информации о содержимом изображений используем MongoDB, а для получения информации возьмём уже натренированную модель ResNet50. При необходимости мы сможем заменить модель, используя методы save_model() и load_model(), доступные в Keras. Последний потребует около 100 Мб при первоначальной загрузке модели. Почитать о других доступных моделях можно в документации.
Начнём с Flask
Если вы незнакомы с Flask, то создать роут на нём можно просто добавив к контроллеру декоратор app.route('/'), где app — переменная приложения. Пример:
from flask import Flask
app = Flask(__name__)
@app.route('/')
def hello_world():
return 'Hello, World!'
При запуске и переходе по дефолтному адресу 127.0.0.1:5000/ мы увидим ответ Hello World! О том, как сделать что-то посложнее, можно почитать в документации.
Приступим же к созданию полноценного бэкенда:
import os
import tensorflow as tf
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing import image as img
from keras.preprocessing.image import img_to_array
import numpy as np
from PIL import Image
from keras.applications.resnet50 import ResNet50,decode_predictions,preprocess_input
from datetime import datetime
import io
from flask import Flask,Blueprint,request,render_template,jsonify
from modules.dataBase import collection as db
Как можно заметить импорты содержат tensorflow, который мы будем использовать как бэкенд для keras, а так же numpy для работы с мультиразмерными массивами.
mod = Blueprint('backend', __name__, template_folder='templates', static_folder='./static')
UPLOAD_URL = 'http://192.168.1.103:5000/static/'
model = ResNet50(weights='imagenet')
model._make_predict_function()
На первой строчке мы создаём блюпринт для более удобной организации приложения. Из-за этого надо будет использовать mod.route('/') для декорирования контроллера. Предварительно натренированная на imagenet модель Resnet50 нуждается в вызове _make_predict_function() для инициализации. Без этого шага есть вероятность получить ошибку. А другую модель можно использовать, заменив строку
model = ResNet50(weights='imagenet')
на
model = load_model('saved_model.h5')
Вот как будет выглядеть контроллер:
@mod.route('/predict', methods=['POST'])
def predict():
if request.method == 'POST':
# проверяем, что прислали файл
if 'file' not in request.files:
return "someting went wrong 1"
user_file = request.files['file']
temp = request.files['file']
if user_file.filename == '':
return "file name not found ..."
else:
path = os.path.join(os.getcwd()+'\\modules\\static\\'+user_file.filename)
user_file.save(path)
classes = identifyImage(path)
db.addNewImage(
user_file.filename,
classes[0][0][1],
str(classes[0][0][2]),
datetime.now(),
UPLOAD_URL+user_file.filename)
return jsonify({
"status":"success",
"prediction":classes[0][0][1],
"confidence":str(classes[0][0][2]),
"upload_time":datetime.now()
})
В коде выше загруженное изображение передаётся в метод identifyImage(file_path), который реализован так:
def identifyImage(img_path):
image = img.load_img(img_path, target_size=(224,224))
x = img_to_array(image)
x = np.expand_dims(x, axis=0)
# images = np.vstack([x])
x = preprocess_input(x)
preds = model.predict(x)
preds = decode_predictions(preds, top=1)
print(preds)
return preds
Сначала мы преобразуем изображение к размеру 224*224, т.к. именно он нужен нашей модели. Затем передаём в model.predict() предварительно обработанные байты изображения. Теперь наша модель может предсказать, что находится на изображении (top=1 нужен чтобы получить единственный самый вероятный результат).
Сохраним полученные данные о содержимом изображения в MongoDB с помощью функции db.addData(). Вот релевантная часть кода:
from pymongo import MongoClient
from bson import ObjectId
client = MongoClient("mongodb://localhost:27017") # host uri
db = client.image_predition #Select the database
image_details = db.imageData
def addNewImage(i_name, prediction, conf, time, url):
image_details.insert({
"file_name":i_name,
"prediction":prediction,
"confidence":conf,
"upload_time":time,
"url":url
})
def getAllImages():
data = image_details.find()
return data
Так как мы использовали блюпринт, код для API можно разместить в отдельном файле:
from flask import Flask,render_template,jsonify,Blueprint
mod = Blueprint('api',__name__,template_folder='templates')
from modules.dataBase import collection as db
from bson.json_util import dumps
@mod.route('/')
def api():
return dumps(db.getAllImages())
Как можно заметить, для возвращения данных БД мы используем json. Посмотреть на результат можно по адресу 127.0.0.1:5000/api
Выше, разумеется, только самые важные куски кода. Полностью проект можно посмотреть в GitHub репозитории. А больше о Pymongo можно почитать здесь.
Создаём приложение Flutter
Мобильная версия будет получать изображения и данные об их содержимом по REST API. Вот что получится в итоге:
ImageData класс инкапсулирует данные об изображении:
import 'dart:convert';
import 'package:http/http.dart' as http;
import 'dart:async';
class ImageData
{
// static String BASE_URL ='http://192.168.1.103:5000/';
String uri;
String prediction;
ImageData(this.uri,this.prediction);
}
Future<List<ImageData>> LoadImages() async
{
List<ImageData> list;
//complete fetch ....
var data = await http.get('http://192.168.1.103:5000/api/');
var jsondata = json.decode(data.body);
List<ImageData> newslist = [];
for (var data in jsondata) {
ImageData n = ImageData(data['url'],data['prediction']);
newslist.add(n);
}
return newslist;
}
Здесь мы получаем json, преобразуем его в список объектов ImageData и возвращаем во Future Builder с помощью функции LoadImages()
Загрузка изображений на сервер
uploadImageToServer(File imageFile) async {
print("attempting to connecto server......");
var stream =
new http.ByteStream(DelegatingStream.typed(imageFile.openRead()));
var length = await imageFile.length();
print(length);
var uri = Uri.parse('http://192.168.1.103:5000/predict');
print("connection established.");
var request = new http.MultipartRequest("POST", uri);
var multipartFile = new http.MultipartFile('file', stream, length,
filename: basename(imageFile.path));
//contentType: new MediaType('image', 'png'));
request.files.add(multipartFile);
var response = await request.send();
print(response.statusCode);
}
Чтобы сделать Flask доступным в локальной сети отключите режим дебага и найдите ipv4 адрес, используя ipconfig. Запустить локальный сервер можно так:
app.run(debug=False, host='192.168.1.103', port=5000)
Иногда файрвол может мешать приложению обращаться к локалхосту, тогда его придётся перенастроить или отключить.
Весь исходный код приложения доступен на гитхабе. Вот ссылки, которые помогут разобраться в происходящем:
Keras : https://keras.io/
Flutter : https://flutter.dev/
MongoDB : https://www.tutorialspoint.com/mongodb/
Курс Гарварда по Python и Flask: https://www.youtube.com/watch?v=j5wysXqaIV8&t=5515s (особенно важны лекции 2,3,4)
GitHub : https://github.com/SHARONZACHARIA