Как стать автором
Обновить

Простая очередь задач в Django, подключение Kandinsky 2.1

Время на прочтение8 мин
Количество просмотров3.7K

Большинство разработчиков рано или поздно сталкивается с необходимостью реализовать очереди выполнения, сложных вычислительных процессов. Изобилие готовых решений, позволяет выбрать именно то, что вам нужно в вашем текущем задании. Моя задача была достаточно простой и не требовала сложного алгоритма регулирования процесса: один GPU, выполняет какую либо ресурсоёмкую функцию последовательно (в этой статье примером будет Kandinsky 2.1), без обработки ошибок с повторным запуском процесса, функция запускаются пользователем в django. Python разработчика, с такой постановкой вопроса, интернет в первую очередь приводит к Celery. Кратко ознакомившись с документацией, подсознательно закрываешь задачу, все готово, за 15 минут подключу. Но в процессе реализации столкнулся с проблемой, существенно влияющей на скорость выполнения процесса. Каждый раз при выполнении функции в Celery, процесс загружает веса модели, в случае Kandinsky 2.1, веса имеют большой объем. Подобная проблема обсуждалась в stackoverflow. Попытки обойти эту преграду в «сельдерей», приводили к новым ошибкам. Один из вариантов решения ведет к этой статье Scaling AllenNLP/PyTorch in Production, вместо Celery автор предлагает использовать ZeroMQ, с этой реализацией буду работать для решения задачи. В статье вычисления выполняются в CPU, вообщем рекомендую к прочтению. ZeroMQ - будет протоколом обмена между ресурсоёмкой функцией и django, для запуска фонового процесса в очереди буду использовать Django Channels, рабочие и фоновые задачи.

Примерный алгоритм:

  1. Пользователь, из web интерфейса отправляет запрос в Django.

  2. Из основного канала Django, переслать полученный запрос пользователя в «другой» фиксированный канал Django.

  3. Полученный запрос из основного канала отправить по TCP ZeroMQ к ресурсоёмкой функции.

  4. После выполнения функции, вернуть ответ по TCP ZeroMQ в «другой» канал Django.

  5. Вернуть ответ пользователю.

Клиентская часть javascript очень простая, не нуждается в дополнительных комментариях. В моем примере, взаимодействие клиент - сервер происходит с помощью WebSocket, благодаря гибкости Django Channels, вы не ограничены в протоколе обмена данными для запуска фоновых задач.

var ws_wall = new WebSocket("ws://"+ IP_ADDR +":"+PORT+"/");

ws_wall.onmessage = function(event) {
  // что то делаем с полученным ответом
}

// отправить данные в Django
function send_wall() {
    var body = document.getElementById('id_body');
    if (body.value == "") {
        return false;
    }
    if (ws_wall.readyState != WebSocket.OPEN) {
        return false;
    }
    var data = JSON.stringify({body: body.value,
                               event: "wallpost"});
    ws_wall.send(data);
}

Серверный кусок кода основного потока Django Channels. Большинство, практикующих разработчиков взаимодействующих с Django, не увидят чего то нового, стандартные фрагменты кода из документации. Подробно рассказывать о структуре orm django модели не имеет смысла, по коду будет понятно, что всё очень просто.

import json
from channels.generic.websocket import AsyncJsonWebsocketConsumer
from myapp.models import User, Post
from asgiref.sync import sync_to_async
from django.utils import dateformat

class WallHandler(AsyncJsonWebsocketConsumer):
    async def connect(self):
        """
        инициализация подключения
        """
        self.room_group_name = "wall"
        self.sender_id = self.scope['user'].id
        self.sender_name = self.scope['user']
        if str(self.scope['user']) != 'AnonymousUser':
            self.path_data = self.scope['user'].path_data
        await self.channel_layer.group_add(
            self.room_group_name,
            self.channel_name
        )
        await self.accept()
        
    async def disconnect(self, close_code):
        """
        обработка отключения
        покинуть группу
        """
        print("error code: ", close_code)
        await self.channel_layer.group_discard(
            self.room_group_name,
            self.channel_name
        )      
        
    async def receive(self, text_data):
        """
        обработка полученных данных от килиента (WebSocket)
        получить событие и отправиь соответствующее событие
        """
        response = json.loads(text_data)
        event = response.get("event", None)
        if self.scope['user'].is_authenticated: 
            if event == "wallpost":
                post = Post()
                post.body = response["body"]
                post.path_data = self.path_data
                post.user_post = self.scope['user']
                post_async = sync_to_async(post.save) # взаимодействие с синхронным Django
                await post_async()

                """
                из основного канала Django, 
                отправить полученный запрос от пользователя 
                в «другой» фиксированный канал Django
                """
                _temp_dict = {}
                _temp_dict["body"] = response["body"]
                _temp_dict["path_data"] = self.path_data
                _temp_dict["post"] = str(post.id)
                _temp_dict["type"] = "triggerWorker"
                _temp_dict["room_group_name"] = self.room_group_name
                await self.channel_layer.send('nnapp', _temp_dict) # синтаксис await не блокирует работу основного потока
                """
                сервер продолжает работу,
                отправить текущие готовые данные клиенту, 
                без окончания выполнения ресурсоёмкой функции
                """
                _data = {"type": "wallpost",
                         "timestamp": dateformat.format(post.date_post, 'U'),
                         "text":response["body"],
                         "user_post": str(self.sender_name),
                         "user_id": str(self.sender_id),
                         "id": str(post.id),
                         "status" : "wallpost"
                        }
                await self.channel_layer.group_send(self.room_group_name, _data)
    async def wallpost(self, res):
        """  
        отправить сообщение клиенту(WebSocket)
        """
        await self.send(text_data=json.dumps(res))                

ВАЖНО!!! Дополнительные настройки в конфигурационном файле asgi.py Django проекта.

"""
ASGI config for app project.

It exposes the ASGI callable as a module-level variable named ``application``.

For more information on this file, see
https://docs.djangoproject.com/en/3.2/howto/deployment/asgi/
"""

import os
import django
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'app.settings')
django.setup()

from django.core.asgi import get_asgi_application
from channels.auth import AuthMiddlewareStack
from channels.routing import ProtocolTypeRouter, URLRouter, ChannelNameRouter
import app.routing
from wall import nnapp

os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'app.settings')

application = ProtocolTypeRouter({
    "http": get_asgi_application(),
    "websocket": AuthMiddlewareStack(
        URLRouter(
            app.routing.websocket_urlpatterns,
        )
    ),
    # добавил «другой» фиксированный канал Django
    "channel": ChannelNameRouter({
        "nnapp": nnapp.NNHandler.as_asgi(),
    }),
})

Создам «другой» канал Django, для обработки фоновой задачи. В новый файл nnapp.py, копирую клиентскую часть из статьи Scaling AllenNLP/PyTorch in Production и изменяю для своих нужд.

from myapp.models import User, Post
from asgiref.sync import async_to_sync
from channels.consumer import SyncConsumer
from channels.layers import get_channel_layer
import uuid
import os
import json
import zlib
import pickle
import zmq

# глобальные переменные из статьи
work_publisher = None
result_subscriber = None
TOPIC = 'snaptravel'

RECEIVE_PORT = 5555
SEND_PORT = 5556 

# получить доступ к channel layer
channel_layer = get_channel_layer()

# низкоуровневый протокол подразумивает работу с байтами
def compress(obj):
    p = pickle.dumps(obj)
    return zlib.compress(p)

def decompress(pickled):
    p = zlib.decompress(pickled)
    return pickle.loads(p)
    
def start():
    global work_publisher, result_subscriber
    context = zmq.Context()
    work_publisher = context.socket(zmq.PUB)
    work_publisher.connect(f'tcp://127.0.0.1:{SEND_PORT}') 

def _parse_recv_for_json(result, topic=TOPIC):
    compressed_json = result[len(topic) + 1:]
    return decompress(compressed_json)

def send(args, model=None, topic=TOPIC):
    id = str(uuid.uuid4())
    message = {'body': args["title"], 'model': model, 'id': id}
    compressed_message = compress(message)
    work_publisher.send(f'{topic} '.encode('utf8') + compressed_message)
    return id

def get(id, topic=TOPIC):
    context = zmq.Context()
    result_subscriber = context.socket(zmq.SUB)
    result_subscriber.setsockopt(zmq.SUBSCRIBE, topic.encode('utf8'))
    result_subscriber.connect(f'tcp://127.0.0.1:{RECEIVE_PORT}')
    result = _parse_recv_for_json(result_subscriber.recv())
    while result['id'] != id:
        result = _parse_recv_for_json(result_subscriber.recv())
    result_subscriber.close()
    if result.get('error'):
        raise Exception(result['error_msg'])
    return result
  
# эта функция немного отличаеться от оригинала из статьи
def send_and_get(args, model=None):
    id = send(args, model=model)
    res = get(id)
    namefile = f'{id}.jpg'
    res['prediction'][0].save(f'media/data_image/{args["path_data"]}/{namefile}', format="JPEG")
    post = Post.objects.get(id=args["post"]) 
    post.image = namefile
    post.save()
    _data = {"type": "wallpost", "status":"Kandinsky-2.1", "path_data": args["path_data"],
             "data": f'{namefile}', "post":args["post"]}
    # возвращаем результат в основной канал Django Channels
    async_to_sync(channel_layer.group_send)(args["room_group_name"], _data)

# запуск фиксированого канала с клиентом ZeroMQ
class NNHandler(SyncConsumer):
    start()
    def triggerWorker(self, message):
        print ("data for a background task: ", message)
        send_and_get(message, model='Kandinsky-2.1')

Сервер ZeroMQ с ресурсоёмкой функцией, в моём случае Kadinsky 2.1, по большей части полностью совпадает с кодом статьи, с незначительными дополнениями. Для работы с русским языком, использую обученную модель переводчика Helsinki-NLP/opus-mt-ru-en.

import os, time
from types import SimpleNamespace
import zmq
import zlib
import pickle
import torch.multiprocessing as mp
import threading
import cv2
from kandinsky2 import get_kandinsky2
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

import uuid

QUEUE_SIZE = mp.Value('i', 0)

def compress(obj):
    p = pickle.dumps(obj)
    return zlib.compress(p)

def decompress(pickled):
    p = zlib.decompress(pickled)
    return pickle.loads(p)

TOPIC = 'snaptravel'
prediction_functions = {}

RECEIVE_PORT = 5556
SEND_PORT = 5555

# «модель» генерация картинки
model = get_kandinsky2('cuda', task_type='text2img', model_version='2.1', use_flash_attention=False)

# «модель» переводчик
tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-ru-en")
model_translater = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-ru-en")


def _parse_recv_for_json(result, topic=TOPIC):
    compressed_json = result[len(topic) + 1:]
    return decompress(compressed_json)

def _decrease_queue():
    with QUEUE_SIZE.get_lock():
        QUEUE_SIZE.value -= 1

def _increase_queue():
    with QUEUE_SIZE.get_lock():
        QUEUE_SIZE.value += 1
    
def send_prediction(message, result_publisher, topic=TOPIC):
    _increase_queue()
    model_name = message['model']
    body = message['body']
    id = message['id']
    
    # подготовка входных данных для обученной «модели» переводчик
    tokenized_text = tokenizer([str(body).lower()], return_tensors='pt')

    # перевод
    translation = model_translater.generate(**tokenized_text)
    body = tokenizer.batch_decode(translation, skip_special_tokens=True)[0]
    print(body)

    # генерация изображения
    images = model.generate_text2img(
        str(body).lower(), 
        num_steps=70,
        batch_size=1, 
        guidance_scale=4,
        h=768, w=768,
        sampler='p_sampler', 
        prior_cf_scale=4,
        prior_steps="5"
    )
    result = {"result": images}

    if result.get('result') is None:
        time.sleep(1)
        compressed_message = compress({'error': True, 'error_msg': 'No result was given: ' + str(result), 'id': id})
        result_publisher.send(f'{topic} '.encode('utf8') + compressed_message)
        _decrease_queue()
        return
      
    prediction = result['result']
    compressed_message = compress({'prediction': prediction, 'id': id})
    result_publisher.send(f'{topic} '.encode('utf8') + compressed_message)
    _decrease_queue()
    print ("SERVER", message, f'{topic} '.encode('utf8'))

def queue_size():
    return QUEUE_SIZE.value

def load_models():
    models = SimpleNamespace()
    return models

def start():
    global prediction_functions

    models = load_models()
    prediction_functions = {
    'queue': queue_size
    }

    print(f'Connecting to {RECEIVE_PORT} in server', TOPIC.encode('utf8'))
    context = zmq.Context()
    work_subscriber = context.socket(zmq.SUB)
    work_subscriber.setsockopt(zmq.SUBSCRIBE, TOPIC.encode('utf8'))
    work_subscriber.bind(f'tcp://127.0.0.1:{RECEIVE_PORT}')

    # send work
    print(f'Connecting to {SEND_PORT} in server')
    result_publisher = context.socket(zmq.PUB)
    result_publisher.bind(f'tcp://127.0.0.1:{SEND_PORT}')

    print('Server started')
    while True:
        message = _parse_recv_for_json(work_subscriber.recv())
        threading.Thread(target=send_prediction, args=(message, result_publisher), kwargs={'topic': TOPIC}).start()

if __name__ == '__main__':
  start()

Запускать django channels для фоновых задач, отдельным процессом:
python manage.py runworker nnapp

Эти части кода из работающего проекта моей предыдущей статьи, на столько популярной и комментируемой что даже не нуждается в упоминании. Хороший пример Build a Pytorch Server with celery and RabbitMQ, но плохо масштабируемый и немного сложней в реализации чем обсуждаемый в этой статье. Были еще несколько идей, но решил пойти по пути наименьшего сопротивления... Для работы с кластером gpu, нужен более сложные алгоритм отслеживания загруженности каждого отдельного gpu, поэтому это простой пример на основе чужого труда. Работающий код https://github.com/naturalkind/social-network.

Теги:
Хабы:
Всего голосов 3: ↑3 и ↓0+3
Комментарии5

Публикации

Истории

Работа

Python разработчик
136 вакансий
Data Scientist
60 вакансий

Ближайшие события