Обновить

Про умножение матриц или как курс по вычислительной линейной алгебре проигрывает жестокой реальности

Мы умеем умножать матрицы быстрее, чем за O(N^3)! По крайней мере, так рассказывают на курсе по алгоритмам. Потом теория сталкивается с "железом", и выясняется, что в DL этим почти никто не пользуется. Но почему?

Для начала вспомним базовые факты про умножение матриц:

  • У нас есть матрицы A (B x D) и B (D x K);

  • При их умножении нам нужно сделать одно сложение и одно умножение для каждого элемента в паре "строка–столбец";

  • Получается B x D x K таких троек для каждой операции;

  • Итого 2 B x D x K троек;

Для квадратных матриц это упрощается до 2 * n^3, то есть O(n^3).

Умный дядька Штрассен когда-то предложил алгоритм, который уменьшает число умножений за счёт рекурсивного разбиения матриц. В сухом остатке теоретическая сложность падает примерно до O(N^2.7).

Сегодня я смотрел лекции "LLM from Scratch" и заметил, что они считают FLOPs что называется "в лоб" - будто в PyTorch используется наивное умножение матриц (скрин из лекции ниже). Сначала подумал, что это просто упрощение, чтобы не уходить в численные методы линейной алгебры, но решил копнуть глубже.

Выяснилось, что в DL практически никто не использует алгоритм Штрассена (и его современные, ещё более эффективные аналоги)!

Во-первых, он менее численно устойчив из-за сложений и вычитаний промежуточных подматриц.

Во-вторых, он плохо стыкуется со специализированными тензорными ядрами, которые выполняют операции Matrix Multiply-Accumulate (MMA, D = A * B + C) на маленьких матрицах фиксированного размера.

В-третьих, из-за рекурсивной структуры он сильно менее эффективен с точки зрения работы с памятью и кэшем.

Реальность vs теория — 1:0

Теги:
0
Комментарии2

Публикации

Ближайшие события