
В статье «Ускоряем анализ данных в 180 000 раз с помощью Rust» показано, как неоптимизированный код на Python, после переписывания и оптимизации на Rust, ускоряется в 180 000 раз. Автор отмечает: «есть множество способов сделать код на Python быстрее, но смысл этого поста не в том, чтобы сравнить высокооптимизированный Python с высокооптимизированным Rust. Смысл в том, чтобы сравнить "стандартный-Jupyter-notebook" Python с высокооптимизированным Rust».
Возникает вопрос: какого ускорения мы могли бы достичь, если бы остановились на Python?
Под катом разработчик Сидни Рэдклифф* проходит путь профилирования и итеративного ускорения кода на Python, чтобы выяснить это.
*Обращаем ваше внимание, что позиция автора может не всегда совпадать с мнением МойОфис.
Повторяем исходные бенчмарки
Как и в упомянутой выше статье, мы используем M1 Macbook, и по тем же бенчмаркам получаем сопоставимые показатели:
Среднее время итерации исходного неоптимизированного кода, измеренное за 1000 итераций, — 35 мс. В оригинальной статье — 36 мс.
После полной оптимизации код на Rust ускорен в 180,081 раз. В оригинальной статье сообщается о 182 450-кратном ускорении.
Изначальный код на Python
Вот неоптимизированный код на Python из ранее упомянутой статьи.
from itertools import combinations import pandas as pd from pandas import IndexSlice as islice def k_corrset(data, K): all_qs = data.question.unique() q_to_score = data.set_index(['question', 'user']) all_grand_totals = data.groupby('user').score.sum().rename('grand_total') # Inner loop corrs = [] for qs in combinations(all_qs, K): qs_data = q_to_score.loc[islice[qs,:],:].swaplevel() answered_all = qs_data.groupby(level=[0]).size() == K answered_all = answered_all[answered_all].index qs_totals = qs_data.loc[islice[answered_all,:]] \ .groupby(level=[0]).sum().rename(columns={'score': 'qs'}) r = qs_totals.join(all_grand_totals).corr().qs.grand_total corrs.append({'qs': qs, 'r': r}) corrs = pd.DataFrame(corrs) return corrs.sort_values('r', ascending=False).iloc[0].qs data = pd.read_json('scores.json') print(k_corrset(data, K=5))
А вот первые две строки DataFrame (далее — датафрейм), data.
user | question | score |
e213cc2b-387e-4d7d-983c-8abc19a586b1 | d3bdb068-7245-4521-ae57-d0e9692cb627 | 1 |
951ffaee-6e17-4599-a8c0-9dfd00470cd9 | d3bdb068-7245-4521-ae57-d0e9692cb627 | 0 |
Для проверки корректности нашего оптимизированного кода, мы можем использовать вывод исходного кода.
Поскольку мы пытаемся оптимизировать внутренний цикл, поместим его в собственную функцию, чтобы профилировать с помощью line_profiler.
Avg time per iteration: 35 ms Speedup over baseline: 1.0x % Time Line Contents ===================== def compute_corrs( qs_iter: Iterable, q_to_score: pd.DataFrame, grand_totals: pd.DataFrame ): 0.0 result = [] 0.0 for qs in qs_iter: 13.5 qs_data = q_to_score.loc[islice[qs, :], :].swaplevel() 70.1 answered_all = qs_data.groupby(level=[0]).size() == K 0.4 answered_all = answered_all[answered_all].index 0.0 qs_total = ( 6.7 qs_data.loc[islice[answered_all, :]] 1.1 .groupby(level=[0]) 0.6 .sum() 0.3 .rename(columns={"score": "qs"}) ) 7.4 r = qs_total.join(grand_totals).corr().qs.grand_total 0.0 result.append({"qs": qs, "r": r}) 0.0 return result
Мы видим значения, которые пытаемся оптимизировать (среднее время итерации/ускорение), а также долю времени, потраченного на выполнение каждой строки.
Это позволяет оптимизировать код следующим образом:
Запускаем профилировщик
Определяем самые медленные строки
Пробуем сделать медленные строки более быстрыми
Повторяем
В приведённом выше коде мы видим, что есть наиболее медленная строка, которая занимает ~70% времени.
Однако есть еще один важный шаг, который предшествует вышеупомянутым:
Проверяем вывод на корректность
Запускаем профилировщик
Определяем самые медленные строки
Пробуем сделать медленные строки более быстрыми
Повторяем
Проверки корректности вывода помогают экспериментировать, пробовать различные методы, библиотеки и т.д., зная при этом, что любые случайные изменения в вычисляемой информации будут отслежены.
Оптимизация 1. Словарь множеств пользователей, ответивших на вопросы: users_who_answered_q
Наш базовый код выполняет различные тяжёлые операции Pandas, выясняя, какие пользователи ответили на заданный набор вопросов — qs. В частности, для этого он проверяет каждую строку датафрейма, чтобы определить, какие пользователи отвечали на вопросы. В первой оптимизации вместо полноценного датафрейма мы можем использовать словарь множеств пользователей. Это позволит нам быстро выяснить, какие пользователи ответили на каждый вопрос qs, и использовать пересечение множеств в Python, чтобы выявить пользователей, ответивших на все вопросы.
Avg time per iteration: 10.0 ms Speedup over baseline: 3.5x % Time Line Contents ===================== def compute_corrs(qs_iter, users_who_answered_q, q_to_score, grand_totals): 0.0 result = [] 0.0 for qs in qs_iter: 0.0 user_sets_for_qs = [users_who_answered_q[q] for q in qs] 3.6 answered_all = set.intersection(*user_sets_for_qs) 40.8 qs_data = q_to_score.loc[islice[qs, :], :].swaplevel() 0.0 qs_total = ( 22.1 qs_data.loc[islice[list(answered_all), :]] 3.7 .groupby(level=[0]) 1.9 .sum() 1.1 .rename(columns={"score": "qs"}) ) 26.8 r = qs_total.join(grand_totals).corr().qs.grand_total 0.0 result.append({"qs": qs, "r": r}) 0.0 return result
Так мы значительно ускоряем вычисление строки answered_all, которая вместо 70 % теперь занимает 4 %, и наш код становится быстрее в 3 раза.
Оптимизация 2. Словарь score_dict
Если сложить время, затрачиваемое на каждую строку, участвующую в вычислении qs_total (включая строку qs_data), то получится ~65%; таким образом, наша следующая задача по оптимизации ясна. Нужно снова заменить тяжёлые операции над полным датафреймом (индексирование, группировка и т. д.) быстрым поиском по словарю. Для этого вводим score_dict, словарь, который позволяет проводить оценку для заданной пары вопрос-пользователь.
Avg time per iteration: 690 μs Speedup over baseline: 50.8x % Time Line Contents ===================== def compute_corrs(qs_iter, users_who_answered_q, score_dict, grand_totals): 0.0 result = [] 0.0 for qs in qs_iter: 0.1 user_sets_for_qs = [users_who_answered_q[q] for q in qs] 35.9 answered_all = set.intersection(*user_sets_for_qs) 3.4 qs_total = {u: sum(score_dict[q, u] for q in qs) for u in answered_all} 8.6 qs_total = pd.DataFrame.from_dict(qs_total, orient="index", columns=["qs"]) 0.1 qs_total.index.name = "user" 51.8 r = qs_total.join(grand_totals).corr().qs.grand_total 0.0 result.append({"qs": qs, "r": r}) 0.0 return result
Это помогает нам ускорить код в 50 раз.
Оптимизация 3. Словарь grand_totals и np.corrcoef
Самая медленная строка в коде выше делает несколько вещей: сперва Pandas join, чтобы объединить grand_totals и qs_total, а затем вычисляет для этого коэффициент корреляции. Опять же, мы можем ускорить процесс, используя поиск по словарю вместо join, и поскольку у нас больше нет объектов Pandas, используем np.corrcoef вместо Pandas corr.
Avg time per iteration: 380 μs Speedup over baseline: 91.6x % Time Line Contents ===================== def compute_corrs(qs_iter, users_who_answered_q, score_dict, grand_totals): 0.0 result = [] 0.0 for qs in qs_iter: 0.2 user_sets_for_qs = [users_who_answered_q[q] for q in qs] 83.9 answered_all = set.intersection(*user_sets_for_qs) 7.2 qs_total = [sum(score_dict[q, u] for q in qs) for u in answered_all] 0.5 user_grand_total = [grand_totals[u] for u in answered_all] 8.1 r = np.corrcoef(qs_total, user_grand_total)[0, 1] 0.1 result.append({"qs": qs, "r": r}) 0.0 return result
Получаем ~90-кратное ускорение кода.
Оптимизация 4. Преобразование строк uuid в ints
Эта оптимизация не вносит никаких изменений в код внутреннего цикла. Но она ускоряет некоторые операции. Мы заменяем длинные uuid пользователя/вопроса (например, e213cc2b-387e-4d7d-983c-8abc19a586b1) на гораздо более короткие целочисленные данные. Как это делается:
data.user = data.user.map({u: i for i, u in enumerate(data.user.unique())}) data.question = data.question.map( {q: i for i, q in enumerate(data.question.unique())} )
Теперь измеряем:
Avg time per iteration: 210 μs Speedup over baseline: 168.5x % Time Line Contents ===================== def compute_corrs(qs_iter, users_who_answered_q, score_dict, grand_totals): 0.0 result = [] 0.1 for qs in qs_iter: 0.4 user_sets_for_qs = [users_who_answered_q[q] for q in qs] 71.6 answered_all = set.intersection(*user_sets_for_qs) 13.1 qs_total = [sum(score_dict[q, u] for q in qs) for u in answered_all] 0.9 user_grand_total = [grand_totals[u] for u in answered_all] 13.9 r = np.corrcoef(qs_total, user_grand_total)[0, 1] 0.1 result.append({"qs": qs, "r": r}) 0.0 return result
Оптимизация 5. Массив np.bool_ вместо множеств пользователей
Видно, что операция с множествами пользователей в коде выше по-прежнему самая медленная. Вместо использования наборов ints мы переходим к использованию массива пользователей np.bool_ и применяем np.logical_and.reduce для поиска пользователей, ответивших на все вопросы qs. (Обратите внимание, что np.bool_ использует целый байт для каждого элемента, но np.logical_and.reduce все равно довольно быстр.) Это даёт нам значительное ускорение:
Avg time per iteration: 75 μs Speedup over baseline: 466.7x % Time Line Contents ===================== def compute_corrs(qs_iter, users_who_answered_q, score_dict, grand_totals): 0.0 result = [] 0.1 for qs in qs_iter: 12.0 user_sets_for_qs = users_who_answered_q[qs, :] # numpy indexing 9.9 answered_all = np.logical_and.reduce(user_sets_for_qs) 10.7 answered_all = np.where(answered_all)[0] 33.7 qs_total = [sum(score_dict[q, u] for q in qs) for u in answered_all] 2.6 user_grand_total = [grand_totals[u] for u in answered_all] 30.6 r = np.corrcoef(qs_total, user_grand_total)[0, 1] 0.2 result.append({"qs": qs, "r": r}) 0.0 return result
Оптимизация 6. score_matrix вместо словаря
Теперь самая медленная строка — вычисление qs_total. Следуя примеру из оригинальной статьи, мы переходим к использованию плотного массива np.array для поиска оценок вместо словаря, и используем быструю индексацию NumPy для получения оценок.
Avg time per iteration: 56 μs Speedup over baseline: 623.7x % Time Line Contents ===================== def compute_corrs(qs_iter, users_who_answered_q, score_matrix, grand_totals): 0.0 result = [] 0.2 for qs in qs_iter: 16.6 user_sets_for_qs = users_who_answered_q[qs, :] 14.0 answered_all = np.logical_and.reduce(user_sets_for_qs) 14.6 answered_all = np.where(answered_all)[0] 7.6 qs_total = score_matrix[answered_all, :][:, qs].sum(axis=1) 3.9 user_grand_total = [grand_totals[u] for u in answered_all] 42.7 r = np.corrcoef(qs_total, user_grand_total)[0, 1] 0.4 result.append({"qs": qs, "r": r}) 0.0 return result
Оптимизация 7. Реализация corrcoef
Самая медленная строка в коде — np.corrcoef... Мы всеми силами пытаемся оптимизировать код, поэтому вот наша собственная реализация corrcoef, которая в данном случае будет в два раза быстрее:
def corrcoef(a: list[float], b: list[float]) -> float | None: """same as np.corrcoef(a, b)[0, 1]""" n = len(a) sum_a = sum(a) sum_b = sum(b) sum_ab = sum(a_i * b_i for a_i, b_i in zip(a, b)) sum_a_sq = sum(a_i**2 for a_i in a) sum_b_sq = sum(b_i**2 for b_i in b) num = n * sum_ab - sum_a * sum_b den = sqrt(n * sum_a_sq - sum_a**2) * sqrt(n * sum_b_sq - sum_b**2) if den == 0: return None return num / den
Получаем приличное ускорение:
Avg time per iteration: 43 μs Speedup over baseline: 814.6x % Time Line Contents ===================== def compute_corrs(qs_iter, users_who_answered_q, score_matrix, grand_totals): 0.0 result = [] 0.2 for qs in qs_iter: 21.5 user_sets_for_qs = users_who_answered_q[qs, :] # numpy indexing 18.7 answered_all = np.logical_and.reduce(user_sets_for_qs) 19.7 answered_all = np.where(answered_all)[0] 10.0 qs_total = score_matrix[answered_all, :][:, qs].sum(axis=1) 5.3 user_grand_total = [grand_totals[u] for u in answered_all] 24.1 r = corrcoef(qs_total, user_grand_total) 0.5 result.append({"qs": qs, "r": r}) 0.0 return result
Оптимизация 8. Преждевременное внедрение Numba
Мы ещё не закончили оптимизацию структур данных в приведённом выше коде, но давайте посмотрим, что нам даст внедрение на текущем этапе Numba? Речь о библиотеке в экосистеме Python, которая «переводит подмножество кода Python и NumPy в быстрый машинный код».
Чтобы иметь возможность использовать Numba, выполним два изменения:
Модификация 1. Передаем qs_combinations как массив numpy, вместо qs_iter
Numba не очень хорошо работает с itertools или генераторами, поэтому мы заранее превращаем qs_iter в массив NumPy, чтобы передать его функции. Влияние этого изменения на скорость выполнения (до добавления Numba) показано ниже.
Avg time per iteration: 42 μs Speedup over baseline: 829.2x
Модификация 2. Массив результатов вместо списка
Вместо добавления в список, мы инициализируем массив и помещаем в него результаты. Вот как это изменение повлияло на скорость.
Avg time per iteration: 42 μs Speedup over baseline: 833.8x
В итоге наш код выглядит так:
import numba @numba.njit(parallel=False) def compute_corrs(qs_combinations, users_who_answered_q, score_matrix, grand_totals): result = np.empty(len(qs_combinations), dtype=np.float64) for i in numba.prange(len(qs_combinations)): qs = qs_combinations[i] user_sets_for_qs = users_who_answered_q[qs, :] # numba doesn't support np.logical_and.reduce answered_all = user_sets_for_qs[0] for j in range(1, len(user_sets_for_qs)): answered_all *= user_sets_for_qs[j] answered_all = np.where(answered_all)[0] qs_total = score_matrix[answered_all, :][:, qs].sum(axis=1) user_grand_total = grand_totals[answered_all] result[i] = corrcoef_numba(qs_total, user_grand_total) return result
(Обратите внимание, что мы также дополнили corrcoef с помощью Numba, потому что функции, вызываемые внутри функции Numba, тоже должны быть скомпилированы.)
Результаты с параметром parallel=False:
Avg time per iteration: 47 μs Speedup over baseline: 742.2x
Результаты с параметром parallel=True:
Avg time per iteration: 8.5 μs Speedup over baseline: 4142.0x
Видно, что при значении parallel=False код Numba работает немного медленнее, чем наш предыдущий код на Python, но когда мы включаем параллелизм, то начинаем использовать все ядра процессора (10 на нашей рабочей машине) — и это даёт хороший множитель скорости.
Однако мы теряем возможность использовать line_profiler на JIT-компилированном коде; (возможно, мы захотим обратиться к сгенерированному LLVM IR / сборке).
Оптимизация 9. Bitsets, без Numba
Пока отвлечёмся от Numba. В оригинальной статье для быстрого вычисления пользователей, ответивших на текущий qs, используются bitsets — проверим, применим ли такой подход в нашем случае. Для реализации bitsets мы можем использовать массивы NumPy np.int64 и np.bitwise_and.reduce. В отличие от использования массива np.bool_, теперь мы используем отдельные биты в байте для представления сущностей в множестве. Обратите внимание, что для данного bitset может понадобиться несколько байтов, в зависимости от максимального количества элементов, которые нам нужны. Мы можем использовать быстрый bitwise_and для байтов каждого вопроса в qs, чтобы найти пересечение множеств и, следовательно, количество пользователей, ответивших на все qs.
Вот функции bitset, которые мы будем использовать:
def bitset_create(size): """Initialise an empty bitset""" size_in_int64 = int(np.ceil(size / 64)) return np.zeros(size_in_int64, dtype=np.int64)
def bitset_add(arr, pos): """Add an element to a bitset""" int64_idx = pos // 64 pos_in_int64 = pos % 64 arr[int64_idx] |= np.int64(1) << np.int64(pos_in_int64)
def bitset_to_list(arr): """Convert a bitset back into a list of ints""" result = [] for idx in range(arr.shape[0]): if arr[idx] == 0: continue for pos in range(64): if (arr[idx] & (np.int64(1) << np.int64(pos))) != 0: result.append(idx * 64 + pos) return np.array(result)
И мы можем инициализировать bitsets следующим образом:
users_who_answered_q = np.array( [bitset_create(data.user.nunique()) for _ in range(data.question.nunique())] ) for q, u in data[["question", "user"]].values: bitset_add(users_who_answered_q[q], u)
Посмотрим, какое ускорение мы получим:
Avg time per iteration: 550 μs Speedup over baseline: 64.2x % Time Line Contents ===================== def compute_corrs(qs_combinations, users_who_answered_q, score_matrix, grand_totals): 0.0 num_qs = qs_combinations.shape[0] 0.0 bitset_size = users_who_answered_q[0].shape[0] 0.0 result = np.empty(qs_combinations.shape[0], dtype=np.float64) 0.0 for i in range(num_qs): 0.0 qs = qs_combinations[i] 0.3 user_sets_for_qs = users_who_answered_q[qs_combinations[i]] 0.4 answered_all = np.bitwise_and.reduce(user_sets_for_qs) 96.7 answered_all = bitset_to_list(answered_all) 0.6 qs_total = score_matrix[answered_all, :][:, qs].sum(axis=1) 0.0 user_grand_total = grand_totals[answered_all] 1.9 result[i] = corrcoef(qs_total, user_grand_total) 0.0 return result
Как видно, мы замедлились, поскольку операция bitset_to_list занимает слишком много времени.
Оптимизация 10. Numba на bitset_to_list
Преобразуем bitset_to_list в скомпилированный код. Для этого мы можем добавить декоратор Numba:
@numba.njit def bitset_to_list(arr): result = [] for idx in range(arr.shape[0]): if arr[idx] == 0: continue for pos in range(64): if (arr[idx] & (np.int64(1) << np.int64(pos))) != 0: result.append(idx * 64 + pos) return np.array(result)
Измерим скорость:
Benchmark #14: bitsets, with numba on bitset_to_list Using 1000 iterations... Avg time per iteration: 19 μs Speedup over baseline: 1801.2x % Time Line Contents ===================== def compute_corrs(qs_combinations, users_who_answered_q, score_matrix, grand_totals): 0.0 num_qs = qs_combinations.shape[0] 0.0 bitset_size = users_who_answered_q[0].shape[0] 0.0 result = np.empty(qs_combinations.shape[0], dtype=np.float64) 0.3 for i in range(num_qs): 0.6 qs = qs_combinations[i] 8.1 user_sets_for_qs = users_who_answered_q[qs_combinations[i]] 11.8 answered_all = np.bitwise_and.reduce(user_sets_for_qs) 7.7 answered_all = bitset_to_list(answered_all) 16.2 qs_total = score_matrix[answered_all, :][:, qs].sum(axis=1) 1.1 user_grand_total = grand_totals[answered_all] 54.1 result[i] = corrcoef(qs_total, user_grand_total) 0.0 return result
Мы получили ускорение в 1800 раз по сравнению с исходным кодом. Вспомните, что оптимизация 7, до введения Numba, дала 814x. (Оптимизация 8 дала 4142x, но это было с parallel=True во внутреннем цикле, так что показатель здесь нерелевантен.)
Оптимизация 11. Numba на corrcoef
Строчка с corrcoef снова выделяется как слишком медленная. Навесим на corrcoef декоратор Numba.
@numba.njit def corrcoef_numba(a, b): """same as np.corrcoef(a, b)[0, 1]""" n = len(a) sum_a = sum(a) sum_b = sum(b) sum_ab = sum(a * b) sum_a_sq = sum(a * a) sum_b_sq = sum(b * b) num = n * sum_ab - sum_a * sum_b den = math.sqrt(n * sum_a_sq - sum_a**2) * math.sqrt(n * sum_b_sq - sum_b**2) return np.nan if den == 0 else num / den
Смотрим результаты:
Avg time per iteration: 11 μs Speedup over baseline: 3218.9x % Time Line Contents ===================== def compute_corrs(qs_combinations, users_who_answered_q, score_matrix, grand_totals): 0.0 num_qs = qs_combinations.shape[0] 0.0 bitset_size = users_who_answered_q[0].shape[0] 0.0 result = np.empty(qs_combinations.shape[0], dtype=np.float64) 0.7 for i in range(num_qs): 1.5 qs = qs_combinations[i] 15.9 user_sets_for_qs = users_who_answered_q[qs_combinations[i]] 26.1 answered_all = np.bitwise_and.reduce(user_sets_for_qs) 16.1 answered_all = bitset_to_list(answered_all) 33.3 qs_total = score_matrix[answered_all, :][:, qs].sum(axis=1) 2.0 user_grand_total = grand_totals[answered_all] 4.5 result[i] = corrcoef_numba(qs_total, user_grand_total) 0.0 return result
Прекрасно, очередное значительное ускорение!
Оптимизация 12. Numba на bitset_and
Вместо использования np.bitwise_and.reduce мы вводим функцию bitwise_and и применяем к ней jit-компиляцию.
@numba.njit def bitset_and(arrays): result = arrays[0].copy() for i in range(1, len(arrays)): result &= arrays[i] return result
Benchmark #16: numba also on bitset_and Using 1000 iterations... Avg time per iteration: 8.9 μs Speedup over baseline: 3956.7x % Time Line Contents ===================== def compute_corrs(qs_combinations, users_who_answered_q, score_matrix, grand_totals): 0.1 num_qs = qs_combinations.shape[0] 0.0 bitset_size = users_who_answered_q[0].shape[0] 0.1 result = np.empty(qs_combinations.shape[0], dtype=np.float64) 1.0 for i in range(num_qs): 1.5 qs = qs_combinations[i] 18.4 user_sets_for_qs = users_who_answered_q[qs_combinations[i]] 16.1 answered_all = bitset_and(user_sets_for_qs) 17.9 answered_all = bitset_to_list(answered_all) 37.8 qs_total = score_matrix[answered_all, :][:, qs].sum(axis=1) 2.4 user_grand_total = grand_totals[answered_all] 4.8 result[i] = corrcoef_numba(qs_total, user_grand_total) 0.0 return result
Оптимизация 13. Numba на всю функцию
Код стал значительно быстрее исходного, причём вычисления распределены довольно равномерно между несколькими строками цикла. Похоже, самая медленная строка выполняет индексацию NumPy, которая и так довольно быстрая. Давайте скомпилируем всю функцию с помощью Numba.
@numba.njit(parallel=False) def compute_corrs(qs_combinations, users_who_answered_q, score_matrix, grand_totals): result = np.empty(len(qs_combinations), dtype=np.float64) for i in numba.prange(len(qs_combinations)): qs = qs_combinations[i] user_sets_for_qs = users_who_answered_q[qs, :] answered_all = user_sets_for_qs[0] # numba doesn't support np.logical_and.reduce for j in range(1, len(user_sets_for_qs)): answered_all *= user_sets_for_qs[j] answered_all = np.where(answered_all)[0] qs_total = score_matrix[answered_all, :][:, qs].sum(axis=1) user_grand_total = grand_totals[answered_all] result[i] = corrcoef_numba(qs_total, user_grand_total) return result
Avg time per iteration: 4.2 μs Speedup over baseline: 8353.2x
А теперь с параметром parallel=True:
Avg time per iteration: 960 ns Speedup over baseline: 36721.4x
Отлично, наш код уже в 36 000 раз быстрее исходного.
Оптимизация 14. Numba, встраивание с накоплением вместо массивов
Куда двигаться дальше?... Ну, в коде все ещё достаточно много значений помещается в массивы, а затем передаётся по ним. Если мы посмотрим, как вычисляется corrcoef, то поймём, что нам не нужно создавать массивы answered_all и user_grand_total, мы можем накапливать значения по мере выполнения цикла.
Вот код (мы также включили некоторые оптимизации компилятора, например, отключили boundschecking для массивов и включили fastmath):
@numba.njit(boundscheck=False, fastmath=True, parallel=False, nogil=True) def compute_corrs(qs_combinations, users_who_answered_q, score_matrix, grand_totals): num_qs = qs_combinations.shape[0] bitset_size = users_who_answered_q[0].shape[0] corrs = np.empty(qs_combinations.shape[0], dtype=np.float64) for i in numba.prange(num_qs): # bitset will contain users who answered all questions in qs_array[i] bitset = users_who_answered_q[qs_combinations[i, 0]].copy() for q in qs_combinations[i, 1:]: bitset &= users_who_answered_q[q] # retrieve stats for the users to compute correlation n = 0.0 sum_a = 0.0 sum_b = 0.0 sum_ab = 0.0 sum_a_sq = 0.0 sum_b_sq = 0.0 for idx in range(bitset_size): if bitset[idx] != 0: for pos in range(64): if (bitset[idx] & (np.int64(1) << np.int64(pos))) != 0: user_idx = idx * 64 + pos score_for_qs = 0.0 for q in qs_combinations[i]: score_for_qs += score_matrix[user_idx, q] score_for_user = grand_totals[user_idx] n += 1.0 sum_a += score_for_qs sum_b += score_for_user sum_ab += score_for_qs * score_for_user sum_a_sq += score_for_qs * score_for_qs sum_b_sq += score_for_user * score_for_user num = n * sum_ab - sum_a * sum_b den = np.sqrt(n * sum_a_sq - sum_a**2) * np.sqrt(n * sum_b_sq - sum_b**2) corrs[i] = np.nan if den == 0 else num / den return corrs
Посмотрим со значением parallel=False.
Avg time per iteration: 1.7 μs Speedup over baseline: 20850.5x
Результат можно сравнить с оптимизацией 12 с parallel=False, которая показала 8353x.
Теперь с параметром parallel=True.
Avg time per iteration: 210 ns Speedup over baseline: 170476.3x
Мы достигли ускорения в 170 000 по сравнению с исходным кодом!
Вывод
Благодаря Numba и NumPy мы получили большинство тех инструментов, которые сделали оптимизированный код Rust быстрым: в частности, bitsets, SIMD и параллелизм на уровне циклов. Сперва мы значительно ускорили оригинальный код на Python с помощью нескольких вспомогательных функций с JIT-компиляцией, в итоге использовали JIT-компиляцию повсеместно и оптимизировали код для этого. Мы использовали подход проб и ошибок, применяя профилирование, чтобы сосредоточить усилия на самых медленных строках кода. Мы показали, что можем использовать Numba для постепенного добавления JIT-компилированного кода в нашу кодовую базу Python. Мы можем сразу же добавить этот код в существующую кодовую базу Python. Однако мы не достигли 180 000-кратного ускорения оптимизированного кода Rust, и развернули собственную реализацию корреляции и bitsets, в то время как код Rust смог использовать библиотеки для них, оставаясь при этом быстрым.
Это было забавное упражнение, которое, надеюсь, продемонстрировало вам некоторые полезные инструменты в экосистеме Python.
Стал бы я рекомендовать один подход вместо другого? Нет, все зависит от конкретной ситуации.
