Многие из нас мечтали бы заглянуть в будущее — ведь это по-настоящему полезный навык. А что, если я скажу, что при помощи математики можно приблизиться к этой мечте? Да, с некоторыми оговорками, но в этой статье мы попробуем почувствовать себя настоящим Докторам Стрэнджам и предсказать какую кассу соберет фильм при определенном бюджете.
Сегодня мы простыми словами разберёмся, что такое линейная регрессия и напишем код на Python, который демонстрирует работу линейной регрессии.
Интуиция
Давайте сначала разберёмся интуитивно: как вообще люди предсказывают будущее? Если это не случайное предположение, взятое из головы, то, как правило, мы опираемся на закономерности. Например, можно приблизительно оценить успех фильма по его бюджету: чем он больше, тем выше вероятность, что фильм получится качественным и соберёт кассу.
Другими словами, для предсказания часто достаточно понять, как одна величина зависит от другой. А что, если бы мы могли отобразить визуально эту зависимость? Именно этим и занимается линейная регрессия.
Что такое линейная регрессия
Постараюсь без сложных формулировок, честно.
Линейная регрессия - это статистический метод, который часто применяется в data science и статистике. Она позволяет визуально отобразить зависимость между данными для дальнейшего анализа и прогнозов на будущее при помощи линейной зависимости(на графике - это прямая). Говоря более неформально, то линейная регрессия, по сути, это правильная подгонка прямой под данные.
Особо внимательные уловили связь с линейной функцией, уж многие определения на нее походят и вы правы, так как формула линейной регрессии выглядит так:
где:
y - предсказание
m - коэффициент
b - свободный член
Фактически нахождение линейной регрессии сводится к идеальному подбору m и b, но об оптимизации прямой мы поговорим в следующих статьях.
Механизм работы линейной регрессии
Представьте, что вы хотите открыть свой бизнес, например, ларёк с шаурмой, но не знаете, где его разместить. Однако вы знакомы линейной регрессией. Вы находите таблицу с двумя колонками: улица, на которой расположен ларёк, и его месячная выручка.
Вы решаете построить график:
По оси X откладываете улицы. Чем ближе улица к центру, тем левее она находится.
По оси Y располагаете месячную выручку.

Хм… Можно заметить закономерность: чем ближе улица к центру, тем выше выручка. Мы определили это на глаз, так как данных мало и они удобно расположены. Давайте теперь проведём прямую.

На этом графике можно увидеть синие точки — это так называемые остатки. Они показывают, насколько линейная регрессия ошиблась при предсказании. При построении графика нужно минимизировать расстояние между точками и линией.
Линейная регрессия на Python
Перед тем, как приступить к делу, я оставлю ссылку на GitHub, с которого брал таблицу.
Сейчас мы будем обучать линейную регрессию на CSV-таблице, которая содержит данные о фильмах, а именно, бюджет и итоговую кассу. Ну что же, приступим!
import pandas as pd
from pandas import DataFrame
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression
Импортируем библиотеки для работы с файлами, графиками и линейной регрессией соответственно.
data = pd.read_csv('cost_revenue_clean.csv')
Читаем данные с CSV-файла.
x = DataFrame(data, columns=['production_budget_usd'])
y = DataFrame(data, columns=['worldwide_gross_usd'])
Извлекаем данные из таблиц бюджета и кассы и сохраняем их в x и y соответственно.
fit = LinearRegression().fit(x,y)
Обучаем нашу линейную регрессию на данных x и y, которые мы извлекли ранее.
m = fit.coef_.flatten()
b = fit.intercept_.flatten()
В первой строке мы сохраняем значения атрибута fit.coef_. который хранит в себе двухмерный массив, а далее преобразовываем его в одномерный массив.
По такой же логике поступаем и с переменной b, но только в этом случае данные хранятся в атрибуте fit.intercept_.
plt.figure(figsize=(8, 6))
plt.scatter(x, y, alpha=0.5, label="Данные")
plt.plot(x, m * X + b, color='red', label="Линейная регрессия")
plt.xlabel("Бюджет фильма (USD)")
plt.ylabel("Мировые сборы (USD)")
plt.title("Линейная регрессия: бюджет и сборы")
plt.legend()
plt.show()
Говоря вкратце, настраиваем визуализацию графика.

Такая красота у нас получилась, но давайте сделаем предсказание.
Заключение
Сегодня мы узнали, что такое линейная регрессия, в чем её основной механизм, к чему нужно стремиться при построении её графика, а также написали несложный код на Python, с помощью которого предсказали кассу фильма при определённом бюджете.