Как стать автором
Обновить

Гайд по overload: как написать один код на Python для разных бэкендов

Время на прочтение11 мин
Количество просмотров2.2K

Разработчики часто сталкиваются с задачами, в которых одна функция должна работать с разными типами данных и количеством аргументов. Чтобы каждый раз не создавать множество функций с разными именами, существует перегрузка (overload). Она позволяет использовать одно имя операции для обработки различных комбинаций входных данных. Благодаря перегрузке одна функция может адаптироваться под различные сценарии и делать код лаконичным и понятным. 

В статье разберемся, как работает перегрузка в статических и динамических языках программирования. В конце покажу, как и зачем мы реализовали перегрузку на Python своим собственным способом.

Что такое перегрузка функций

Перегрузка функций (function overloading) — это концепция, которая позволяет определять несколько функций или методов с одинаковым именем, но с разными сигнатурами: количеством, типами или порядком аргументов. 

Компилятор или интерпретатор выбирает подходящую версию функции на основе переданных аргументов. Это используется в строго типизированных языках, чтобы писать гибкий и читаемый код было проще.

Примеры перегрузки в C++ и Typescript

C++ — классический пример языка с поддержкой перегрузки «из коробки». У нас есть две функции с одинаковым названием sum, но с разным типом параметров — int и double:

#include <iostream>
#include <string>

int sum(int a, int b) {
    return a + b;
}

double sum(double a, double b) {
    return a + b;
}

int main() {
    std::cout << sum(2, 3) << std::endl;     // Вызовет sum(int, int)
    std::cout << sum(2.5, 3.1) << std::endl; // Вызовет sum(double, double)
    return 0;
}

В TypeScript перегрузка функций реализуется на уровне типов. Здесь две сигнатуры объявлены как перегрузка функции greet, а сама реализация одна. Она проверяет, какие аргументы пришли:

function greet(name: string): string;
function greet(name: string, age: number): string;
function greet(name: string, age?: number): string {
  if (age !== undefined) {
    return `Hello, ${name}! You are ${age} years old.`;
  } else {
    return `Hello, ${name}!`;
  }
}

console.log(greet("Alice"));       // "Hello, Alice!"
console.log(greet("Bob", 30));     // "Hello, Bob! You are 30 years old."

Почему перегрузки в чистом виде нет в Python и других динамических языках?

Python — язык с динамической типизацией. Во время исполнения любая переменная может содержать объект почти любого типа. Получается, что единственная «актуальная» сигнатура функции видна только в рантайме. 

Если сделать в Python две функции с одинаковым именем, то последняя «затрет» предыдущую:

def hello(name: str):
    print(f"Hello {name}")

def hello(age: int):
    print(f"Your age is {age}")

hello("Alice")  # "Your age is Alice" – ошибка: вызовется вторая версия, но она ждёт int.

По умолчанию никакого отдельного механизма перегрузки в Python нет. Но это не значит, что перегрузка невозможно в принципе (:

Как же создать перегрузку в Python?

Ниже опишу подходы, которые часто используются в реальном Python-коде. Первый — самый популярный, остальные — максимально простые в реализации.

1. Проверка типов внутри функции:

def hello(name_or_age):
    if isinstance(name_or_age, str):
        print(f"Hello {name_or_age}")
    elif isinstance(name_or_age, int):
        print(f"Your age is {name_or_age}")
    else:
        raise TypeError("Expected str or int")

Минус такого подхода — единая монолитная функция, которая со временем может раздуться и стать нечитаемой. А еще она не дает возможности задавать несколько версий функции с разными сигнатурами.

2. Уникальные имена для каждой комбинации:

Вместо перегрузки можно использовать разные имена функций для каждой комбинации аргументов.

def hello_str(name):
    print(f"Hello {name}")

def hello_int(age):
    print(f"Your age is {age}")

Такой метод максимально прост и прозрачен, не требует дополнительных инструментов или проверок, что делает его удобным для случаев, где важна явность. Однако это увеличивает количество функций в коде и не соответствует концепции перегрузки, так как нет единой точки входа, что может быть неудобно при работе с похожими по смыслу операциями и может сделать API библиотеки громоздким.

3. Использование декоратора functools.singledispatch (доступен с Python 3.4):

from functools import singledispatch

@singledispatch
def hello(arg):
    raise TypeError("Unsupported type")

@hello.register
def _(arg: str):
    print(f"Hello {arg}")

@hello.register
def _(arg: int):
    print(f"Your age is {arg}")

Этот декоратор позволяет регистрировать функции-обработчики для разных типов аргументов. Но singledispatch ориентирован на тип первого аргумента, а для многих случаев (например, учитывая несколько параметров, Union, Optional и т. д.) этого может быть недостаточно.

4. Использование библиотеки multipledispatch:

from multipledispatch import dispatch

@dispatch(str)
def hello(arg):
    print(f"Hello {arg}")

@dispatch(int)
def hello(arg):
    print(f"Your age is {arg}")

Эта библиотека позволяет регистрировать функции для разных сигнатур, но она не входит в стандартную библиотеку Python — так что придется устанавливать ее отдельно. Кроме того, она не поддерживает аннотации типов.

5. Использование сторонних библиотек

Нестандартные библиотеки или самодельные решения, которые пытаются проанализировать типы через аннотации и хранить разные реализации одной функции в разных местах. Именно в эту категорию и попадает описанная в вопросе реализация.

Как мы делаем перегрузку функций в Python

Мы с командой пришли к тому, что нет такого подхода, который бы идеально нам подошел. Они:

  • имеют ограничения по количеству аргументов, по которым перегружаются

  • не умеют работать с generic'ами по типу Union, Optional, с аргументами по умолчанию, с args, **kwargs.

Наше решение по перегрузке адаптировано под запросы команды, поэтому в нем можно использовать и другие наши техники: например, LazyImport. Это удобно, и коллеги довольны (:

Наша реализация состоит из двух ключевых классов — OverloadManager, OverloadFunction, и декоратора @overload. Давайте разберем, как они взаимодействуют и решают сложные задачи перегрузки.

Важно: Наша реализация overload — это не то же самое, что @overload из typing. Декоратор из typing используется только для статической типизации и не влияет на runtime-поведение, в то время как наша версия направлена на динамическую диспетчеризацию вызовов на основе типов и количества аргументов.

1. Регистрация функций и методов (OverloadManager.register)

Когда вы применяете декоратор @overload к функции, она регистрируется в OverloadManager. Здесь происходит первый важный шаг — определение, является ли объект обычной функцией или методом класса.

Отличие функции от метода

Мы используем атрибут __qualname__, который возвращает полное имя функции или метода. Например:

  • Для обычной функции: __qualname__ = "process".

  • Для метода класса: __qualname__ = "MyClass.process".

Если в __qualname__ есть точка (.), это означает, что функция — метод класса. Тогда мы извлекаем имя класса и сохраняем метод в словаре self.methods с ключом (module, class_name, method_name). Для обычных функций используется просто имя в словаре self.functions.

Зачем это нужно? Это позволяет различать перегрузку на уровне функций и методов, а также поддерживать перегрузку методов с учетом наследования (через __mro__, о чем ниже).

Процесс регистрации функций и методов
Процесс регистрации функций и методов

2. Анализ сигнатуры (OverloadFunction.register)

После определения типа объекта мы анализируем его сигнатуру, чтобы зарегистрировать конкретную перегрузку. Анализ разбит на этапы:

  • Извлечение типов параметров

Используется inspect.signature, который возвращает объект Signature. Мы проходимcя по всем параметрам и извлекаем их аннотации типов, исключая *args и **kwargs (переменное число аргументов), так как они не участвуют в строгой перегрузке. Результат — кортеж типов, например: (int, str).

  • Нормализация аннотаций (_normalize_annotation)

Аннотации могут быть сложными (например, Union[int, str], List[str], Optional[float]), и их нужно привести к удобному виду:

  • Обработка Union: Если тип — Union, мы вызываем get_origin (возвращает Union) и get_args (возвращает (int, str)), сохраняя подтипы для последующей проверки.

  • Обработка generic-типов: Для List[str] get_origin вернет list, а get_args(str,).

  • Ленивые импорты: Если аннотация — объект LazyImport, мы оборачиваем её в LazyTypeWrapper, чтобы отложить разрешение типа до момента вызова.

  • Результат

Каждый вариант перегрузки сохраняется в словаре self.overloads с ключом — кортежем типов, а значением — самой функцией.

class OverloadFunction:
    def __init__(self, name: str):
        self.name = name
        self.overloads = {}

    def register(self, func) -> None:
        # извлечение типов параметров
        # нормализация аннотаций
        # получение финального кортежа param_types
        
        self.overloads[param_types] = func 

3. Вызов функции (OverloadFunction.__call__)

Когда вызывается перегруженная функция, мы определяем, какая версия должна быть выполнена, по такому алгоритму:

  • Сбор типов аргументов

Для переданных аргументов (args) мы создаем кортеж их фактических типов с помощью tuple(type(arg) for arg in args). Например, вызов process(42, "hello") дает (int, str).

  • Сопоставление типов (_match_types)

Это сердце проверки, где сравниваются фактические типы аргументов с ожидаемыми типами параметров:

  • Проверка длины: Если аргументов больше, чем параметров, это сразу несовпадение.

  • Методы классов: Если первый параметр — self (пустая аннотация), он пропускается при сравнении, чтобы поддерживать методы.

  • Обработка Union: Если параметр имеет тип Union[int, str], мы используем get_args для извлечения (int, str) и проверяем, является ли тип аргумента подклассом хотя бы одного из них через issubclass. Если в Union есть None, он исключается из проверки, если аргумент не None.

  • Ленивые типы: Для LazyTypeWrapper мы пытаемся разрешить тип через resolve(). Если это не удается (например, из-за циклического импорта), сравниваем имена типов как запасной вариант.

  • Generic-типы: Если параметр — list, проверяем, является ли аргумент подклассом list (через issubclass).

  • Выбор функции

Если типы совпадают, мы используем inspect.signature(fn).bind_partial, чтобы привязать аргументы (включая значения по умолчанию), и вызываем функцию.

class OverloadFunction:
    ...

    # Вызов обертки над оригинальной функцией
    def __call__(self, *args, **kwargs) -> Any:
        arg_types = # получаем типы аргументов из args

        # Проходимся по всем элементам в self.overloads, если находим перегрузку,
        # то возвращаем результат оригинальной функции:
          return fn(*bound_args.args, **bound_args.kwargs)

        # иначе:
          raise TypeError(f"No match for types {arg_types}")

4. Поддержка наследования (OverloadManager.call)

Для методов классов мы учитываем иерархию наследования:

Если первый аргумент — объект класса, мы проверяем его тип через __class__ и проходим по цепочке базовых классов (__mro__). Например, если метод перегружен в базовом классе Base, а вызывается на объекте Derived, мы найдем подходящую версию.

class OverloadManager:
    ...

    # Метод вызывается из функции-декоратора
    def call(self, name, *args, **kwargs) -> Any:
      # сначала пытаемся найти перегрузку среди функций
      # если нашлась, то возвращаем:
        return self.functions[name](*args, **kwargs)
      
      # если первый аргумент это __class__,
      # то ищем метод в классе и среди родительских классов,
      # извлекаем название модуля и класса, затем возвращаем:
        key = (module_name, class_name, name)
        return self.methods[key](*args, **kwargs)
      
      # иначе:
        raise TypeError(f"No overloaded function '{name}' found.")

Итоговая картинка вызова перегруженной функции

Вызов функции
Вызов функции

Что получается в итоге

from typing import Union, Optional
from copy import deepcopy

import pandas as pd
import numpy as np
from sklearn.datasets import load_iris
from sklearn.linear_model import LogisticRegression

from my_library.overloading import overload  # функция-декоратор перегрузки

DataFrame = Union[pd.DataFrame, np.ndarray]
DataArray = Union[pd.Series, pd.Index, np.ndarray, list, tuple]

class Dataset:
    "Кастомный класс для работы с данными"

    def __init__(self, data: DataFrame, target_column: str):
        self.data = data
        self.target_column = target_column
        ...

class Model:
    "Кастомный класс для моделей"
    def __init__(self, model):
        self.model = deepcopy(model)
        self.features = None
    
    @overload
    def fit(self, dataset: Dataset, features: Optional[DataArray] = None, **kwargs):
        self.features = (
            self._get_features(dataset.data, dataset.target_column)
            if features is None
            else features
        )
        return self.fit(dataset.data[self.features], dataset.data[dataset.target_column], **kwargs)

    @overload
    def fit(self, X: DataFrame, y: DataArray, **kwargs):
        self.features = self._get_features(X, None)
        self.model.fit(X, y, **kwargs)
        return self

    @overload
    def predict(self, dataset: Dataset, **kwargs):
        return self.predict(dataset.data[self.features], **kwargs)

    @overload
    def predict(self, X: DataFrame, **kwargs):
        return self.model.predict(X, **kwargs)
    
    def _get_features(self, data: DataFrame, target_col: Optional[str]) -> list:
        if hasattr(data, 'columns'):
            if target_col is not None:
                return data.columns.drop(target_col).tolist()
            return data.columns.tolist()
        else:
            return list(range(data.shape[1] - (1 if target_col is None else 0)))


X, y = load_iris(as_frame=True, return_X_y=True)
dataset = Dataset(pd.concat([X, y], axis=1), 'target')
sklearn_model = LogisticRegression(solver='liblinear', multi_class='ovr', random_state=42)

model1 = Model(sklearn_model)
model2 = Model(sklearn_model)
model3 = Model(sklearn_model)

# Перегрузка сработает в зависимости от типов аргументов
features = ['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)']
model1.fit(dataset, features)
model2.fit(dataset)
model3.fit(X, y)

# Проверка на равенство
np.array_equal(
    model1.predict(dataset), 
    model2.predict(dataset)
) # True
np.array_equal(
    model2.predict(dataset), 
    model3.predict(dataset)
) # True
np.array_equal(
    model1.predict(dataset), 
    model1.predict(X)
) # True

Теперь посмотрим на содержание нашего менеджера.

В коде нашей библиотеки нам достаточно объявить один объект manager = OverloadManager(), далее в ней функция-декоратор overload будет регистрировать все перегрузки.

>>> from my_library import overloading

>>> overloading.manager.methods
{('__main__',
  'Model',
  'fit'): <my_library.overloading.OverloadFunction at 0x12fa82ec0>,
 ('__main__',
  'Model',
  'predict'): <my_library.overloading.OverloadFunction at 0x14f8b0b80>}

# название модуля здесь `__main__` так как мы тестировали наши
# перегрузки выше в jupyter notebook.

Посмотрим теперь на конкретные перегрузки. Так как мы работаем с методами, а не функциями, то ключом к каждому из них будет кортеж, который состоит из названия модуля, класса и самой функции.

>>> overloading.manager.methods[('__main__', 'Model', 'fit')].overloads
{(inspect._empty,
  __main__.Dataset,
  typing.Union[pandas.core.series.Series, pandas.core.indexes.base.Index, numpy.ndarray, list, tuple, NoneType]): <function __main__.Model.fit(self, dataset: __main__.Dataset, features: Union[pandas.core.series.Series, pandas.core.indexes.base.Index, numpy.ndarray, list, tuple, NoneType] = None, **kwargs)>,
 (inspect._empty,
  typing.Union[pandas.core.frame.DataFrame, numpy.ndarray],
  typing.Union[pandas.core.series.Series, pandas.core.indexes.base.Index, numpy.ndarray, list, tuple]): <function __main__.Model.fit(self, X: Union[pandas.core.frame.DataFrame, numpy.ndarray], y: Union[pandas.core.series.Series, pandas.core.indexes.base.Index, numpy.ndarray, list, tuple], **kwargs)>}

>>> overloading.manager.methods[('__main__', 'Model', 'predict')].overloads
{(inspect._empty,
  __main__.Dataset): <function __main__.Model.predict(self, dataset: __main__.Dataset, **kwargs)>,
 (inspect._empty,
  typing.Union[pandas.core.frame.DataFrame, numpy.ndarray]): <function __main__.Model.predict(self, X: Union[pandas.core.frame.DataFrame, numpy.ndarray], **kwargs)>}

Технические моменты в реализации overload

1. Работа с Union и Optional

  • Union[int, str] разбирается на подтипы (int, str), и проверка проходит для каждого аргумента отдельно.

  • Optional[float] (то есть Union[float, None]) обрабатывается так, чтобы None не мешал, если аргумент — float. Это делает перегрузку интуитивной.

2. Ленивые импорты

Если тип импортируется лениво (в нашем случае с LazyImport), мы не загружаем его сразу, а откладываем до момента вызова. Это решает проблему циклических зависимостей, так как использование from future import annotations ломает нашу реализацию.

LazyImport: Наш кастомный класс для отложенного разрешения типов. В основном используется для решения проблем с циклическими импортами и также откладывает импорт тяжелых библиотек. При вызове resolve() он импортирует модуль и возвращает тип.

3. Значения по умолчанию

Благодаря bind_partial и apply_defaults мы корректно обрабатываем параметры с дефолтными значениями, даже если они не переданы в вызове.

4. Обработка ошибок

Если подходящей перегрузки не найдено, выбрасывается информативное исключение с указанием типов аргументов, что упрощает отладку.

5. Работа с *args и **kwargs

Для пропуска переменного числа аргументов, мы используем inspect.Parameter.VAR_POSITIONAL и inspect.Parameter.VAR_KEYWORD, сравнивая с ним все переданные аргументы.

6. Пропуск self в методах

Если первый параметр — self, мы пропускаем его при сравнении типов, чтобы поддерживать методы классов. Вычислить, является ли аргумент self, помогает сравнение с inspect.Parameter.empty.

7. Различия в работе с методами и функциями

При обработке функций достаточно хранить только их имена в менеджере, тогда как для методов требуется использовать кортеж, включающий имя метода, а также имена модуля и класса. Это обусловлено тем, что методы с одинаковыми именами в разных классах (за исключением случаев наследования) выполняют различные задачи. Например, в пользовательских классах Model для обучения и GridSearch для подбора гиперпараметров может быть метод fit(), но его назначение и реализация в каждом случае будет различным.

Что ещё почитать по теме

  1. typing — Support for type hints — Python documentation — официальная документация библиотеки typing для использования аннотаций типов (get_type_hints, get_origin, get_args).

  2. inspect — Inspect live objects — Python documentation — официальная документация модуля inspect для работы с объектами во время исполнения.

  3. PEP 484 — Type Hints — официальное описание синтаксиса аннотаций типов в Python.

  4. PEP 563 — postponed evaluation of type annotations — официальное описание отложенного разрешения аннотаций типов.

Теги:
Хабы:
+5
Комментарии1

Публикации

Информация

Сайт
beeline.ru
Дата регистрации
Дата основания
Численность
свыше 10 000 человек
Местоположение
Россия