Про умножение матриц или как курс по вычислительной линейной алгебре проигрывает жестокой реальности
Мы умеем умножать матрицы быстрее, чем за O(N^3)! По крайней мере, так рассказывают на курсе по алгоритмам. Потом теория сталкивается с "железом", и выясняется, что в DL этим почти никто не пользуется. Но почему?
Для начала вспомним базовые факты про умножение матриц:
У нас есть матрицы A (B x D) и B (D x K);
При их умножении нам нужно сделать одно сложение и одно умножение для каждого элемента в паре "строка–столбец";
Получается B x D x K таких троек для каждой операции;
Итого 2 B x D x K троек;
Для квадратных матриц это упрощается до 2 * n^3, то есть O(n^3).
Умный дядька Штрассен когда-то предложил алгоритм, который уменьшает число умножений за счёт рекурсивного разбиения матриц. В сухом остатке теоретическая сложность падает примерно до O(N^2.7).
Сегодня я смотрел лекции "LLM from Scratch" и заметил, что они считают FLOPs что называется "в лоб" - будто в PyTorch используется наивное умножение матриц (скрин из лекции ниже). Сначала подумал, что это просто упрощение, чтобы не уходить в численные методы линейной алгебры, но решил копнуть глубже.

Выяснилось, что в DL практически никто не использует алгоритм Штрассена (и его современные, ещё более эффективные аналоги)!
Во-первых, он менее численно устойчив из-за сложений и вычитаний промежуточных подматриц.
Во-вторых, он плохо стыкуется со специализированными тензорными ядрами, которые выполняют операции Matrix Multiply-Accumulate (MMA, D = A * B + C) на маленьких матрицах фиксированного размера.
В-третьих, из-за рекурсивной структуры он сильно менее эффективен с точки зрения работы с памятью и кэшем.
Реальность vs теория — 1:0
