Pull to refresh

How linear algebra is applied in machine learning

Reading time5 min
Views14K

When you study an abstract subject like linear algebra, you may wonder: why do you need all these vectors and matrices? How are you going to apply all this inversions, transpositions, eigenvector and eigenvalues for practical purposes?


Well, if you study linear algebra with the purpose of doing machine learning, this is the answer for you.


In brief, you can use linear algebra for machine learning on 3 different levels:


  • application of a model to data;
  • training the model;
  • understanding how it works or why it does not work.

drawing

I assume that you, the reader, have at least a vague idea of linear algebra concepts (such as vectors, matrices, their products, inverse matrices, eigenvectors and eigenvalues), and machine learning problems (such as regression, classification and dimensionality reduction). If not, maybe now it's a good time to read about them in Wikipedia, or even signup for a MOOC on these subjects.


Application


What machine learning usually does is fitting some function $f_W(X)=H$, where $X$ is the input data, $H$ is some useful representation of this data, and $W$ are additional parameters, on which our function depends and which have to be learned. When we have this representation $H$, we can use it e.g. to reconstruct the original data $X$ (as in unsupervised learning), or to predict some value of interest, $Y$ (as in supervised learning).


All of $X$, $H$, $W$ and $Y$ are usually numeric arrays, and can be at least stored as vectors and matrices. But storage alone is not important. The important thing is that our function $f$ is often linear, that is, $H=XW$. Examples of such linear algorithms are:


  • linear regression, where $Y=H$. It is a sensible baseline for regression problems, and a popular tool for answering questions like "does $x$ affect $y$, other things being equal?"
  • logistic regression, where $Y=softmax(H)$. It is a good baseline for classification problems, and sometimes this baseline is difficult to beat.
  • principal component analysis, where $H$ is just a low-dimensional representation of high-dimensional $X$, from which $X$ can be restored with high precision. You can think of it as a compression algorithm.
  • Other PCA-like algorithms (matrix decompositions) are widely used in recommender systems, to turn a very sparce matrix of "which products were purchased by which users" into compact and dense representations of users and products, that can be further used to predict new transactions.

Other algorithms, like neural network, learn nonlinear transformations, but still rely heavily on linear operations (that is, matrix-matrix or matrix-vector multiplication). A simple neural network may look like $Y=\sigma(W_2\sigma(W_1X))$ — it uses two matrix multiplications, and a nonlinear transformation $\sigma$ between them.


Training


To train an algorithm, you usually define a loss function and try to optimize it. The loss itself is sometimes convenient to write in terms of linear algebra. For example, the quadratic loss (used in the least squares method) can be written as a dot product $(Y-\hat{Y})^T(Y-\hat{Y})$, where $\hat{Y}$ is the vector of your prediction, and $Y$ is the ground truth you try to predict. This representation is useful, because it enables us to derive ways to minimize this loss. For example, if you use linear regression with this least squares method, then your optimal solution looks like $W=(X^TX)^{-1}X^TY$. Lots of linear operations in one place!


Another example of linear solution is PCA, where the parameters of interest $W$ are the first $k$ eigenvectors of the matrix $X^TX$, corresponding to the largest eigenvalues.


If you train neural networks, there is usually no analytical solution for the optimal parameters, and you have to use gradient descent. To do this, you need to differentiate the loss w.r.t. the parameters, and it while doing so, you again have to multiply matrices, because if $loss=f(g(h(w)))$ (a composite function), then $\frac{\partial loss}{\partial w} = f' \times g' \times h'$, and all these derivatives are matrices or vectors, because $g$ and $h$ are multidimensional.


Simple gradient descent is OK, but it is slow. You can speed it up, by applying Newtonian optimization methods. The basic method is $W_{t+1}=W_t - A^{-1}B$, where $B$ and $A$ are are the vector of first derivatives and the matrix of the second derivatives of your loss w.r.t. the parameters $W$. But it can be unstable and/or computationally expensive, and you may need to come up with its approximations (like L-BFGS) that use even more involved linear algebra for quick and cheap optimization.


Analysis


You see that linear algebra helps you to apply and to train your models. But the real science (or magic) starts when your model refuses to train or predict well. The learning may get stuck at a bad point, or suddenly go wild. In deep learning, it often happens due to vanishing or exploding gradients. That is, whey you calculate the gradient, you multiply lots of matrices, and then strange things happen, and you need to know what, why, and how to overcome it. One of the ways to inspect what is happening is to keep track of the eigenvalues of the matrices you are trying to invert. If they are close to 0, or just very different, then the inversion of this matrix can lead to unstable results. If you multiply many matrices with large eigenvalues, the product explodes. When these eigenvalues are small, the result fades to zero.


Different techniques, like L1/L2 regularization, batch normalization, and LSTM were invented in order to fight these problems with convergence. If you want to apply any of those techniques, you need a way to measure whether they help much for your particular problem. And if you want to invent such a technique yourself, you need a way to prove that it can work at all. This again involves lots of manipulation with vectors, matrices, their decompositions, etc.


Conclusion


You can see that the deeper you dive into machine learning, the more linear algebra you see there. To apply pre-trained models, you have to at least convert your data into a format, compatible with linear algebra (e.g. numpy.array in Python). If you need to implement a training algorithm, or even to invent a new one, be prepared to multiply, invert, and decompose lots of matrices.


In this text, I have referenced some concepts you may be unfamiliar with. It's okay. What this article encourages you is to search for the unknown words and enlarge your horizons.


By the way, it would be interesting to hear some stories from you in the comments, about how you encountered applications of linear algebra in your own job or study.


P.S. In one of my articles, I argued that you don't have to learn maths in order to become successful (it's still a popular stereotype in Russia), even if you work in IT. However, I never said that maths is useless (otherwise, I wouldn't be teaching it all the time). Usually it is not the key to success, but in many cases, it helps, and in a few (like developing deep learning models), it is essential.


P.P.S. Почему на английском?! Ну просто потому что могу. Оригинальный вопрос мне задали на этом языке, и на английском же я на него ответил. А потом решил, что ответ можно довести до уровня маленькой публичной статьюшки.


Почему тогда Хабр, а не, например, Медиум? Во-первых, в отличие от Медиума, тут нормально поддержаны формулы. Во-вторых, Хабр вроде сам собирался выходить на международные рынки — так почему бы не попробовать разместить тут кусочек англоязычного контента?


Посмотрим, что из этого получится.

Only registered users can participate in poll. Log in, please.
Is it a good idea to post on Habr in English? Это хорошая идея — писать на Хабре на английском?
29.28% Yes, please write more such articles!106
20.17% Нет, Хабрахабр — только для русских!73
50.55% Well, I don't care about the language, just give me good content183
362 users voted. 37 users abstained.
Tags:
Hubs:
If this publication inspired you and you want to support the author, do not hesitate to click on the button
Total votes 49: ↑37 and ↓12+25
Comments39

Articles