Рекомендательная система через поиск схожих изображений с помощью Resnet50
Коротко про рекомендательные системы
Глобально существует два подхода в создании рекомендательных систем. Контентно-ориентированная и коллаборативная фильтрация. Основополагающее предположение подхода коллаборативной фильтрации заключается в том, что если А и В покупают аналогичные продукты, А, скорее всего, купит продукт, который купил В, чем продукт, который купил случайный человек. В отличие от контентно-ориентированного подхода, здесь нет признаков, соответствующих пользователям или предметам. Рекомендательная система базируется на матрице взаимодействий пользователей. Контентно-ориентированная система базируется на знаниях о предметах. Например если пользователь смотрит шелковые футболки возможно ему будет интересно посмотреть на другие шелковые футболки.
В этой статье я хочу рассказать о подходе который основан на поиске схожих изображений. Зачем подготавливать дополнительнительные данные если почти все основные характеристики некоторых товаров, например одежда, можно отобразить на изображении.
Суть подхода заключается в извлчении признаков из изображений товаров. С помощью сверточной сети, в своем примере я использовал Resnet50, так как вектор признаков resnet имеет относительно небольшую размерность. Извлечь вектор признаков с помощью обученой сети очень просто. Нужно просто исключить softmax классификатор именно он определяет к какому классу относится изображение и мы получим на выходе вектор признаков. Далее необходимо сравнивать векторы и искать похожие. Чем более схожи изображения тем меньше евклидово расстояние между векторами.
Код и датасет
Датасет можно скачать отсюда: ссылка на датасет.
Инициализации обученой restnet50 из библиотеки pytorch и извлечении признаков из датасета:
from torchvision.io import read_image
from torchvision.models import resnet50, ResNet50_Weights
import torch
import glob
import pickle
from tqdm import tqdm
from PIL import Image
def pil_loader(path):
# Некоторые изображения из датасета представленны не в RGB формате, необходимо их конверитровать в RGB
with open(path, 'rb') as f:
img = Image.open(f)
return img.convert('RGB')
# Инициализация модели обученой на датасете imagenet
weights = ResNet50_Weights.DEFAULT
model = resnet50(weights=weights)
model.eval()
preprocess = weights.transforms()
use_precomputed_embeddings = True
emb_filename = 'fashion_images_embs.pickle'
if use_precomputed_embeddings:
with open(emb_filename, 'rb') as fIn:
img_names, img_emb_tensors = pickle.load(fIn)
print("Images:", len(img_names))
else:
img_names = list(glob.glob('images/*.jpg'))
img_emb = []
# извлечение признаков из изображений в датасете. У меня на CPU заняло около часа
for image in tqdm(img_names):
img_emb.append(
model(preprocess(pil_loader(image)).unsqueeze(0)).squeeze(0).detach().numpy()
)
img_emb_tensors = torch.tensor(img_emb)
with open(emb_filename, 'wb') as handle:
pickle.dump([img_names, img_emb_tensors], handle, protocol=pickle.HIGHEST_PROTOCOL)
Функция которая создает поисковый индекс с помощью faiss и уменьшает размерность векторов признаков:
# Для сравнения векторов используется faiss
import faiss
from sklearn.decomposition import PCA
def build_compressed_index(n_features):
pca = PCA(n_components=n_features)
pca.fit(img_emb_tensors)
compressed_features = pca.transform(img_emb_tensors)
dataset = np.float32(compressed_features)
d = dataset.shape[1]
nb = dataset.shape[0]
xb = dataset
index_compressed = faiss.IndexFlatL2(d)
index_compressed.add(xb)
return [pca, index_compressed]
Хэлперы для отображения результатов:
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
def main_image(img_path, desc):
plt.imshow(mpimg.imread(img_path))
plt.xlabel(img_path.split('.')[0] + '_Original Image',fontsize=12)
plt.title(desc,fontsize=20)
plt.show()
def similar_images(indices, suptitle):
plt.figure(figsize=(15,10), facecolor='white')
plotnumber = 1
for index in indices[0:4]:
if plotnumber<=len(indices) :
ax = plt.subplot(2,2,plotnumber)
plt.imshow(mpimg.imread(img_names[index]))
plt.xlabel(img_names[index],fontsize=12)
plotnumber+=1
plt.suptitle(suptitle,fontsize=15)
plt.tight_layout()
Сама функция поиска. Принимает на вход количество признаков, что бы можно было поэксперементировать с достаточным количеством признаков:
import numpy as np
# поиск, можно искать по индексу из предварительно извлеченных изображений или передать новое изображение
def search(query, factors):
if(type(query) == str):
img_path = query
else:
img_path = img_names[query]
one_img_emb = torch.tensor(model(preprocess(read_image(img_path)).unsqueeze(0)).squeeze(0).detach().numpy())
main_image(img_path, 'Query')
compressor, index_compressed = build_compressed_index(factors)
D, I = index_compressed.search(np.float32(compressor.transform([one_img_emb.detach().numpy()])),5)
similar_images(I[0][1:], "faiss compressed " + str(factors))
Виновник торжества. Вызов поиска:
search(100,300)
search("t-shirt.jpg", 500)
Выводы
В итоге за пару часов можно собрать довольно качественную рекомендательную систему основаную на схожести изображений, чего достаточно для некоторых случаев. Изображения не требуют предварительной подготовки, разметки и какой то метаинформации что значительно упрощает процесс.
Для повышения качества рекомендаций можно дообучить некторые слои сети на используемом датасете.