Как стать автором
Обновить
650.87
OTUS
Цифровые навыки от ведущих экспертов

Кратко про библиотеку Rumale для машинного обучения на Ruby

Уровень сложностиПростой
Время на прочтение5 мин
Количество просмотров916

Привет, Хабр!

Библиотека Rumale создана для того, чтобы сделать машинное обучение доступным и удобным для разрабов на Ruby. Она имеет большой выбор алгоритмов и инструментов, аналогичных тем, что можно найти в Scikit-learn для Python.

Краткий формат статьи выбран из-за сходств с Sckit learn.

Установим

Открываем Gemfile и добавляем строку:

gem 'rumale'

После этого юзаем bundle install для установки библиотеки:

$ bundle install

Если хочется установить Rumale без Bundler, можно сделать это напрямую через команду gem install:

$ gem install rumale

После установки библиотеки, подключаем в проект:

require 'rumale'

Построение и обучение моделей в Rumale

Загружать данные будем с библиотеками Daru и RDatasets.

Линейная регрессия

Линейная регрессия — это база для предсказания числовых значений. В Rumale для этой цели используется класс Rumale::LinearModel::LinearRegression:

require 'daru'
require 'rumale'

# создание набора данных
data = Daru::DataFrame.from_csv('housing_prices.csv')
x = data['size'].to_a
y = data['price'].to_a

# преобразование данных в формат, подходящий для Rumale
x = Numo::DFloat[x].reshape(x.size, 1)
y = Numo::DFloat[y]

# построение и обучение модели линейной регрессии
model = Rumale::LinearModel::LinearRegression.new
model.fit(x, y)

# предсказание на новых данных
predicted = model.predict(x)
puts "Предсказанные значения: #{predicted.to_a}"

Данные о размерах домов и их ценах загружаются из CSV-файла, преобразуются в массивы, а затем используются для обучения модели линейной регрессии.

Метод опорных векторов (SVM)

Метод опорных векторов — это алгоритм для задач классификации. В Rumale он представлен классом Rumale::LinearModel::SVC:

require 'daru'
require 'rumale'
require 'rdatasets'

# загрузка набора данных Iris
iris = RDatasets.load(:datasets, :iris)
x = iris[0..3].to_matrix
y = iris['Species'].map { |species| species == 'setosa' ? 0 : 1 }

# преобразование данных в формат Numo::NArray
x = Numo::DFloat[*x.to_a]
y = Numo::Int32[*y]

# построение и обучение модели SVM
model = Rumale::LinearModel::SVC.new(kernel: 'linear', reg_param: 1.0)
model.fit(x, y)

# предсказание на новых данных
predicted = model.predict(x)
puts "Предсказанные значения: #{predicted.to_a}"

SVM моделька классифицирует цветы как setosa или нет.

Кластеризация с использованием K-Means

K-Means — это алгоритм кластеризации, который группирует данные на основе их схожести. В Rumale используется класс Rumale::Clustering::KMeans:

require 'daru'
require 'rumale'
require 'rdatasets'

# загрузка набора данных Iris
iris = RDatasets.load(:datasets, :iris)
x = iris[0..3].to_matrix

# преобразование данных в формат Numo::NArray
x = Numo::DFloat[*x.to_a]

# построение и обучение модели K-Means
model = Rumale::Clustering::KMeans.new(n_clusters: 3, max_iter: 300)
model.fit(x)

# предсказание кластеров
labels = model.predict(x)
puts "Кластеры: #{labels.to_a}"

Используем данные Iris для кластеризации их на три группы с помощью K-Means.

Прочие алгоритмы

Random Forest:

require 'daru'
require 'rumale'
require 'rdatasets'

# загрузка набора данных Iris
iris = RDatasets.load(:datasets, :iris)
x = iris[0..3].to_matrix
y = iris['Species'].map { |species| species == 'setosa' ? 0 : 1 }

# преобразование данных в формат Numo::NArray
x = Numo::DFloat[*x.to_a]
y = Numo::Int32[*y]

# построение и обучение модели Random Forest
model = Rumale::Ensemble::RandomForestClassifier.new(n_estimators: 10, max_depth: 3)
model.fit(x, y)

# предсказание на новых данных
predicted = model.predict(x)
puts "Предсказанные значения: #{predicted.to_a}"

Gradient Boosting:

require 'daru'
require 'rumale'
require 'rdatasets'

# загрузка набора данных Iris
iris = RDatasets.load(:datasets, :iris)
x = iris[0..3].to_matrix
y = iris['Species'].map { |species| species == 'setosa' ? 0 : 1 }

# преобразование данных в формат Numo::NArray
x = Numo::DFloat[*x.to_a]
y = Numo::Int32[*y]

# построение и обучение модели Gradient Boosting
model = Rumale::Ensemble::GradientBoostingClassifier.new(n_estimators: 100, learning_rate: 0.1, max_depth: 3)
model.fit(x, y)

# предсказание на новых данных
predicted = model.predict(x)
puts "Предсказанные значения: #{predicted.to_a}"

Оценка и валидация моделей

Метрики оценки качества моделей

Среднеквадратичная ошибка (MSE): измеряет среднее значение квадратов ошибок, т.е разницу между предсказанными и фактическими значениями:

require 'numo/narray'
require 'rumale'

# пример данных
y_true = Numo::DFloat[3.0, -0.5, 2.0, 7.0]
y_pred = Numo::DFloat[2.5, 0.0, 2.0, 8.0]

# расчет MSE
mse = Rumale::EvaluationMeasure::MeanSquaredError.new
mse_value = mse.score(y_true, y_pred)
puts "MSE: #{mse_value}"

Коэффициент детерминации (R²): измеряет долю дисперсии, объясненную моделью. Значение R² варьируется от 0 до 1, где 1 означает идеальное соответствие:

# расчет R²
r2 = Rumale::EvaluationMeasure::RSquared.new
r2_value = r2.score(y_true, y_pred)
puts "R²: #{r2_value}"

Кросс-валидации

Кросс-валидация позволяет оценить обобщающую способность модели. Одним из самых частых методов - K-Fold кросс-валидация.

K-Fold кросс-валидация:

require 'rumale'
require 'daru'
require 'rdatasets'

# загрузка данных Iris
iris = RDatasets.load(:datasets, :iris)
x = iris[0..3].to_matrix
y = iris['Species'].map { |species| species == 'setosa' ? 0 : 1 }

x = Numo::DFloat[*x.to_a]
y = Numo::Int32[*y]

# определение модели
model = Rumale::LinearModel::LogisticRegression.new

# определение метрики оценки
mse = Rumale::EvaluationMeasure::MeanSquaredError.new

# настройка K-Fold кросс-валидации
kf = Rumale::ModelSelection::KFold.new(n_splits: 5, shuffle: true, random_seed: 1)

# проведение кросс-валидации
cv = Rumale::ModelSelection::CrossValidation.new(estimator: model, splitter: kf, evaluator: mse)
report = cv.perform(x, y)

# вывод результатов
mean_score = report[:test_score].sum / kf.n_splits
puts "5-CV MSE: #{mean_score}"

После выполнения кросс-валидации или других методов оценки, очень важно не забывать о том, что нужно еще и правильно интерпретировать полученные результаты.

Среднее значение и стандартное отклонение: эти показатели дают представление о стабильности и надежности модели. Например, низкое ср. значение ошибки и низкое стандартное отклонение указывают на стабильную и точную модель:

mean_score = report[:test_score].mean
std_score = report[:test_score].std
puts "Mean MSE: #{mean_score}, Standard Deviation: #{std_score}"

Можно еще подключить gnuplot, чтобы визуализировать и помогает понять производительность модельки на различных наборах данных:

require 'gnuplot'

Gnuplot.open do |gp|
  Gnuplot::Plot.new(gp) do |plot|
    plot.title "K-Fold Cross Validation Scores"
    plot.ylabel "MSE"
    plot.xlabel "Fold"

    plot.data << Gnuplot::DataSet.new(report[:test_score]) do |ds|
      ds.with = "linespoints"
      ds.title = "Fold MSE"
    end
  end
end

Подробнее с этой замечательной библиотекой можно ознакомиться здесь.

А с другими инструментами и библиотеками вы всегда можете познакомиться в рамках практических онлайн-курсов от моих коллег из OTUS.

Теги:
Хабы:
Всего голосов 5: ↑5 и ↓0+7
Комментарии3

Публикации

Информация

Сайт
otus.ru
Дата регистрации
Дата основания
Численность
101–200 человек
Местоположение
Россия
Представитель
OTUS