Коротко про рекомендательные системы

Глобально существует два подхода в создании рекомендательных систем. Контентно-ориентированная и коллаборативная фильтрация. Основополагающее предположение подхода коллаборативной фильтрации заключается в том, что если А и В покупают аналогичные продукты, А, скорее всего, купит продукт, который купил В, чем продукт, который купил случайный человек. В отличие от контентно-ориентированного подхода, здесь нет признаков, соответствующих пользователям или предметам. Рекомендательная система базируется на матрице взаимодействий пользователей. Контентно-ориентированная система базируется на знаниях о предметах. Например если пользователь смотрит шелковые футболки возможно ему будет интересно посмотреть на другие шелковые футболки.

В этой статье я хочу рассказать о подходе который основан на поиске схожих изображений. Зачем подготавливать дополнительнительные данные если почти все основные характеристики некоторых товаров, например одежда, можно отобразить на изображении.

исключение softmax

Суть подхода заключается в извлчении признаков из изображений товаров. С помощью сверточной сети, в своем примере я использовал Resnet50, так как вектор признаков resnet имеет относительно небольшую размерность. Извлечь вектор признаков с помощью обученой сети очень просто. Нужно просто исключить softmax классификатор именно он определяет к какому классу относится изображение и мы получим на выходе вектор признаков. Далее необходимо сравнивать векторы и искать похожие. Чем более схожи изображения тем меньше евклидово расстояние между векторами.

Код и датасет

Датасет можно скачать отсюда: ссылка на датасет.

Инициализации обученой restnet50 из библиотеки pytorch и извлечении признаков из датасета:

from torchvision.io import read_image
from torchvision.models import resnet50, ResNet50_Weights
import torch
import glob
import pickle
from tqdm import tqdm
from PIL import Image

def pil_loader(path):
    # Некоторые изображения из датасета представленны не в RGB формате, необходимо их конверитровать в RGB
    with open(path, 'rb') as f:
        img = Image.open(f)
        return img.convert('RGB')


# Инициализация модели обученой на датасете imagenet
weights = ResNet50_Weights.DEFAULT
model = resnet50(weights=weights)
model.eval()
preprocess = weights.transforms()

use_precomputed_embeddings = True
emb_filename = 'fashion_images_embs.pickle'
if use_precomputed_embeddings: 
    with open(emb_filename, 'rb') as fIn:
        img_names, img_emb_tensors = pickle.load(fIn)  
    print("Images:", len(img_names))
else:
    img_names  = list(glob.glob('images/*.jpg'))
    img_emb = []
    # извлечение признаков из изображений в датасете. У меня на CPU заняло около часа
    for image in tqdm(img_names):
        img_emb.append(
            model(preprocess(pil_loader(image)).unsqueeze(0)).squeeze(0).detach().numpy()
        )
    img_emb_tensors = torch.tensor(img_emb)
    
    with open(emb_filename, 'wb') as handle:
        pickle.dump([img_names, img_emb_tensors], handle, protocol=pickle.HIGHEST_PROTOCOL)

Функция которая создает поисковый индекс с помощью faiss и уменьшает размерность векторов признаков:

# Для сравнения векторов используется faiss
import faiss                   
from sklearn.decomposition import PCA

def build_compressed_index(n_features):
    pca = PCA(n_components=n_features)
    pca.fit(img_emb_tensors)
    compressed_features = pca.transform(img_emb_tensors)
    dataset = np.float32(compressed_features)
    d = dataset.shape[1]
    nb = dataset.shape[0]
    xb = dataset

    index_compressed = faiss.IndexFlatL2(d)
    index_compressed.add(xb)
    return [pca, index_compressed]

Хэлперы для отображения результатов:

import matplotlib.pyplot as plt
import matplotlib.image as mpimg

def main_image(img_path, desc):
    plt.imshow(mpimg.imread(img_path))
    plt.xlabel(img_path.split('.')[0] + '_Original Image',fontsize=12)
    plt.title(desc,fontsize=20)
    plt.show()

def similar_images(indices, suptitle):
    plt.figure(figsize=(15,10), facecolor='white')
    plotnumber = 1    
    for index in indices[0:4]:
        if plotnumber<=len(indices) :
            ax = plt.subplot(2,2,plotnumber)
            plt.imshow(mpimg.imread(img_names[index]))
            plt.xlabel(img_names[index],fontsize=12)
            plotnumber+=1
    plt.suptitle(suptitle,fontsize=15)
    plt.tight_layout()

Сама функция поиска. Принимает на вход количество признаков, что бы можно было поэксперементировать с достаточным количеством признаков:

import numpy as np
# поиск, можно искать по индексу из предварительно извлеченных изображений или передать новое изображение
def search(query, factors):
    if(type(query) == str):
        img_path = query
    else:
        img_path = img_names[query]
    one_img_emb = torch.tensor(model(preprocess(read_image(img_path)).unsqueeze(0)).squeeze(0).detach().numpy())
    main_image(img_path, 'Query')
    compressor, index_compressed = build_compressed_index(factors)
    D, I = index_compressed.search(np.float32(compressor.transform([one_img_emb.detach().numpy()])),5)
    similar_images(I[0][1:], "faiss compressed " + str(factors))

Виновник торжества. Вызов поиска:

search(100,300)
search("t-shirt.jpg", 500)

Выводы

В итоге за пару часов можно собрать довольно качественную рекомендательную систему основаную на схожести изображений, чего достаточно для некоторых случаев. Изображения не требуют предварительной подготовки, разметки и какой то метаинформации что значительно упрощает процесс.

Для повышения качества рекомендаций можно дообучить некторые слои сети на используемом датасете.