Pull to refresh

Как прикрутить нейросеть к сайту по-быстрому

Reading time7 min
Views15K


В данном материале предлагается, приложив небольшие усилия, соединить python 3.7+flask+tensorflow 2.0+keras+небольшие вкрапления js и вывести на web-страницу определенный интерактив. Пользователь, рисуя на холсте, будет отправлять на распознавание цифры, а ранее обученная модель, использующая архитектуру CNN, будет распознавать полученный рисунок и выводить результат. Модель обучена на известном наборе рукописных цифр MNIST, поэтому и распознавать будет только цифры от 0 до 9 включительно. В качестве системы, на которой все это будет крутиться, используется windows 7.

Небольшое вступление


Чем печальны книги по машинному обучению, так, пожалуй, тем, что код устаревает почти с выходом самой книги. И хорошо, если автор издания поддерживает свое дитя, сопровождая и обновляя код, но, зачастую все ограничивается тем, что пишут — вот вам requirements.txt, ставьте устаревшие пакеты, и все заработает.

Так вышло и в этот раз. Читая «Hands-On Python Deep Learning for the Web» авторства Anubhav Singh, Sayak Paul, сначала все шло хорошо. Однако, после первой главы праздник закончился. Самое неприятное было то, что заявленные требования в requirements в целом соблюдались.

Масло в огонь подлили и сами разработчики пакетов tensorflow и keras. Один пакет работает только с определенным другим и, либо даунгрейд одного из них либо бубен шамана.
Но и это еще не все. Оказывается, что некоторые пакеты еще и зависимы от архитектуры используемого железа!

Так, за неимением алтернативы железа, устанавливался tensorflow 2.0 на платформу с Celeron j1900 и, как оказалось, там нет инструкции AVX2:


И вариант через pip install tensorflow не работал.

Но не все так грустно при наличии желания и интернета!

Вариант с tensorflow 2.0 удалось реализовать через wheel — github.com/fo40225/tensorflow-windows-wheel/tree/master/2.0.0/py37/CPU/sse2 и установку x86: vc_redist.x86.exe, x64: vc_redist.x64.exe (https://support.microsoft.com/en-us/help/2977003/the-latest-supported-visual-c-downloads).

Keras был установлен с минимальной версией, с которой он «стал совместим» с tensorflow — Keras==2.3.0.

Поэтому

pip install tensorflow-2.0.0-cp37-cp37m-win_amd64.whl

и

pip install keras==2.3.0

Основное приложение


Рассмотрим код основной программы.

flask_app.py

#code work with scipy==1.6.1, tensorflow @ file:///D:/python64/tensorflow-2.0.0-cp37-cp37m-win_amd64.whl,
#Keras==2.3.0

from flask import Flask, render_template, request
import imageio
#https://imageio.readthedocs.io/en/stable/examples.html
#from scipy.misc import imread, imresize
#from matplotlib.pyplot import imread
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import model_from_json
from skimage import transform,io

json_file = open('model.json','r')
model_json = json_file.read()
json_file.close()
model = model_from_json(model_json)
model.load_weights("weights.h5")
model.compile(loss='categorical_crossentropy',optimizer='adam',metrics=['accuracy'])
#graph = tf.get_default_graph()
graph = tf.compat.v1.get_default_graph()

app = Flask(__name__)

@app.route('/')
def index():
    return render_template("index.html")
import re
import base64

def convertImage(imgData1):
    imgstr = re.search(r'base64,(.*)', str(imgData1)).group(1)
    with open('output.png', 'wb') as output:
        output.write(base64.b64decode(imgstr))

@app.route('/predict/', methods=['GET', 'POST'])
def predict():
    global model, graph
    
    imgData = request.get_data()
    convertImage(imgData)
    #print(imgData)
   
    #x = imread('output.png', mode='L')
    #x.shape
    #(280, 280)
    x = imageio.imread('output.png',pilmode='L')
    #x = imresize(x, (28, 28))
    #x = x.resize(x, (28, 28))
    x = transform.resize(x, (28,28), mode='symmetric', preserve_range=True)
    #(28, 28)
    #type(x)
    #<class 'numpy.ndarray'>

    x = x.reshape(1, 28, 28, 1)
    #(1, 28, 28, 1) 
    x = tf.cast(x, tf.float32)
    
    # perform the prediction
    out = model.predict(x)        
    #print(np.argmax(out, axis=1))
    # convert the response to a string
    response = np.argmax(out, axis=1)
    return str(response[0])

if __name__ == "__main__":
    # run the app locally on the given port
    app.run(host='0.0.0.0', port=80)
# optional if we want to run in debugging mode
    app.run(debug=True)



Подгрузили пакеты:

from flask import Flask, render_template, request
import imageio
#https://imageio.readthedocs.io/en/stable/examples.html
#from scipy.misc import imread, imresize
#from matplotlib.pyplot import imread
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import model_from_json
from skimage import transform,io

Как выяснилось imread, imresize устарели еще со времен scipy==1.0. Непонятно, как у автора все работало, учитывая, что книга относительно нова (2019). С современной scipy==1.6.1 книжный вариант кода не работал.

Загружаем с диска, компилируем модель нейросети:


json_file = open('model.json','r')
model_json = json_file.read()
json_file.close()
model = model_from_json(model_json)
model.load_weights("weights.h5")
model.compile(loss='categorical_crossentropy',optimizer='adam',metrics=['accuracy'])
#graph = tf.get_default_graph()
graph = tf.compat.v1.get_default_graph()

Здесь произведена замена на tf.compat.v1.get_default_graph() в виду несовместимости.

Далее часть, относящаяся к серверу на flask. «Прорисовка» шаблона страницы:


@app.route('/')
def index():
    return render_template("index.html")

Часть, преобразующая картинку в числовой массив:


import re
import base64

def convertImage(imgData1):
    imgstr = re.search(r'base64,(.*)', str(imgData1)).group(1)
    with open('output.png', 'wb') as output:
        output.write(base64.b64decode(imgstr))

Основная функция предсказания:


def predict():
    global model, graph
    
    imgData = request.get_data()
    convertImage(imgData)
    #print(imgData)
   
    #x = imread('output.png', mode='L')
    #x.shape
    #(280, 280)
    x = imageio.imread('output.png',pilmode='L')
    #x = imresize(x, (28, 28))
    #x = x.resize(x, (28, 28))
    x = transform.resize(x, (28,28), mode='symmetric', preserve_range=True)
    #(28, 28)
    #type(x)
    #<class 'numpy.ndarray'>

    x = x.reshape(1, 28, 28, 1)
    #(1, 28, 28, 1) 
    x = tf.cast(x, tf.float32)
    
    # perform the prediction
    out = model.predict(x)        
    #print(np.argmax(out, axis=1))
    # convert the response to a string
    response = np.argmax(out, axis=1)
    return str(response[0])

Закоментированы строки, которые были заменены на рабочие, а также оставлены выводы отдельных строк для наглядности.

Как все работает


После запуска командой python flask_app.py запускается локальный flask-сервер, который выводит index.html с вкраплением js.

Пользователь рисует на холсте цифру, нажимает «predict». Картинка «улетает» на сервер, где сохраняется и преобразуется в цифровой массив. Далее в бой вступает CNN, распознающая цифру и возвращающая ответ в виде цифры.

Сеть не всегда дает верный ответ, т.к. обучалась всего на 10 эпохах. Это можно наблюдать, если нарисовать «спорную» цифру, которая может трактоваться по-разному.

*Можно покрутить слайдер, увеличивая или уменьшая толщину начертания цифры для целей распознавания.

Второй вариант программы — через API,curl


Поользователь загружает на сервер свое изображение с цифрой для распознавания и нажимает «отправить»:



Заменим index.js на следующий:

index.js:
$("form").submit(function(evt){
	evt.preventDefault();
	var formData = new FormData($(this)[0]);
	$.ajax({
		url: '/predict/',
		type: 'POST',
		data: formData,
		async: false,
		cache: false,
		contentType: false,
		enctype: 'multipart/form-data',
		processData: false,
		success: function (response) {
			$('#result').empty().append(response);
		}
	});
	return false;
});


Шаблон страницы также изменится:

index.html
<!DOCTYPE html>
<html lang="en">
<head>
<title>MNIST CNN</title>
</head>
<body>
<h1>MNIST Handwritten Digits Prediction</h1>
<form>
<input type="file" name="img"></input>
<input type="submit"></input>
</form>
<hr>
<h3>Prediction: <span id="result"></span></h3>
<script
src='https://code.jquery.com/jquery-3.6.0.min.js'></script>
<script src="{{ url_for('static',filename='index.js') }}"></script>
</body>
</html>


Немного изменится и основная программа:

flask_app2.py

#code work with scipy==1.6.1, tensorflow @ file:///D:/python64/tensorflow-2.0.0-cp37-cp37m-win_amd64.whl,
#Keras==2.3.0

from flask import Flask, render_template, request
import imageio
#https://imageio.readthedocs.io/en/stable/examples.html
#from scipy.misc import imread, imresize
#from matplotlib.pyplot import imread
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import model_from_json
from skimage import transform,io


json_file = open('model.json','r')
model_json = json_file.read()
json_file.close()
model = model_from_json(model_json)
model.load_weights("weights.h5")
model.compile(loss='categorical_crossentropy',optimizer='adam',metrics=['accuracy'])
#graph = tf.get_default_graph()
graph = tf.compat.v1.get_default_graph()

app = Flask(__name__)

@app.route('/')
def index():
    return render_template("index.html")

import re
import base64

def convertImage(imgData1):
    imgstr = re.search(r'base64,(.*)', str(imgData1)).group(1)
    with open('output.png', 'wb') as output:
        output.write(base64.b64decode(imgstr))

@app.route('/predict/', methods=['POST'])
def predict():
    global model, graph
    
    imgData = request.get_data()
    try:
        stringToImage(imgData)
    except:
        f = request.files['img']
        f.save('image.png')
       
    #x = imread('output.png', mode='L')
    #x.shape
    #(280, 280)
    x = imageio.imread('image.png',pilmode='L')
    #x = imresize(x, (28, 28))
    #x = x.resize(x, (28, 28))
    x = transform.resize(x, (28,28), mode='symmetric', preserve_range=True)
    #(28, 28)
    #type(x)
    #<class 'numpy.ndarray'>

    x = x.reshape(1, 28, 28, 1)
    #(1, 28, 28, 1) 
    x = tf.cast(x, tf.float32)
    
    # perform the prediction
    out = model.predict(x)        
    #print(np.argmax(out, axis=1))
    # convert the response to a string
    response = np.argmax(out, axis=1)
    return str(response[0])

if __name__ == "__main__":

    # run the app locally on the given port
    app.run(host='0.0.0.0', port=80)
# optional if we want to run in debugging mode
    app.run(debug=True)



Запускается все похоже — python flask_app2.py

Вариант с curl (для windows)


Скачиваем curl

В командной строке windows отправляем команду:


curl -X POST -F img=@1.png http://localhost/predict/

где 1.png — картинка с цифрой (или она же с путем к ней).
В ответ прилетит распознанная цифра.

Файлы для скачивания — скачать.
Tags:
Hubs:
Total votes 4: ↑4 and ↓0+4
Comments6

Articles