Как стать автором
Обновить

PyTriton inference server c Gradio: быстро и просто делаем демо для ML-проектов

Уровень сложностиСредний
Время на прочтение10 мин
Количество просмотров5.6K

Введение

Предположим вы обучили какую-то модель решать нужную задачу. Метрики хорошие, код работает как нужно. Теперь стоит задача как можно быстро и с минимальными затратами собственного времени написать сервер и предоставить стандартный API для использования вашей модели. Желательно переиспользовать наработки для эксплуатации в дальнейшем.

Первое из таких приложений будет более наглядно демонстрировать результаты работы вашей модели, которое также надо написать самостоятельно. Это будет минимальный UI, который может показать то как работает модель на разных данных.

Базовый список проблем обычно следующий:

  • Нужно написать код сервера и настроить его запуск.

  • Моделей и их версий может быть несколько.

  • Нужно чтобы время на обработку запросов было как можно меньше.

  • Нужен базовый мониторинг и логирование.

  • Эффективно использовать GPU, CPU, память.

  • Написать минимальный UI для приложения.

  • Поддержка разных типов обработки (offline, несколько примеров данных за раз (batch), по одному примеру за раз, потоковая обработка (streaming)).

В зависимости от требований что-то можно убрать или добавить из списка проблем. В любом случае получается достаточное число задач на одного человека. Особенно если написание backend-серверов и UI для вас не основной вид деятельности.

Разберём по порядку что можно использовать для каждой из задач. Весь код для этой статьи есть на GitHub. По умолчанию всё настроено так чтобы работало на CPU. Если у вас есть видеокарта с поддержкой CUDA, то инструкции есть описание как переделать на то чтобы использовалась видеокарта.

Написание сервера для модели

Для такого рода задач уже существуют готовые решения. В этой статье будет рассматриваться PyTriton. PyTriton упрощает работу с triton inference server и позволяет оставаться в рамках Python.

На момент написания статьи последней версией была 0.3.0. Процесс установки есть в документации. Каких-то проблемы быть не должно. Работает только в Linux. Для пользователей Windows 10 и выше можно запустить в WSL для тестов. До версии 0.3.0 была привязка к Python 3.8. Нужно было либо использовать Python 3.8 сразу или создавать дополнительное виртуальное окружение.

PyTriton решает задачи из основного списка:

  • Готовый HTTP/gRPC сервер.

  • Можно зарегистрировать несколько разных моделей, версий одной модели.

  • Есть возможность получить метрики по использованию CPU, GPU и т. п.

  • Базовое логирование.

  • Эффективное использование GPU.

  • Можно реализовать разные типы обработок. На момент написания потоковая обработка была ещё в разработке.

  • Получаем базовый API для проверки статуса сервера: готов ли он принимать запросы и находится ли в рабочем состоянии. Может быть важно, если дальше планируется запускать в Kubernetes.

  • Есть возможность кешировать запросы.

Для того чтобы написать и запустить свой первый сервер необходимо выполнить следующие шаги:

  1. Подготовить модель.

  2. Написать логику инициализации модели, её регистрации, обработки запроса.

  3. Запустить сервер и проверить работу.

  4. Упаковать всё в Docker образ и запускать через docker compose. Опциональный шаг.

Подготовка модели

Для простоты будем использовать готовые модели в формате ONNX на примере переноса стиля. Для каждого стиля нужно использовать свою модель. Возьмём их в качестве примера регистрации и использования нескольких моделей на сервере. Всего на сервере будет зарегистрировано две модели. Каждая модель отвечает за свой стиль.

Для своей модели можно воспользоваться соответствующими инструментами для конвертации в ONNX или использовать Model navigator. Не любую модель можно конвертировать просто так без изменения логики её определения в коде, поэтому желательно заранее писать код так чтобы можно было сохранить в нужном формате. Например, написать тест на возможность конвертации при старте разработки.

В репозитории с примером для этой статьи достаточно выполнить:

dvc pull
dvc repro

В директории ./models/converted репозитория будут ONNX модели конвертированные до последней версии opset.

Определение логики сервера

Всю логику обработки изображения реализуем на стороне сервере. Клиенту необходимо только передать изображение и получить результат. Так как готовые ONNX модели содержат определения с фиксированным размером входа, то реализовать обработку сразу нескольких изображений произвольного размера за раз не получится. Придётся приводить все изображения к нужному размеру. Качество результата пострадает из-за этого, но для демонстрационных целей это не страшно.

Схема обработки представлена на рис. ниже (изображение можно увеличить):

Определим класс для переноса стиля. Передаём путь к модели и то под каким именем клиент может получить стилизованное изображение. Модель запускается через ONNX runtime. В качестве Execution Provider передаётся список доступных. В конструкторе класса получаем информацию о названии для входных и выходных данных, размере входных данных. В методе __call__ выполняем необходимые преобразования и возвращаем результат.

Код класса для переноса стиля
class StyleTransferONNX:
    def __init__(self, path_to_model: str, out_api_name: str):
        self._session = ort.InferenceSession(
            path_to_model, providers=ort.get_available_providers())
        input_info = self._session.get_inputs()[0]
        self._input_name = input_info.name
        self._in_height, self._in_width = input_info.shape[2:]
        self._max_image_size = min(self._in_height, self._in_width)
        self._out_model_name = self._session.get_outputs()[0].name
        self._out_api_name = out_api_name

    @sample
    def __cal__(self, image: np.ndarray):
        orig_height, orig_width = image.shape[:2]
        image = resize_image(image, self._max_image_size)
        image, pad_info = pad_image(
            image, self._in_width, self._in_height)
        image = image.transpose(2, 0, 1)[np.newaxis, ...].astype(np.float32)

        out_image_batch = self._session.run(
            [self._out_model_name], {self._input_name: image})[0]

        out_image = out_image_batch[0].transpose(1, 2, 0)[:pad_info.orig_image_height,
                                                          :pad_info.orig_image_width, ...]
        out_image = out_image.clip(0, 255).astype(np.uint8)

        out_image = cv2.resize(out_image, (orig_width, orig_height),
                               interpolation=cv2.INTER_LANCZOS4)

        return {self._out_api_name: out_image}

Тут стоит обратить внимание на то какой тип у входного параметра, результата выполнения, а также на декоратор sample. Тип входного параметра sample это np.ndarray. PyTriton делает за нас преобразование данных из входного запроса в массивы numpy. В качестве типа результата словарь, где ключи это название выходов, значения это какие-то numpy массива. В нашем случае это стилизованное изображение. Декоратор sample определён в библиотеке PyTriton и он обозначает, что наша функция может обработать только один пример за раз. Он занимается преобразованием данных запроса в вызов нашей функции с передачей аргументов по соответствующим именам. Если в запросе указан параметр с именем image, то значение этого параметра будет установлено как аргумент для параметра image нашей функции __call__. Если была бы возможность делать обработку нескольких примеров за раз, то можно использовать декоратор batch.

Теперь можно приступить к инициализации моделей, их регистрации и запуска сервера.

Логика запуска и инициализации сервера
if __name__ == "__main__":
    model_dir = pathlib.Path("models", "converted")
    onnx_model_paths = list(model_dir.glob("*.onnx"))
    assert len(onnx_model_paths) > 0, "Cannot find any ONNX model"

    with Triton(config=TritonConfig(strict_readiness=True)) as triton:
        out_api_name = "styled_image"

        for onnx_path in onnx_model_paths:
            model = StyleTransferONNX(str(onnx_path), out_api_name)
            model_name = onnx_path.stem.split("-")[0]

            triton.bind(
                model_name=model_name,
                infer_func=model.__cal__,
                inputs=[
                    Tensor(dtype=np.uint8, shape=(-1, -1, 3), name="image"),
                ],
                outputs=[
                    Tensor(dtype=np.uint8, shape=(-1, -1, 3),
                           name=out_api_name),
                ],
                config=ModelConfig(batching=False)
            )

        triton.serve()

Создаём через контекстный менеджер объект класса Triton. Передаём явно некоторые настройки. А именно strict_readiness, которые обозначает, что на запрос о готовности сервера он будет возвращать true в случае, если все модели готовы и сервера готов отвечать на запросы. Другие настройки можно передать также здесь. Далее создаем модель со стилем StyleTransferONNX, определяем её имя. Функция bind регистрируем модель. Передаём её имя (model_name), какую функцию вызывать для обработки запросов infer_func, какие входные и выходные данные (inputs, outputs), конфигурационный файл модели config. В нашем примере отключаем режим обработки нескольких примеров за раз (batching).

Для определения входных и выходных данных нам доступен только класс Tensor. Это намёк на то все данные должны передаваться в виде массивов. Просто так массив строк или произвольный json не передать. Есть вариант всё закодировать в байтах и передавать информацию как массив байт. В Tensor нужно передать название параметра, его тип, размер, является ли он опциональным. Если размер может быть переменным, то указать -1. Это работает для любого измерения. Например, в нашем случае вход и выход это один массив - одно изображения типа uint8 произвольной высоты и ширины, но с фиксированным числом каналов, равным 3. Ожидается стандартное изображение в формате RGB. Ещё можно указать параметр strict = True чтобы проверять размер и тип того что возвращает модель после получение результата работы. После вызова bind и запуска сервера можно отправлять запросы на обработку модели по заданному имени.

В конце запускаем сам сервер triton.serve(). После запуска в логах будет выведена базовая информация: какие модели зарегистрированы, какие порты отвечают за какие протоколы и т. п.

Примерный лог после запуска сервера
| +------------+---------+--------+
| | Model      | Version | Status |
| +------------+---------+--------+
| | pointilism | 1       | READY  |
| | udne       | 1       | READY  |
| +------------+---------+--------+
...

| +----------------------------------+------------------------------------------+
| | Option                           | Value                                    |
| +----------------------------------+------------------------------------------+
| | server_id                        | triton                                   |
| | server_version                   | 2.36.0                                   |
| | server_extensions                | classification sequence model_repository |
| |                                  |  model_repository(unload_dependents) sch |
| |                                  | edule_policy model_configuration system_ |
| |                                  | shared_memory cuda_shared_memory binary_ |
| |                                  | tensor_data parameters statistics trace  |
| |                                  | logging                                  |
| | model_repository_path[0]         | /root/.cache/pytriton/workspace_igw1_6ol |
| |                                  | /model-store                             |
| | model_control_mode               | MODE_NONE                                |
| | strict_model_config              | 0                                        |
| | rate_limit                       | OFF                                      |
| | pinned_memory_pool_byte_size     | 268435456                                |
| | min_supported_compute_capability | 6.0                                      |
| | strict_readiness                 | 1                                        |
| | exit_timeout                     | 30                                       |
| | cache_enabled                    | 0                                        |
| +----------------------------------+------------------------------------------+
|
| I0926 11:35:05.431413 63 grpc_server.cc:2451] Started GRPCInferenceService at 0.0.0.0:8001
| I0926 11:35:05.431799 63 http_server.cc:3558] Started HTTPService at 0.0.0.0:8000
| I0926 11:35:05.475171 63 http_server.cc:187] Started Metrics Service at 0.0.0.0:8002

У сервера есть стандартный API для проверки статуса работы и использования моделей. Некоторые из них приведены в таблице ниже:

HTTP API

Описание

GET v2/health/ready

Готов ли сервер обрабатывать запросы

GET v2/health/live

Находится ли сервер в рабочем состоянии

POST v2/models/${MODEL_NAME}[/versions/${MODEL_VERSION}]/infer

Вызвать модель ${MODEL_NAME} версии ${MODEL_VERSION} для получения результата

Для того чтобы получить стилизованное изображение с использованием стиля pointilism нужно вызвать метод POST v2/models/pointilism/infer или POST v2/models/pointilism/versions/1/infer. Можно вызывать несколько версий одной модели.

Для проверки сразу отправим запрос в виде json, который содержит белое изображение размера 5x5:

curl -X POST -H "Content-Type: application/json" -d @query.json http://localhost:8000/v2/models/pointilism/infer
query.json
{
  "id": "1",
  "inputs": [
    {
      "name": "image",
      "shape": [5, 5, 3],
      "datatype": "UINT8",
      "parameters": {},
      "data": [
            [
                [255, 255, 255],
                [255, 255, 255],
                [255, 255, 255],
                [255, 255, 255],
                [255, 255, 255]
            ],
             [
                [255, 255, 255],
                [255, 255, 255],
                [255, 255, 255],
                [255, 255, 255],
                [255, 255, 255]
            ],
             [
                [255, 255, 255],
                [255, 255, 255],
                [255, 255, 255],
                [255, 255, 255],
                [255, 255, 255]
            ],
             [
                [255, 255, 255],
                [255, 255, 255],
                [255, 255, 255],
                [255, 255, 255],
                [255, 255, 255]
            ],
             [
                [255, 255, 255],
                [255, 255, 255],
                [255, 255, 255],
                [255, 255, 255],
                [255, 255, 255]
            ]
        ]
    }
  ]
}

В результате получим стилизованное изображение и ответ тоже в виде json. Более подробно про описание API и его расширения можно прочитать в документации

Также можно получить метрики с сервера:

curl -X GET http://localhost:8002/metrics

В дальнейшем их можно собирать и строить графики для отслеживания того что представляет интерес.

Запуск в контейнере

Для того чтобы сервер было просто запускать и подготовить к быстрому развёртыванию внутри кластера Kubernetes или просто на отдельном сервере напишем Dockerfile и docker-compose для запуска. Детально это разбирать не будем, но остановимся на некоторых моментах.

Для использования видеокарты берём готовый образ от NVIDIA с установленной CUDA с минимальными зависимостями (runtime). Для того чтобы видеокарты была доступна внутри контейнера нужно настроить nvidia-conatiner-toolkit и раскомментировать секцию deploy в docker compose. Можно подключить несколько видеокарт или сразу все.

Для проверки работоспособности сервера используем healthcheck. В Kubernetes можно использовать livenessProbe и readinessProbe для соответствующих проверок.

В итоге для запуска достаточно выполнить:

docker compose up --build -d

Создание UI для модели

Для демонстрационных целей необходимо создать минимальный UI, который можно использовать для проверки работы модели. С учётом того что всё это делает один человек можно выбрать какие-то популярные решения для этого:

Рассмотрим Gradio т. к. там достаточно просто создать UI и много готовых компонент для разных форматов данных. Логика работы достаточно простоя. Приложение на Gradio будет принимать данные пользователей, отправлять их на сервер и отображать результат. Для нашего пример подойдёт макет из двух компонент для загрузки изображений и отображения изображений, выбора типа модели и кнопки запуска. В репозитории это всё написано в одном файле client.py. Пример UI на рис. ниже (изображение можно увеличить):

Для взаимодействия с triton inference server есть как минимум три варианта:

  1. Отправлять напрямую запросы через HTTP в json как было в примере выше.

  2. Использовать классы для работы из PyTriton.

  3. Использовать triton client.

Работа с классами из PyTriton наиболее простая, но выберем вариант с triton client как более универсальный. Запросы будут отправляться через http.

Отправка запроса и получение результата
with InferenceServerClient("localhost:8000") as client:
    model_config = client.get_model_metadata(model_name)

    input_dtype = model_config["inputs"][0]["datatype"]
    input_info = InferInput(model_config["inputs"]
                            [0]["name"], image.shape, input_dtype)

    input_info.set_data_from_numpy(image)

    out_name = model_config["outputs"][0]["name"]
    outputs = [InferRequestedOutput(out_name)]

    infer_res = client.infer(
        model_name,
        inputs=[input_info],
        outputs=outputs
    )

    return infer_res.as_numpy(out_name)

Логика работы сводится к тому что сначала подключаемся к серверу. Получаем информацию о модели. Какие есть входные, выходные данные, их тип, размеры. Подготавливаем входные данные input_info. Подготавливаем какие выходные данные нам нужны outputs. Отправляем запрос и ждём ответа client.infer(...). После того как результат получен возвращаем его как numpy массив infer_res.as_numpy(out_name). Есть вариант отправлять асинхронные запросы. Можно посмотреть дополнительные примеры в репозитории с тем какие ещё есть особенности.

Заключение

Получился минимальный рабочий пример из сервера и UI для демонстрации работы. Всё сделано одним человек в рамках программирования на Python. Потребовались дополнительные знания по Docker, использованию видеокарт в контейнерах, docker compose, но это не должно вызывать трудностей. Если вдруг выяснится, что API ещё кому-то нужен, то ничего переписывать не надо. Уже готов хороший сервер с базовыми функциями, которые обычно нужны при эксплуатации в prod. Можно сразу показать пример того как использовать разные модели и какие данные передавать или показать примеры из репозитория triton client.

В дальнейшем можно настроить сбор метрик, логов, написать deployment в Kubernetes и уже использовать в prod.

Не стоит забывать про ограничения и возможные проблемы. Пока что PyTriton находится в разработке и не весь функционал triton inference server доступен.

Теги:
Хабы:
Всего голосов 1: ↑1 и ↓0+1
Комментарии0
1

Публикации

Истории

Работа

Data Scientist
53 вакансии

Ближайшие события

27 марта
Deckhouse Conf 2025
Москва
25 – 26 апреля
IT-конференция Merge Tatarstan 2025
Казань