Простая очередь задач в Django, подключение Kandinsky 2.1
Большинство разработчиков рано или поздно сталкивается с необходимостью реализовать очереди выполнения, сложных вычислительных процессов. Изобилие готовых решений, позволяет выбрать именно то, что вам нужно в вашем текущем задании. Моя задача была достаточно простой и не требовала сложного алгоритма регулирования процесса: один GPU, выполняет какую либо ресурсоёмкую функцию последовательно (в этой статье примером будет Kandinsky 2.1), без обработки ошибок с повторным запуском процесса, функция запускаются пользователем в django. Python разработчика, с такой постановкой вопроса, интернет в первую очередь приводит к Celery. Кратко ознакомившись с документацией, подсознательно закрываешь задачу, все готово, за 15 минут подключу. Но в процессе реализации столкнулся с проблемой, существенно влияющей на скорость выполнения процесса. Каждый раз при выполнении функции в Celery, процесс загружает веса модели, в случае Kandinsky 2.1, веса имеют большой объем. Подобная проблема обсуждалась в stackoverflow. Попытки обойти эту преграду в «сельдерей», приводили к новым ошибкам. Один из вариантов решения ведет к этой статье Scaling AllenNLP/PyTorch in Production, вместо Celery автор предлагает использовать ZeroMQ, с этой реализацией буду работать для решения задачи. В статье вычисления выполняются в CPU, вообщем рекомендую к прочтению. ZeroMQ - будет протоколом обмена между ресурсоёмкой функцией и django, для запуска фонового процесса в очереди буду использовать Django Channels, рабочие и фоновые задачи.
Примерный алгоритм:
Пользователь, из web интерфейса отправляет запрос в Django.
Из основного канала Django, переслать полученный запрос пользователя в «другой» фиксированный канал Django.
Полученный запрос из основного канала отправить по TCP ZeroMQ к ресурсоёмкой функции.
После выполнения функции, вернуть ответ по TCP ZeroMQ в «другой» канал Django.
Вернуть ответ пользователю.
Клиентская часть javascript очень простая, не нуждается в дополнительных комментариях. В моем примере, взаимодействие клиент - сервер происходит с помощью WebSocket, благодаря гибкости Django Channels, вы не ограничены в протоколе обмена данными для запуска фоновых задач.
var ws_wall = new WebSocket("ws://"+ IP_ADDR +":"+PORT+"/");
ws_wall.onmessage = function(event) {
// что то делаем с полученным ответом
}
// отправить данные в Django
function send_wall() {
var body = document.getElementById('id_body');
if (body.value == "") {
return false;
}
if (ws_wall.readyState != WebSocket.OPEN) {
return false;
}
var data = JSON.stringify({body: body.value,
event: "wallpost"});
ws_wall.send(data);
}
Серверный кусок кода основного потока Django Channels. Большинство, практикующих разработчиков взаимодействующих с Django, не увидят чего то нового, стандартные фрагменты кода из документации. Подробно рассказывать о структуре orm django модели не имеет смысла, по коду будет понятно, что всё очень просто.
import json
from channels.generic.websocket import AsyncJsonWebsocketConsumer
from myapp.models import User, Post
from asgiref.sync import sync_to_async
from django.utils import dateformat
class WallHandler(AsyncJsonWebsocketConsumer):
async def connect(self):
"""
инициализация подключения
"""
self.room_group_name = "wall"
self.sender_id = self.scope['user'].id
self.sender_name = self.scope['user']
if str(self.scope['user']) != 'AnonymousUser':
self.path_data = self.scope['user'].path_data
await self.channel_layer.group_add(
self.room_group_name,
self.channel_name
)
await self.accept()
async def disconnect(self, close_code):
"""
обработка отключения
покинуть группу
"""
print("error code: ", close_code)
await self.channel_layer.group_discard(
self.room_group_name,
self.channel_name
)
async def receive(self, text_data):
"""
обработка полученных данных от килиента (WebSocket)
получить событие и отправиь соответствующее событие
"""
response = json.loads(text_data)
event = response.get("event", None)
if self.scope['user'].is_authenticated:
if event == "wallpost":
post = Post()
post.body = response["body"]
post.path_data = self.path_data
post.user_post = self.scope['user']
post_async = sync_to_async(post.save) # взаимодействие с синхронным Django
await post_async()
"""
из основного канала Django,
отправить полученный запрос от пользователя
в «другой» фиксированный канал Django
"""
_temp_dict = {}
_temp_dict["body"] = response["body"]
_temp_dict["path_data"] = self.path_data
_temp_dict["post"] = str(post.id)
_temp_dict["type"] = "triggerWorker"
_temp_dict["room_group_name"] = self.room_group_name
await self.channel_layer.send('nnapp', _temp_dict) # синтаксис await не блокирует работу основного потока
"""
сервер продолжает работу,
отправить текущие готовые данные клиенту,
без окончания выполнения ресурсоёмкой функции
"""
_data = {"type": "wallpost",
"timestamp": dateformat.format(post.date_post, 'U'),
"text":response["body"],
"user_post": str(self.sender_name),
"user_id": str(self.sender_id),
"id": str(post.id),
"status" : "wallpost"
}
await self.channel_layer.group_send(self.room_group_name, _data)
async def wallpost(self, res):
"""
отправить сообщение клиенту(WebSocket)
"""
await self.send(text_data=json.dumps(res))
ВАЖНО!!! Дополнительные настройки в конфигурационном файле asgi.py Django проекта.
"""
ASGI config for app project.
It exposes the ASGI callable as a module-level variable named ``application``.
For more information on this file, see
https://docs.djangoproject.com/en/3.2/howto/deployment/asgi/
"""
import os
import django
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'app.settings')
django.setup()
from django.core.asgi import get_asgi_application
from channels.auth import AuthMiddlewareStack
from channels.routing import ProtocolTypeRouter, URLRouter, ChannelNameRouter
import app.routing
from wall import nnapp
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'app.settings')
application = ProtocolTypeRouter({
"http": get_asgi_application(),
"websocket": AuthMiddlewareStack(
URLRouter(
app.routing.websocket_urlpatterns,
)
),
# добавил «другой» фиксированный канал Django
"channel": ChannelNameRouter({
"nnapp": nnapp.NNHandler.as_asgi(),
}),
})
Создам «другой» канал Django, для обработки фоновой задачи. В новый файл nnapp.py, копирую клиентскую часть из статьи Scaling AllenNLP/PyTorch in Production и изменяю для своих нужд.
from myapp.models import User, Post
from asgiref.sync import async_to_sync
from channels.consumer import SyncConsumer
from channels.layers import get_channel_layer
import uuid
import os
import json
import zlib
import pickle
import zmq
# глобальные переменные из статьи
work_publisher = None
result_subscriber = None
TOPIC = 'snaptravel'
RECEIVE_PORT = 5555
SEND_PORT = 5556
# получить доступ к channel layer
channel_layer = get_channel_layer()
# низкоуровневый протокол подразумивает работу с байтами
def compress(obj):
p = pickle.dumps(obj)
return zlib.compress(p)
def decompress(pickled):
p = zlib.decompress(pickled)
return pickle.loads(p)
def start():
global work_publisher, result_subscriber
context = zmq.Context()
work_publisher = context.socket(zmq.PUB)
work_publisher.connect(f'tcp://127.0.0.1:{SEND_PORT}')
def _parse_recv_for_json(result, topic=TOPIC):
compressed_json = result[len(topic) + 1:]
return decompress(compressed_json)
def send(args, model=None, topic=TOPIC):
id = str(uuid.uuid4())
message = {'body': args["title"], 'model': model, 'id': id}
compressed_message = compress(message)
work_publisher.send(f'{topic} '.encode('utf8') + compressed_message)
return id
def get(id, topic=TOPIC):
context = zmq.Context()
result_subscriber = context.socket(zmq.SUB)
result_subscriber.setsockopt(zmq.SUBSCRIBE, topic.encode('utf8'))
result_subscriber.connect(f'tcp://127.0.0.1:{RECEIVE_PORT}')
result = _parse_recv_for_json(result_subscriber.recv())
while result['id'] != id:
result = _parse_recv_for_json(result_subscriber.recv())
result_subscriber.close()
if result.get('error'):
raise Exception(result['error_msg'])
return result
# эта функция немного отличаеться от оригинала из статьи
def send_and_get(args, model=None):
id = send(args, model=model)
res = get(id)
namefile = f'{id}.jpg'
res['prediction'][0].save(f'media/data_image/{args["path_data"]}/{namefile}', format="JPEG")
post = Post.objects.get(id=args["post"])
post.image = namefile
post.save()
_data = {"type": "wallpost", "status":"Kandinsky-2.1", "path_data": args["path_data"],
"data": f'{namefile}', "post":args["post"]}
# возвращаем результат в основной канал Django Channels
async_to_sync(channel_layer.group_send)(args["room_group_name"], _data)
# запуск фиксированого канала с клиентом ZeroMQ
class NNHandler(SyncConsumer):
start()
def triggerWorker(self, message):
print ("data for a background task: ", message)
send_and_get(message, model='Kandinsky-2.1')
Сервер ZeroMQ с ресурсоёмкой функцией, в моём случае Kadinsky 2.1, по большей части полностью совпадает с кодом статьи, с незначительными дополнениями. Для работы с русским языком, использую обученную модель переводчика Helsinki-NLP/opus-mt-ru-en.
import os, time
from types import SimpleNamespace
import zmq
import zlib
import pickle
import torch.multiprocessing as mp
import threading
import cv2
from kandinsky2 import get_kandinsky2
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import uuid
QUEUE_SIZE = mp.Value('i', 0)
def compress(obj):
p = pickle.dumps(obj)
return zlib.compress(p)
def decompress(pickled):
p = zlib.decompress(pickled)
return pickle.loads(p)
TOPIC = 'snaptravel'
prediction_functions = {}
RECEIVE_PORT = 5556
SEND_PORT = 5555
# «модель» генерация картинки
model = get_kandinsky2('cuda', task_type='text2img', model_version='2.1', use_flash_attention=False)
# «модель» переводчик
tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-ru-en")
model_translater = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-ru-en")
def _parse_recv_for_json(result, topic=TOPIC):
compressed_json = result[len(topic) + 1:]
return decompress(compressed_json)
def _decrease_queue():
with QUEUE_SIZE.get_lock():
QUEUE_SIZE.value -= 1
def _increase_queue():
with QUEUE_SIZE.get_lock():
QUEUE_SIZE.value += 1
def send_prediction(message, result_publisher, topic=TOPIC):
_increase_queue()
model_name = message['model']
body = message['body']
id = message['id']
# подготовка входных данных для обученной «модели» переводчик
tokenized_text = tokenizer([str(body).lower()], return_tensors='pt')
# перевод
translation = model_translater.generate(**tokenized_text)
body = tokenizer.batch_decode(translation, skip_special_tokens=True)[0]
print(body)
# генерация изображения
images = model.generate_text2img(
str(body).lower(),
num_steps=70,
batch_size=1,
guidance_scale=4,
h=768, w=768,
sampler='p_sampler',
prior_cf_scale=4,
prior_steps="5"
)
result = {"result": images}
if result.get('result') is None:
time.sleep(1)
compressed_message = compress({'error': True, 'error_msg': 'No result was given: ' + str(result), 'id': id})
result_publisher.send(f'{topic} '.encode('utf8') + compressed_message)
_decrease_queue()
return
prediction = result['result']
compressed_message = compress({'prediction': prediction, 'id': id})
result_publisher.send(f'{topic} '.encode('utf8') + compressed_message)
_decrease_queue()
print ("SERVER", message, f'{topic} '.encode('utf8'))
def queue_size():
return QUEUE_SIZE.value
def load_models():
models = SimpleNamespace()
return models
def start():
global prediction_functions
models = load_models()
prediction_functions = {
'queue': queue_size
}
print(f'Connecting to {RECEIVE_PORT} in server', TOPIC.encode('utf8'))
context = zmq.Context()
work_subscriber = context.socket(zmq.SUB)
work_subscriber.setsockopt(zmq.SUBSCRIBE, TOPIC.encode('utf8'))
work_subscriber.bind(f'tcp://127.0.0.1:{RECEIVE_PORT}')
# send work
print(f'Connecting to {SEND_PORT} in server')
result_publisher = context.socket(zmq.PUB)
result_publisher.bind(f'tcp://127.0.0.1:{SEND_PORT}')
print('Server started')
while True:
message = _parse_recv_for_json(work_subscriber.recv())
threading.Thread(target=send_prediction, args=(message, result_publisher), kwargs={'topic': TOPIC}).start()
if __name__ == '__main__':
start()
Запускать django channels для фоновых задач, отдельным процессом:
python manage.py runworker nnapp
Эти части кода из работающего проекта моей предыдущей статьи, на столько популярной и комментируемой что даже не нуждается в упоминании. Хороший пример Build a Pytorch Server with celery and RabbitMQ, но плохо масштабируемый и немного сложней в реализации чем обсуждаемый в этой статье. Были еще несколько идей, но решил пойти по пути наименьшего сопротивления... Для работы с кластером gpu, нужен более сложные алгоритм отслеживания загруженности каждого отдельного gpu, поэтому это простой пример на основе чужого труда. Работающий код https://github.com/naturalkind/social-network.