Как быстро написать и выкатить в продакшн алгоритм машинного обучения

Сейчас анализ данных все шире используется в самых разных, зачастую далеких от ИТ, областях и задачи, стоящие перед специалистом на ранних этапах проекта радикально отличаются от тех, с которыми сталкиваются крупные компании с развитыми отделами аналитики. В этой статье я расскажу о том, как быстро сделать полезный прототип и подготовить простой API для его использования прикладным программистом.

Для примера рассмотрим задачу предсказания цены на трубы размещенную на платформе для соревнований Kaggle. Описание и данные можно найти здесь. На самом деле на практике очень часто встречаются задачи в которых надо быстро сделать прототип имея очень небольшое количество данных, а то и вообще не имея реальных данных до момента первого внедрения. В этих случаях приходится подходить к задаче творчески, начинать с несложных эвристик и ценить каждый запрос или размеченный объект. Но в нашей модельной ситуации таких проблем, к счастью, нет и поэтому мы можем сразу начать с обзора данных, определения задачи и попыток применения алгоритмов.

Итак, распаковав архив с данными, можно обнаружить, что в нем содержится приблизительно пара десятков csv-файлов. Согласно описанию на странице соревнования, основными из них являются train_set.csv и test_set.csv. Они содержат базовую информацию касающуюся цен. Остальные файлы содержат вспомогательные, но немногим менее важные данные. Давайте рассмотрим их поподробнее.

Поместив архив в поддиректорию data корневой директории проекта и разархивировав его командами

$ cd data/
$ unzip data.zip
$ cd ..

Мы можем посмотреть, что находится в интересующих нас файлах:

$ head data/competition_data/train_set.csv
tube_assembly_id,supplier,quote_date,annual_usage,min_order_quantity,bracket_pricing,quantity,cost
TA-00002,S-0066,2013-07-07,0,0,Yes,1,21.9059330191461
TA-00002,S-0066,2013-07-07,0,0,Yes,2,12.3412139792904
TA-00002,S-0066,2013-07-07,0,0,Yes,5,6.60182614356538
TA-00002,S-0066,2013-07-07,0,0,Yes,10,4.6877695119712
TA-00002,S-0066,2013-07-07,0,0,Yes,25,3.54156118026073
TA-00002,S-0066,2013-07-07,0,0,Yes,50,3.22440644770007
TA-00002,S-0066,2013-07-07,0,0,Yes,100,3.08252143576504
TA-00002,S-0066,2013-07-07,0,0,Yes,250,2.99905966403855
TA-00004,S-0066,2013-07-07,0,0,Yes,1,21.9727024365273

Видим столбцы с данными, соответствующими объекту (instance), цену которого мы будем предсказывать. А именно — идентификатор сборки (пока совершенно непонятно что это такое, но волшебство машинного обучения, помимо прочего, состоит в том, что для эффективного использования данных иногда совершенно необязательно понимать, что они означают), номер поставщика, дата и так далее. Отдельно обратим внимание на предпоследний столбец с количеством единиц поставки. Наконец последний столбец — метка (label), его цена. Чем больше больше единиц товара поставляется, тем ниже его цена, что вполне согласуется с нашими представления о происходящем в реальном мире.

Посмотрим теперь на файл с тестовыми данными.

$ head data/competition_data/test_set.csv
id,tube_assembly_id,supplier,quote_date,annual_usage,min_order_quantity,bracket_pricing,quantity
1,TA-00001,S-0066,2013-06-23,0,0,Yes,1
2,TA-00001,S-0066,2013-06-23,0,0,Yes,2
3,TA-00001,S-0066,2013-06-23,0,0,Yes,5
4,TA-00001,S-0066,2013-06-23,0,0,Yes,10
5,TA-00001,S-0066,2013-06-23,0,0,Yes,25
6,TA-00001,S-0066,2013-06-23,0,0,Yes,50
7,TA-00001,S-0066,2013-06-23,0,0,Yes,100
8,TA-00001,S-0066,2013-06-23,0,0,Yes,250
9,TA-00003,S-0066,2013-07-07,0,0,Yes,1

Те же самые столбцы за исключением последнего — а именно предсказываемой цены. Что логично для соревнования — именно на этих данных надо сформировать ответ для отправки на сайт соревнования и получения промежуточного результата.

Как мы уже отметили выше, помимо двух основных файлов с данными в нашем распоряжении есть еще и немало дополнительных. Имеет смысл заглянуть хотя бы в парочку из них для того, чтобы уловить основную идею данных в них содержащихся.

$ head data/competition_data/tube.csv
tube_assembly_id,material_id,diameter,wall,length,num_bends,bend_radius,end_a_1x,end_a_2x,end_x_1x,end_x_2x,end_a,end_x,num_boss,num_bracket,other
TA-00001,SP-0035,12.7,1.65,164,5,38.1,N,N,N,N,EF-003,EF-003,0,0,0
TA-00002,SP-0019,6.35,0.71,137,8,19.05,N,N,N,N,EF-008,EF-008,0,0,0
TA-00003,SP-0019,6.35,0.71,127,7,19.05,N,N,N,N,EF-008,EF-008,0,0,0
TA-00004,SP-0019,6.35,0.71,137,9,19.05,N,N,N,N,EF-008,EF-008,0,0,0
TA-00005,SP-0029,19.05,1.24,109,4,50.8,N,N,N,N,EF-003,EF-003,0,0,0
TA-00006,SP-0029,19.05,1.24,79,4,50.8,N,N,N,N,EF-003,EF-003,0,0,0
TA-00007,SP-0035,12.7,1.65,202,5,38.1,N,N,N,N,EF-003,EF-003,0,0,0
TA-00008,SP-0039,6.35,0.71,174,6,19.05,N,N,N,N,EF-008,EF-008,0,0,0
TA-00009,SP-0029,25.4,1.65,135,4,63.5,N,N,N,N,EF-003,EF-003,0,0,0

Ага, это просто соответствие между неким «идентификатором сборки», упоминавшимся выше и материалом, диаметром и тому подобными сведениями. Информация, вероятно, вполне полезная для обучения алгоритма.

Наконец, посмотрим, что содержится в файле посвященном материалам из которых сделаны трубы.

$ head data/competition_data/bill_of_materials.csv

tube_assembly_id,component_id_1,quantity_1,component_id_2,quantity_2,component_id_3,quantity_3,component_id_4,quantity_4,component_id_5,quantity_5,component_id_6,quantity_6,component_id_7,quantity_7,component_id_8,quantity_8

TA-00001,C-1622,2,C-1629,2,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA
TA-00002,C-1312,2,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA
TA-00003,C-1312,2,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA
TA-00004,C-1312,2,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA
TA-00005,C-1624,1,C-1631,1,C-1641,1,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA
TA-00006,C-1624,1,C-1631,1,C-1641,1,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA
TA-00007,C-1622,2,C-1629,2,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA
TA-00008,C-1312,2,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA
TA-00009,C-1625,2,C-1632,2,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA

Опять видим просто список соответствий между «идентификатором сборки» и компонентами. Обратим внимание на обилие полей с значением «NA». Как правило они обозначают пропуски в данных, однако в данном примере они соответствуют случаям, когда компонент просто недостаточно много для того, чтобы заполнить все 8 имеющихся в шапке позиций.

Казалось бы данные в общих чертах изучены, пора переходить непосредственно к настройке алгоритма. А вот и нет — прежде чем взяться за настройку алгоритма стоит понять, чего мы от него собираемся добиться, как мы будем сравнивать между собой два алгоритма, какой из них хуже, а какой — лучше. Во-первых нам придется оставить часть размеченных данных для проверки качества обученного алгоритма, в машинном обучении эта часть данных называется валидационной выборкой. Во-вторых нам придется исключить эту часть данных из тех, на которых мы будем проводить обучение, иначе модель предсказывающая ровно ту метку (в нашем случае — цену), какая была в обучающем объекте и случайную — в любом другом случае, будет давать идеальные предсказания на валидационной выборке, но при этом совершенно ужасные — при реальном применении. Исходные данные после исключения из них валидационной выборки называются тренировочной выборкой. Процесс обучения модели на тренировочной выборке и проверки качества ее работы на валидационной называется валидационной процедурой. Но и это еще не все — помимо выбора валидационной процедуры (которая, кстати говоря, может меняться в процессе экспериментов) необходимо выбрать еще и способ оценки качества предсказания при имеющихся экспериментальных результатах. А именно — функцию, которая возьмет на вход соответствующие два массива и даст на выходе оценку того, насколько предсказания соответствуют эксперименту. Такая функция называется метрикой качества и от ее выбора в реальных ситуациях зачастую зависит намного больше, чем от того, какой алгоритм мы выберем и как будем его настраивать. Не вдаваясь пока в детали этого тонкого процесса воспользуемся предложенной организаторами соревнования Root Mean Squared Logarithic Error (RMSLE).

Итак, принципиальные вопросы решены и мы можем приступить к написанию кода, загружающего данные и обучающего наш алгоритм. Для того, чтобы протестировать базовый пайплайн разделим выборку на тренировочную и валидационную части в соотношении 70/30, используем две простые фичи и один из простейших алгоритмов — линейную регрессию.

Напишем базовую функцию для загрузки данных и проверим ее работоспособность:

def load_data():
  list_of_instances = []
  list_of_labels = []

  with open('./data/competition_data/train_set.csv') as input_stream:
    header_line = input_stream.readline()
    columns = header_line.strip().split(',')
    for line in input_stream:
      new_instance = dict(zip(columns[:-1], line.split(',')[:-1]))
      new_label = float(line.split(',')[-1])
      list_of_instances.append(new_instance)
      list_of_labels.append(new_label)
  return list_of_instances, list_of_labels

>>> list_of_instances, list_of_labels = load_data()
>>> print(len(list_of_instances), len(list_of_labels))
30213 30213
>>> print(list_of_instances[:3])
[{'annual_usage': '0', 'quote_date': '2013-07-07', 'tube_assembly_id': 'TA-00002', 'min_order_quantity': '0', 'bracket_pricing': 'Yes', 'quantity': '1', 'supplier': 'S-0066'}, {'annual_usage': '0', 'quote_date': '2013-07-07', 'tube_assembly_id': 'TA-00002', 'min_order_quantity': '0', 'bracket_pricing': 'Yes', 'quantity': '2', 'supplier': 'S-0066'}, {'annual_usage': '0', 'quote_date': '2013-07-07', 'tube_assembly_id': 'TA-00002', 'min_order_quantity': '0', 'bracket_pricing': 'Yes', 'quantity': '5', 'supplier': 'S-0066'}]
>>> print(list_of_labels[:3])
[21.9059330191461, 12.3412139792904, 6.60182614356538]

Результаты вполне соответствуют ожиданиям. Теперь напишем заготовку функции, переводящей объекты (instances) в фичевекторы (samples).

def is_bracket_pricing(instance):
  if instance['bracket_pricing'] == 'Yes':
    return [1]
  elif instance['bracket_pricing'] == 'No':
    return [0]
  else:
    raise ValueError

def get_quantity(instance):
  return [int(instance['quantity'])]

def to_sample(instance):
  return is_bracket_pricing(instance) + get_quantity(instance)

>>> print(list(map(to_sample, list_of_instances[:3])))
[[1, 1], [1, 2], [1, 5]]

Позже, когда разных фичей станет много, они все переедут в специально отведенный для этого файл features.py, где будут обрастать вариациями, вспомогательными функциями, а в отдельных, особо запущенных случаях — еще и юнит-тестами.

Теперь у нас уже в принципе есть все необходимое для того, чтобы обучить первую, самую простую модель машинного обучения. Маленький (но важный) момент — мы договорились, что будем оптимизироваться под предложенную в условии соревнования метрику

$RMSLE = \sqrt{\frac1n\sum_{i=1}^n(\log(p_i + 1) - \log(a_i + 1))^2},$

где $p_i$ — предсказанные моделью значение, а $a_i$ — фактические. Для того, чтобы наши усилия по подбору фичей и настроек модели в большей степени соответствовали этой цели, применим к меткам (labels) функцию $f(x) = \log(x + 1)$. Дело в том, что большинство регрессионных методов машинного обучения предназначены для минимизации квадратичной ошибки (MSE) — и именно к оптимизации такой метрики качества мы сводим нашу задачу применив к меткам вышеуказанную функцию.

import math

def to_interim_label(label):
  return math.log(label + 1)

def to_final_label(interim_label):
  return math.exp(interim_label) - 1

>>> print(to_final_label(to_interim_label(42)))
42.0

Похоже на то, что мы не ошиблись с функциями перехода от исходных меток к более удобным для оптимизации и обратно и эти функции в самом деле взаимообратны. Теперь инициализируем модель и обучим ее на полученных фичевекторах и промежуточных метках.

>>> model = LinearRegression()
>>> list_of_samples = list(map(to_sample, list_of_instances))
>>> TRAIN_SAMPLES_NUM = 20000
>>> train_samples = list_of_samples[:TRAIN_SAMPLES_NUM]
>>> train_labels = list_of_labels[:TRAIN_SAMPLES_NUM]
>>> model.fit(train_samples, train_labels)

Теперь модель обучена и мы можем проверить, насколько она хорошо работает, сравнив результат ее предсказания и действительные значения на валидационной выборке.

>>> validation_samples = list_of_samples[TRAIN_SAMPLES_NUM:]
>>> validation_labels = list(map(to_interim_label, list_of_labels[TRAIN_SAMPLES_NUM:]))

>>> squared_errors = []
>>> for sample, label in zip(validation_samples, validation_labels):
>>>   prediction = model.predict(numpy.array(sample).reshape(1, -1))[0]
>>>   squared_errors.append((prediction - label) ** 2)

>>> mean_squared_error = sum(squared_errors) / len(squared_errors)
>>> print('Mean Squared Error: {0}'.format(mean_squared_error))
Mean Squared Error: 0.8251727558694294

Велика или мала полученная ошибка? Априори это сказать невозможно, однако с учетом того, что мы используем всего две фичи и одну из самых простых моделей — скорее всего наше предсказание далеко от оптимального. Давайте попробуем поэкспериментировать, например добавив новых фичей. Например в базовом файле с данными помимо уже использованных двух полей (bracket_pricing и quantity) имеется (пусть и с не очень часто отличающимися от 0 значениями) поле min_order_quantity. Давайте попробуем использовать его значение в качестве новой фичи для обучения алгоритма.

def get_min_order_quantity(instance):
  return [int(instance['min_order_quantity'])]

def to_sample(instance):
  return (is_bracket_pricing(instance) + get_quantity(instance)
             + get_min_order_quantity(instance))

Mean Squared Error: 0.8234554779286141

Как видим, ошибка немного уменьшилась, а значит — наш алгоритм улучшился. Не будем останавливатся на достигнутом и добавим одну за одной еще несколько фичей.

def get_annual_usage(instance):
  return [int(instance['annual_usage'])]

def to_sample(instance):
  return (is_bracket_pricing(instance) + get_quantity(instance)
             + get_min_order_quantity(instance) + get_annual_usage(instance))

Mean Squared Error: 0.8227852260998361

Следующее неиспользованное пока поле — quote_date, не имеет простой интерпретации в виде числа или набора чисел фиксированной длины. Поэтому придется немного подумать о том, как использовать его значение в качестве числового входа для алгоритма. Конечно, и год и месяц и день могут помочь в качестве новой фичи, но самым логичным первым приближением выглядит количество дней начиная с определенной даты — например с 1 января нулевого года, как дня, наступившего заведомо раньше самой ранней из встретившихся в файле дат. В первом приближении можно пока посчитать, что в году всегда 365 дней, а в каждом из 12 месяцев — 30. И пусть нас не смутит кажущаяся математическая некорректность этого предположения — мы всегда сможем уточнить формулы позже и посмотреть, улучшит ли соответствующая фича качество предсказания на валидационной выборке.

def get_absolute_date(instance):
  return [365 * int(instance['quote_date'].split('-')[0])
          + 12 * int(instance['quote_date'].split('-')[1])
          + int(instance['quote_date'].split('-')[2])]

def to_sample(instance):
  return (is_bracket_pricing(instance) + get_quantity(instance)
          + get_min_order_quantity(instance) + get_annual_usage(instance)
          + get_absolute_date(instance))

Mean Squared Error: 0.8216646342919645

Как видим, даже не вполне математически и астрономически корректная фича тем не менее помогла улучшить качество предсказания нашей модели. Теперь перейдем к пока еще новому для нас типу признаков, содержащихся в полях tube_assembly_id и supplier. Каждое из этих полей содержит значения идентификаторов предприятия-изготовителя и вендора. Они имеют не бинарную и не количественную природу, а описывают тип объекта из фиксированного списка. В машинном обучении такие свойства объектов и соответствующие им фичи наывают категориальными. Впрочем, категория предприятия-изготовителя сама по себе нам наврядли поможет, так как они не повторяются в файле test_set.csv, и мы совершенно правильно разделили размеченную выборку таким образом чтобы между тренировочной и валидационной частями (практически) не было соответствующего пересечения. Попробуем, тем не менее, извлечь что-то полезное из значения поля supplier. Посмотрим для начала какие вообще соответствующие коды встречаются в файле с размеченными данными.

>>> with open('./data/competition_data/train_set.csv') as input_stream:
...   header_line = input_stream.readline()
...   suppliers = set()
...   for line in input_stream:
...     new_supplier = line.split(',')[1]
...     suppliers.add(new_supplier)
...
>>> print(len(suppliers))
57
>>> print(suppliers)
{'S-0058', 'S-0013', 'S-0050', 'S-0011', 'S-0070', 'S-0104', 'S-0012', 'S-0068', 'S-0041', 'S-0023', 'S-0092', 'S-0095', 'S-0029', 'S-0051', 'S-0111', 'S-0064', 'S-0005', 'S-0096', 'S-0062', 'S-0004', 'S-0059', 'S-0031', 'S-0078', 'S-0106', 'S-0060', 'S-0090', 'S-0072', 'S-0105', 'S-0087', 'S-0080', 'S-0061', 'S-0108', 'S-0042', 'S-0027', 'S-0074', 'S-0081', 'S-0025', 'S-0024', 'S-0030', 'S-0022', 'S-0014', 'S-0054', 'S-0015', 'S-0008', 'S-0007', 'S-0009', 'S-0056', 'S-0026', 'S-0107', 'S-0066', 'S-0018', 'S-0109', 'S-0043', 'S-0046', 'S-0003', 'S-0006', 'S-0097'}

Как видим — их не так уж и много. Для начала можно попробовать поступить самым простым стандартным для таких случаев образом, а именно сопоставить каждому значению поля массив с одной единицей, стоящей на месте соответствующем идентификатору и всеми остальными значениями равными 0.

SUPPLIERS_LIST = ['S-0058', 'S-0013', 'S-0050', 'S-0011', 'S-0070', 'S-0104', 'S-0012', 'S-0068', 'S-0041', 'S-0023', 'S-0092', 'S-0095', 'S-0029', 'S-0051', 'S-0111', 'S-0064', 'S-0005', 'S-0096', 'S-0062', 'S-0004', 'S-0059', 'S-0031', 'S-0078', 'S-0106', 'S-0060', 'S-0090', 'S-0072', 'S-0105', 'S-0087', 'S-0080', 'S-0061', 'S-0108', 'S-0042', 'S-0027', 'S-0074', 'S-0081', 'S-0025', 'S-0024', 'S-0030', 'S-0022', 'S-0014', 'S-0054', 'S-0015', 'S-0008', 'S-0007', 'S-0009', 'S-0056', 'S-0026', 'S-0107', 'S-0066', 'S-0018', 'S-0109', 'S-0043', 'S-0046', 'S-0003', 'S-0006', 'S-0097']

def get_supplier(instance):
  if instance['supplier'] in SUPPLIERS_LIST:
    supplier_index = SUPPLIERS_LIST.index(instance['supplier'])
    result = [0] * supplier_index + [1] + [0] * (len(SUPPLIERS_LIST) - supplier_index - 1)
  else:
    result = [0] * len(SUPPLIERS_LIST)
  return result

def to_sample(instance):
  return (is_bracket_pricing(instance) + get_quantity(instance)
          + get_min_order_quantity(instance) + get_annual_usage(instance)
          + get_absolute_date(instance) + get_supplier(instance))

Mean Squared Error: 0.7992338454746866

Как мы видим, произошло значительное снижение средней ошибки. Скорее всего это означает, что в значении поля, которое мы добавили, заложена ценная для алгоритма информация и стоит попробовать воспользоваться ей позже еще несколько раз, но каким-нибудь менее банальным способом. Мы уже обсудили, что поле tube_assembly вряд ли поможет нам напрямую, однако стоит все-таки попробовать проверить это экспериментально.

def get_assembly(instance):
  assembly_id = int(instance['tube_assembly_id'].split('-')[1])
  result = [0] * assembly_id + [1] + [0] * (25000 - assembly_id - 1)
  return result

def to_sample(instance):
  return (is_bracket_pricing(instance) + get_quantity(instance)
          + get_min_order_quantity(instance) + get_annual_usage(instance)
          + get_absolute_date(instance) + get_supplier(instance)
          + get_assembly(instance))

Разумеется, конкретный способ конвертации идентификатора производителя в фичевектор выглядит несколько неуклюже, особенно с учетом количества возможных значений, однако если у читателя есть более конструктивный вариант использования этого поля напрямую и без привлечения дополнительных данных, то можно обсудить его в комментариях и даже попробовать применить в рамках имеющегося на данный момент кода. А мы вспомним о том, что помимо основного файла с обучающей выборкой в нашем распоряжении имеется еще несколько файлов с вспомогательными данными и попробуем поэкспериментировать с их привлечением. Например, чтобы нам было не очень обидно за относительную (хотя и предсказуемую) неудачу с прямым использованием значения поля tube_assembly_id, можно попытаться взять реванш, воспользовавшись данными, содержащимися в файле specs.csv и в анонимизированной форме описывающими соответствующие этим значениям спецификаторы.

head data/competition_data/specs.csv -n 20
tube_assembly_id,spec1,spec2,spec3,spec4,spec5,spec6,spec7,spec8,spec9,spec10
TA-00001,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA
TA-00002,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA
TA-00003,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA
TA-00004,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA
TA-00005,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA
TA-00006,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA
TA-00007,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA
TA-00008,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA
TA-00009,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA
TA-00010,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA
TA-00011,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA
TA-00012,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA
TA-00013,SP-0004,SP-0069,SP-0080,NA,NA,NA,NA,NA,NA,NA
TA-00014,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA
TA-00015,SP-0063,SP-0069,SP-0080,NA,NA,NA,NA,NA,NA,NA
TA-00016,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA
TA-00017,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA
TA-00018,SP-0007,SP-0058,SP-0070,SP-0080,NA,NA,NA,NA,NA,NA
TA-00019,SP-0080,NA,NA,NA,NA,NA,NA,NA,NA,NA

Похоже на то, что имеет место простое перечисление спецификаторов соответствующих каждому значению. Соответственно разумным выглядит вариант в качестве фичевектора выбрать массив из единиц и нулей, в котором единицы стоят на позициях, соответствующих имеющимся спецификаторам, а нули — всем остальным. Для того, чтобы использовать данные из дополнительного файла, нам придется сделать небольшое изменение в структуре кода.

def get_assembly_specs(instance, assembly_to_specs):
  result = [0] * 100
  for spec in assembly_to_specs[instance['tube_assembly_id']]:
    result[int(spec.split('-')[1])] = 1
  return result

def to_sample(instance, additional_data):
  return (is_bracket_pricing(instance) + get_quantity(instance)
          + get_min_order_quantity(instance) + get_annual_usage(instance)
          + get_absolute_date(instance) + get_supplier(instance)
          + get_assembly_specs(instance, additional_data['assembly_to_specs']))

def load_additional_data():
  result = dict()
  assembly_to_specs = dict()
  with open('data/competition_data/specs.csv') as input_stream:
    header_line = input_stream.readline()
    for line in input_stream:
      tube_assembly_id = line.split(',')[0]
      specs = []
      for spec in line.strip().split(',')[1:]:
        if spec != 'NA':
          specs.append(spec)
      assembly_to_specs[tube_assembly_id] = specs

  result['assembly_to_specs'] = assembly_to_specs
  return result

additional_data = load_additional_data()
list_of_samples = list(map(lambda x:to_sample(x, additional_data), list_of_instances))

Mean Squared Error: 0.7754770419953809

Наши труды не пропали зря и целевая метрика улучшилась на 0.024, что на текущем этапе уже совсем неплохо. На этом, пожалуй, можно пока что остановиться с оптимизацией алгоритма и обсудить вопрос простого предоставления удобного API к обученным алгоритмам для прикладного программиста.

Сперва сохраним обученную модель на диск.

with open('./data/model.mdl', 'wb') as output_stream:
  output_stream.write(pickle.dumps(model))

Теперь создадим скрипт generate_response.py, в котором воспользуемся полученными ранее наработками.

import pickle
import numpy
import research

class FinalModel(object):
  def __init__(self, model, to_sample, additional_data):
    self._model = model
    self._to_sample = to_sample
    self._additional_data = additional_data
  def process(self, instance):
    return self._model.predict(numpy.array(self._to_sample(
                                instance, self._additional_data)).reshape(1, -1))[0]

if __name__ == '__main__':
  with open('./data/model.mdl', 'rb') as input_stream:
    model = pickle.loads(input_stream.read())
  additional_data = research.load_additional_data()
  final_model = FinalModel(model, research.to_sample, additional_data)
  print(final_model.process({'tube_assembly_id':'TA-00001', 'supplier':'S-0066',
                             'quote_date':'2013-06-23', 'annual_usage':'0',
                             'min_order_quantity':'0', 'bracket_pricing':'Yes', 
                             'quantity':'1'}))

2.357692493326624

Теперь аналогичным образом прикладной программист может подгрузить модель и сопутствующие переменные в любом скрипте и использовать для предсказания значения на новых входящих данных.

В следующей серии — дальнейшая оптимизация модели, больше фичей, более сложные алгоритмы и настройка их гиперпараметров, при необходимости — продвинутые валидационные процедуры.
  • +11
  • 10.4k
  • 4
Share post
AdBlock has stolen the banner, but banners are not teeth — they will be back

More
Ads

Comments 4

    0
    Зачем обучаться с предположением, что у вас 100 видов материалов, когда в обучающей выборке их 57? У оставшихся 43 фич ведь будут случайные веса и, если при использовании отсутствующие при обучении типы материалов возникнут, они могут непредсказуемо повлиять на качество предсказания. Не лучше было бы при подгрузке материалов трубы проверять, присутствовали ли эти материалы при обучении?
      0
      Ну это первый, пристрелочный вариант. Нужный для того, чтобы показать общий стиль мышления и то, как выкатить потом все это дело в прод. Планируется вторая (а может быть и дальнейшие) части — там буду уже заниматься более тонкой оптимизацией, глубже экспериментировать с фичами и моделями. В частности и избавляться от лишних неинформативных значений. Спасибо большое за замечание, оно сэкономит мне немного времени в будущем :)
      0
      Спасибо за интересную статью. Возможно, лучше было бы использовать модуль csv из стандартной библиотеки для работы с csv файлами,
        0
        Я рад, что Вам понравилось. Буду писать еще.

        Я предпочитаю контролировать код, который пишу насколько это возможно и использую библиотеки типа pandas или csv только в случаях, когда вручную быстро написать нужный участок кода проблематично — например в случае очень больших файлов или текстовых значений с запятыми. Довольно печально потратить несколько часов (а то и дней) на поиски причины несоответствия метрики на валидации и в продакшене и, наконец, обнаружить, что проблема состояла в том, что та или иная библиотечная функция работает «немного» не так, как предполагалось.

      Only users with full accounts can post comments. Log in, please.