Как стать автором
Обновить
588.3
OTUS
Цифровые навыки от ведущих экспертов

Кратко про язык программирования Triton

Время на прочтение5 мин
Количество просмотров4.7K

Привет, Хабр!

Triton был разработан специально для выполнения на GPU и предоставляет удобную Python-ориентированную среду.

Triton позволяет использовать модель программирования, основанную на блоках, которая значительно отличается от традиционной модели CUDA. Вместо управления потоками на уровне скалярных инструкций, Triton оперирует блоками данных, что в целом дает более лучшую производительность.

В отличие от стандартного подхода CUDA, где исполнение кода организуется через взаимодействие множества потоков, Triton структурирует выполнение на уровне программ. Т.е каждый блок программы может быть исполнен независимо, с возможностью обращения к глобальной памяти GPU и выполнения асинхронных операций без явного управления синхронизацией потоков.

Компилятор Triton применяет сложные стратегии оптимизации, к примеру такие как анализ потока данных и управление памятью на уровне блоков. Это включает в себя: автоматическая векторизация, предварительная выборка данных, и использование тензорных ядер, где это возможно. Такой подход позволяет максимально использовать возможности GPU.

Установим

Самый простой и доступный способ установить Triton — это через pip:

pip install triton

Triton можно также собрать из исходников. Для этого клонируем репозиторий и устанавливаем необходимые зависимости:

git clone https://github.com/openai/triton.git
cd triton/python
pip install ninja cmake wheel
pip install -e .

Учтите, что если на системе не установлен llvm, скрипт setup.py скачает официальные статические библиотеки LLVM и свяжет их.

Рассмотрим весь основной синтаксис

Triton использует декоратор @triton.jit для компиляции Python-функций в GPU ядра. Пример определения функции:

@triton.jit
def my_kernel(x_ptr, y_ptr, z_ptr, N):
    # тело функции

Можно использовать операции с тензорами, наподобие тех, что есть в NumPy. Пример создания тензора и выполнения операции:

x = triton.testing.randn((N,), dtype=torch.float32, device=device)
y = triton.testing.randn((N,), dtype=torch.float32, device=device)

Также есть поддержка операций индексации и срезов, аналогичные Python:

x = x_ptr + tl.arange(0, N, dtype=tl.int64)

Есть различные функции для управления памятью, включая загрузку tl.load и сохранение tl.store данных:

x = tl.load(x_ptr + idx)
tl.store(z_ptr + idx, x + 5)

Можно контролировать распределение памяти и выполнение потоков на GPU таким образом:

BLOCK_SIZE = 128
grid = lambda meta: (triton.cdiv(N, BLOCK_SIZE),)

Есть также множество встроенных математических функций:

  1. abs - вычисляет поэлементное абсолютное значение x.

  2. cdiv - вычисляет потолок от деления x на div.

  3. clamp - ограничивает тензор x в пределах указанного диапазона [min, max].

  4. cos - вычисляет поэлементный косинус x.

  5. div_rn - вычисляет поэлементное точное деление x на y с округлением к ближайшему целому.

  6. erf - вычисляет поэлементную функцию ошибок x.

  7. exp - вычисляет поэлементную экспоненту x.

  8. exp2 - вычисляет поэлементную экспоненту x по основанию 2.

  9. fma - вычисляет поэлементно слияние умножения и сложения для x, y и z (x * y + z).

  10. fdiv - вычисляет поэлементное быстрое деление x на y.

  11. floor - вычисляет поэлементное округление x вниз.

  12. log - вычисляет поэлементный натуральный логарифм x.

  13. log2 - вычисляет поэлементный логарифм x по основанию 2.

  14. maximum - вычисляет поэлементный максимум из x и y.

  15. minimum - вычисляет поэлементный минимум из x и y.

  16. sigmoid - вычисляет поэлементную сигмоидную функцию x.

  17. sin - вычисляет поэлементный синус x.

  18. softmax - вычисляет поэлементную softmax функцию x.

  19. sqrt - вычисляет поэлементный быстрый квадратный корень x.

  20. sqrt_rn - вычисляет поэлементный точный квадратный корень x с округлением к ближайшему.

  21. umulhi - вычисляет поэлементно старшие N бит из 2N-битного произведения x и y.

Юзать их достаточно просто, к примеру:

z = tl.maximum(x, y)

После определения функции её можно скомпилировать и выполнить, передав соответствующие параметры:

my_kernel[grid](x_ptr, y_ptr, z_ptr, N)

Есть поддержка векторизации:

x = tl.load(x_ptr + tl.arange(0, BLOCK_SIZE))

Есть стандартные условные операторы, такие как if, else :

if idx < N:
    x = x_ptr[idx]
    y = y_ptr[idx]
    z_ptr[idx] = x + y

Когда нужно, чтобы несколько потоков должны безопасно обновлять одни и те же данные, можно юзать атомарные операции, такие как tl.atomic_add:

tl.atomic_add(z_ptr[idx], x + y)

Здесь мы выполнили атомарное сложение по указанному адресу памяти.

Прочие атомарные операции:

  1. atomic_cas

    • Выполняет атомарную операцию сравнения и замены по указанному адресу памяти.

  2. atomic_max

    • Выполняет атомарное нахождение максимума по указанному адресу памяти.

  3. atomic_min

    • Выполняет атомарное нахождение минимума по указанному адресу памяти.

  4. atomic_or

    • Выполняет атомарную логическую операцию ИЛИ по указанному адресу памяти.

  5. atomic_xchg

    • Выполняет атомарный обмен значениями по указанному адресу памяти.

  6. atomic_xor

    • Выполняет атомарную логическую операцию исключающее ИЛИ по указанному адресу памяти.

Также есть операции сканирования и сортировки:

  1. associative_scan

    • Применяет функцию combine_fn к каждому элементу с сохранением промежуточного значения carry в тензорах вдоль указанной оси и обновляет carry.

  2. cumprod

    • Возвращает кумулятивное произведение всех элементов в тензоре вдоль указанной оси.

  3. cumsum

    • Возвращает кумулятивную сумму всех элементов в тензоре вдоль указанной оси.

  4. histogram

    • Вычисляет гистограмму на основе входного тензора с заданным числом корзин num_bins, корзины имеют ширину 1 и начинаются с 0.

  5. sort

    • Сортирует элементы тензора. Может быть вызвана как метод тензора x.sort(...).

И стандартные операции редукции:

  1. argmax

    • Возвращает индекс максимального элемента в тензоре вдоль указанной оси.

  2. argmin

    • Возвращает индекс минимального элемента в тензоре вдоль указанной оси.

  3. max

    • Возвращает максимальное значение среди всех элементов тензора вдоль указанной оси.

  4. min

    • Возвращает минимальное значение среди всех элементов тензора вдоль указанной оси.

  5. reduce

    • Применяет функцию комбинирования combine_fn ко всем элементам входных тензоров вдоль указанной оси.

Итак, как все это применять?

Можно юзать atomic_add для эффективной реализации градиентного спуска в среде с высокой степенью параллелизма. Например, при обучении больших нейронных сетей на GPU, атомарные операции могут обеспечить обновление весов без конфликтов между потоками:

@triton.jit
def update_weights(grads, weights, learning_rate):
    pid = tl.program_id(0)
    for i in range(pid, grads.shape[0], tl.num_programs()):
        tl.atomic_add(weights, i, -learning_rate * grads[i])

Можно также юзать sort и reduce для реализации алгоритмов кластеризации, таких как K-means, где необходимо сортировать данные и вычислять центроиды:

@triton.jit
def k_means_update(x, centroids, labels, num_clusters):
    # вычисление расстояний и присвоение меток
    block_start = tl.program_id(0) * tl.num_programs()
    distances = tl.zeros((num_clusters,), dtype=tl.float32)
    for i in range(block_start, min(block_start + BLOCK_SIZE, x.shape[0])):
        for c in range(num_clusters):
            distances[c] = tl.sum(tl.pow(x[i] - centroids[c], 2))
        labels[i] = tl.argmin(distances)
    # обновление центроидов
    for c in range(num_clusters):
        assigned_pts = x[labels == c]
        centroids[c] = tl.sum(assigned_pts, axis=0) / len(assigned_pts)

А вот exp, max, и log можно юзать например, для фунций softmax:

@triton.jit
def softmax(x):
    max_val = tl.max(x, axis=1, keepdims=True)
    exp_x = tl.exp(x - max_val)
    sum_exp_x = tl.sum(exp_x, axis=1, keepdims=True)
    return exp_x / sum_exp_x

Можно использовать cumsum для расчёта скользящего среднего (что применимо во временных рядах):

@triton.jit
def moving_average(data, window_size):
    cum_sum = tl.cumsum(data, axis=0)
    return (cum_sum[window_size:] - cum_sum[:-window_size]) / window_size

associative_scan подходит для реализации алгоритма прямого распространения в RNN, где последовательные зависимости могут быть эффективно обработаны с помощью этой операции:


@triton.jit
def rnn_step(hidden, input, weights):
    carry = tl.zeros_like(hidden)
    for t in range(input.shape[0]):
        carry = tl.associative_scan(lambda h, x: tl.tanh(x @ weights + h), input[t], axis=0)
        hidden[t] = carry
    return hidden

Также можно применить associative_scan для реализации параллельного алгоритма быстрого преобразования Фурье:

@triton.jit
def fft_step(data, step_size):
    idx = tl.program_id(0) * step_size * 2 + tl.arange(0, step_size)
    t = exp(-2j * pi / (2 * step_size) * idx)
    u = data[idx]
    v = data[idx + step_size] * t
    data[idx] = u + v
    data[idx + step_size] = u - v

Более подробно с Triton можно ознакомиться здесь.

Материал подготовлен в преддверии старта специализации "Machine Learning".

Теги:
Хабы:
Всего голосов 6: ↑6 и ↓0+7
Комментарии1

Публикации

Информация

Сайт
otus.ru
Дата регистрации
Дата основания
Численность
101–200 человек
Местоположение
Россия
Представитель
OTUS