
Привет, Хабр!
Triton был разработан специально для выполнения на GPU и предоставляет удобную Python-ориентированную среду.
Triton позволяет использовать модель программирования, основанную на блоках, которая значительно отличается от традиционной модели CUDA. Вместо управления потоками на уровне скалярных инструкций, Triton оперирует блоками данных, что в целом дает более лучшую производительность.
В отличие от стандартного подхода CUDA, где исполнение кода организуется через взаимодействие множества потоков, Triton структурирует выполнение на уровне программ. Т.е каждый блок программы может быть исполнен независимо, с возможностью обращения к глобальной памяти GPU и выполнения асинхронных операций без явного управления синхронизацией потоков.
Компилятор Triton применяет сложные стратегии оптимизации, к примеру такие как анализ потока данных и управление памятью на уровне блоков. Это включает в себя: автоматическая векторизация, предварительная выборка данных, и использование тензорных ядер, где это возможно. Такой подход позволяет максимально использовать возможности GPU.
Установим
Самый простой и доступный способ установить Triton — это через pip:
pip install triton
Triton можно также собрать из исходников. Для этого клонируем репозиторий и устанавливаем необходимые зависимости:
git clone https://github.com/openai/triton.git
cd triton/python
pip install ninja cmake wheel
pip install -e .
Учтите, что если на системе не установлен llvm, скрипт setup.py скачает официальные статические библиотеки LLVM и свяжет их.
Рассмотрим весь основной синтаксис
Triton использует декоратор @triton.jit
для компиляции Python-функций в GPU ядра. Пример определения функции:
@triton.jit
def my_kernel(x_ptr, y_ptr, z_ptr, N):
# тело функции
Можно использовать операции с тензорами, наподобие тех, что есть в NumPy. Пример создания тензора и выполнения операции:
x = triton.testing.randn((N,), dtype=torch.float32, device=device)
y = triton.testing.randn((N,), dtype=torch.float32, device=device)
Также есть поддержка операций индексации и срезов, аналогичные Python:
x = x_ptr + tl.arange(0, N, dtype=tl.int64)
Есть различные функции для управления памятью, включая загрузку tl.load
и сохранение tl.store
данных:
x = tl.load(x_ptr + idx)
tl.store(z_ptr + idx, x + 5)
Можно контролировать распределение памяти и выполнение потоков на GPU таким образом:
BLOCK_SIZE = 128
grid = lambda meta: (triton.cdiv(N, BLOCK_SIZE),)
Есть также множество встроенных математических функций:
abs - вычисляет поэлементное абсолютное значение
x
.cdiv - вычисляет потолок от деления
x
наdiv
.clamp - ограничивает тензор
x
в пределах указанного диапазона[min, max]
.cos - вычисляет поэлементный косинус
x
.div_rn - вычисляет поэлементное точное деление
x
наy
с округлением к ближайшему целому.erf - вычисляет поэлементную функцию ошибок
x
.exp - вычисляет поэлементную экспоненту
x
.exp2 - вычисляет поэлементную экспоненту
x
по основанию 2.fma - вычисляет поэлементно слияние умножения и сложения для
x
,y
иz
(x * y + z).fdiv - вычисляет поэлементное быстрое деление
x
наy
.floor - вычисляет поэлементное округление
x
вниз.log - вычисляет поэлементный натуральный логарифм
x
.log2 - вычисляет поэлементный логарифм
x
по основанию 2.maximum - вычисляет поэлементный максимум из
x
иy
.minimum - вычисляет поэлементный минимум из
x
иy
.sigmoid - вычисляет поэлементную сигмоидную функцию
x
.sin - вычисляет поэлементный синус
x
.softmax - вычисляет поэлементную softmax функцию
x
.sqrt - вычисляет поэлементный быстрый квадратный корень
x
.sqrt_rn - вычисляет поэлементный точный квадратный корень
x
с округлением к ближайшему.umulhi - вычисляет поэлементно старшие N бит из 2N-битного произведения
x
иy
.
Юзать их достаточно просто, к примеру:
z = tl.maximum(x, y)
После определения функции её можно скомпилировать и выполнить, передав соответствующие параметры:
my_kernel[grid](x_ptr, y_ptr, z_ptr, N)
Есть поддержка векторизации:
x = tl.load(x_ptr + tl.arange(0, BLOCK_SIZE))
Есть стандартные условные операторы, такие как if
, else
:
if idx < N:
x = x_ptr[idx]
y = y_ptr[idx]
z_ptr[idx] = x + y
Когда нужно, чтобы несколько потоков должны безопасно обновлять одни и те же данные, можно юзать атомарные операции, такие как tl.atomic_add
:
tl.atomic_add(z_ptr[idx], x + y)
Здесь мы выполнили атомарное сложение по указанному адресу памяти.
Прочие атомарные операции:
atomic_cas
Выполняет атомарную операцию сравнения и замены по указанному адресу памяти.
atomic_max
Выполняет атомарное нахождение максимума по указанному адресу памяти.
atomic_min
Выполняет атомарное нахождение минимума по указанному адресу памяти.
atomic_or
Выполняет атомарную логическую операцию ИЛИ по указанному адресу памяти.
atomic_xchg
Выполняет атомарный обмен значениями по указанному адресу памяти.
atomic_xor
Выполняет атомарную логическую операцию исключающее ИЛИ по указанному адресу памяти.
Также есть операции сканирования и сортировки:
associative_scan
Применяет функцию
combine_fn
к каждому элементу с сохранением промежуточного значенияcarry
в тензорах вдоль указанной оси и обновляетcarry
.
cumprod
Возвращает кумулятивное произведение всех элементов в тензоре вдоль указанной оси.
cumsum
Возвращает кумулятивную сумму всех элементов в тензоре вдоль указанной оси.
histogram
Вычисляет гистограмму на основе входного тензора с заданным числом корзин
num_bins
, корзины имеют ширину 1 и начинаются с 0.
sort
Сортирует элементы тензора. Может быть вызвана как метод тензора
x.sort(...)
.
И стандартные операции редукции:
argmax
Возвращает индекс максимального элемента в тензоре вдоль указанной оси.
argmin
Возвращает индекс минимального элемента в тензоре вдоль указанной оси.
max
Возвращает максимальное значение среди всех элементов тензора вдоль указанной оси.
min
Возвращает минимальное значение среди всех элементов тензора вдоль указанной оси.
reduce
Применяет функцию комбинирования
combine_fn
ко всем элементам входных тензоров вдоль указанной оси.
Итак, как все это применять?
Можно юзать atomic_add
для эффективной реализации градиентного спуска в среде с высокой степенью параллелизма. Например, при обучении больших нейронных сетей на GPU, атомарные операции могут обеспечить обновление весов без конфликтов между потоками:
@triton.jit
def update_weights(grads, weights, learning_rate):
pid = tl.program_id(0)
for i in range(pid, grads.shape[0], tl.num_programs()):
tl.atomic_add(weights, i, -learning_rate * grads[i])
Можно также юзать sort
и reduce
для реализации алгоритмов кластеризации, таких как K-means, где необходимо сортировать данные и вычислять центроиды:
@triton.jit
def k_means_update(x, centroids, labels, num_clusters):
# вычисление расстояний и присвоение меток
block_start = tl.program_id(0) * tl.num_programs()
distances = tl.zeros((num_clusters,), dtype=tl.float32)
for i in range(block_start, min(block_start + BLOCK_SIZE, x.shape[0])):
for c in range(num_clusters):
distances[c] = tl.sum(tl.pow(x[i] - centroids[c], 2))
labels[i] = tl.argmin(distances)
# обновление центроидов
for c in range(num_clusters):
assigned_pts = x[labels == c]
centroids[c] = tl.sum(assigned_pts, axis=0) / len(assigned_pts)
А вот exp
, max
, и log
можно юзать например, для фунций softmax:
@triton.jit
def softmax(x):
max_val = tl.max(x, axis=1, keepdims=True)
exp_x = tl.exp(x - max_val)
sum_exp_x = tl.sum(exp_x, axis=1, keepdims=True)
return exp_x / sum_exp_x
Можно использовать cumsum
для расчёта скользящего среднего (что применимо во временных рядах):
@triton.jit
def moving_average(data, window_size):
cum_sum = tl.cumsum(data, axis=0)
return (cum_sum[window_size:] - cum_sum[:-window_size]) / window_size
associative_scan
подходит для реализации алгоритма прямого распространения в RNN, где последовательные зависимости могут быть эффективно обработаны с помощью этой операции:
@triton.jit
def rnn_step(hidden, input, weights):
carry = tl.zeros_like(hidden)
for t in range(input.shape[0]):
carry = tl.associative_scan(lambda h, x: tl.tanh(x @ weights + h), input[t], axis=0)
hidden[t] = carry
return hidden
Также можно применить associative_scan
для реализации параллельного алгоритма быстрого преобразования Фурье:
@triton.jit
def fft_step(data, step_size):
idx = tl.program_id(0) * step_size * 2 + tl.arange(0, step_size)
t = exp(-2j * pi / (2 * step_size) * idx)
u = data[idx]
v = data[idx + step_size] * t
data[idx] = u + v
data[idx + step_size] = u - v
Более подробно с Triton можно ознакомиться здесь.
Материал подготовлен в преддверии старта специализации "Machine Learning".