На Хабре было много упоминаний pgvector в обзорах Postgresso. И каждый раз новость была про место которое где‑то за границей и далеко. Многие коммерческие решения для хранения и поиска векторов в базе данных нынче не доступны, а pgvector доступен любому, тем более в самой популярной базе в России. Применим pgvector для задачи поиска похожих домов по инфраструктуре для детей в Москве.
В этой статье покажу на этом практическом примере как хранить, кластеризовать алгоритмом DBSCAN
векторы и искать по ним в базе данных. В примере задача с векторами на грани типичного хранения и обработки результатов работы нейросетевых моделей в базе данных.
Прежде всего надо установить pgvector в PostgreSQL, он доступен в виде расширения. Поскольку я работаю с базой данных из Docker, то могу просто добавить в Dockerfile строчки и пересобрать образ:
RUN git clone --branch v0.5.1 https://github.com/pgvector/pgvector.git
RUN cd pgvector && make && make install
А в самой базе данных, нужно загрузить расширение:
osmworld=# CREATE EXTENSION vector;
CREATE EXTENSION
Time: 32,606 ms
Данные для векторов можно получить, например, из модели машинного обучения в python скрипте или ML модели в spark и вставить в таблицу с колонкой типа vector. А можно создать в SQL как гистограмму определенных категорий. В этом случае можно значения в массивах real[]
, integer[]
или double precision[]
, numeric[]
привести к типу ::vector
.
Данными для примера послужат гистограммы числа объектов детской инфраструктуры в окрестностях жилых домов в Москве. Про то как рассчитать эти данные я рассказывал здесь, а в этой публикации я просто возьму готовые данные и создам из них таблицу с колонкой типа одинадцатимерный vector:
create table infrastructure_for_children_features as
select (row_number() over ())::integer id, null::integer cluster,
district, street, housenumber,
ARRAY[kindergarten::integer, school::integer,college::integer, university::integer, language_school::integer, music_school::integer,training::integer,sports_centre::integer,community_centre::integer,playground::integer,clinic::integer]
::vector(11) feature
from infrastructure_for_children;
alter table infrastructure_for_children_features add primary key(id);
Так в базе создал таблицу на 30237 записей со структурой:
osmworld=# \d infrastructure_for_children_features
Table "public.infrastructure_for_children_features"
Column | Type |
-------------+------------|
id | integer |
cluster | integer |
district | text |
street | text |
housenumber | text |
feature | vector(11) |
Indexes:
"infrastructure_for_children_features_pkey" PRIMARY KEY, btree (id)
Теперь хотелось бы объединить их в группы по близости векторов. Опять же можно использовать нейросети, а можно использовать классические алгоритмы кластеризации — метод k‑средних(k‑means) или основанную на плотности пространственную кластеризацию для приложений с шумами (DBSCAN). Для метрики близости использую Евклидово расстояние между векторами из колонки feature.
Поскольку число кластеров мне не известно, то я выберу метод DBSCAN и прогоню этот набор данных через него чтобы посмотреть зависимость от epsilon числа групп и число элементов не попавших в группы:
eps|clusters|not_in_cluster
0.0 75 29667
0.5 75 29667
1.0 202 28648
1.5 475 26904
2.0 928 22630
2.5 1227 17620
3.0 1173 11778
3.5 856 7605
4.0 601 4562
4.5 364 2760
5.0 232 1574
5.5 138 972
6.0 77 604
6.5 51 377
7.0 29 265
7.5 14 168
8.0 10 96
8.5 4 54
9.0 4 37
9.5 2 30
На свой субъективный взгляд выберу eps=5.5 и запущу Java программу, которая заполнит колонку cluster значениями алгоритма DBSCAN для minPoints=3 и eps=5.5:
package com.github.isuhorukov;
import com.pgvector.PGvector;
import org.apache.commons.math3.ml.clustering.Cluster;
import org.apache.commons.math3.ml.clustering.Clusterable;
import org.apache.commons.math3.ml.clustering.DBSCANClusterer;
import org.apache.commons.math3.ml.distance.EuclideanDistance;
import java.sql.*;
import java.util.ArrayList;
import java.util.List;
public class Main {
public static void main(String[] args) throws Exception {
try (Connection connection = DriverManager.getConnection(
System.getenv("jdbc_url"), System.getenv("user"), System.getenv("password"))) {
connection.setAutoCommit(false);
PGvector.addVectorType(connection);
float eps = Float.parseFloat(System.getenv("eps"));
int minPoints = Integer.parseInt(System.getenv("minPoints"));
DBSCANClusterer<Feature> dbscanClusterer = new DBSCANClusterer<>(eps,minPoints,new EuclideanDistance());
List<Feature> features = fetchFeatures(connection,
"select id,feature from infrastructure_for_children_features");
List<Cluster<Feature>> cluster = dbscanClusterer.cluster(features);
saveClusters(connection, cluster);
}
}
private static void saveClusters(Connection connection, List<Cluster<Feature>> cluster) throws SQLException {
try (PreparedStatement clusterPs = connection.prepareStatement(
"update infrastructure_for_children_features set cluster = ? where id = ?")){
for (int idx = 0; idx < cluster.size(); idx++) {
List<Feature> featureCluster = cluster.get(idx).getPoints();
for (Feature feature : featureCluster) {
clusterPs.setInt(1, idx);
clusterPs.setInt(2, feature.id);
clusterPs.addBatch();
}
clusterPs.executeBatch();
}
connection.commit();
} catch (Exception e) {
connection.rollback();
throw new RuntimeException(e);
}
}
private static List<Feature> fetchFeatures(Connection connection, String query) {
List<Feature> features = new ArrayList<>();
try (Statement statement = connection.createStatement();
ResultSet resultSet = statement.executeQuery(query))
{
while (resultSet.next()) {
int id = resultSet.getInt(1);
float[] feature = ((PGvector) resultSet.getObject(2)).toArray();
features.add(new Feature(id, feature));
}
} catch (Exception e) {
throw new RuntimeException(e);
}
return features;
}
static class Feature implements Clusterable {
public int id;
public double[] feature;
public Feature(int id, float[] feature) {
this.id = id;
this.feature = new double[feature.length];
for (int i = 0; i < feature.length; i++) {
this.feature[i] = feature[i];
}
}
@Override
public double[] getPoint() {
return feature;
}
}
}
Для компиляции программы нужен pom.xml для maven:
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>com.github.igor-suhorukov</groupId>
<artifactId>vectors</artifactId>
<version>1.0-SNAPSHOT</version>
<properties>
<maven.compiler.source>11</maven.compiler.source>
<maven.compiler.target>11</maven.compiler.target>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
</properties>
<dependencies>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-math3</artifactId>
<version>3.6.1</version>
</dependency>
<dependency>
<groupId>com.pgvector</groupId>
<artifactId>pgvector</artifactId>
<version>0.1.3</version>
</dependency>
<dependency>
<groupId>org.postgresql</groupId>
<artifactId>postgresql</artifactId>
<version>42.6.0</version>
</dependency>
</dependencies>
</project>
Внутри этой программы происходит следующее:
Устанавливается соединение с базой данных используя URL соединения, имя пользователя и пароль, полученные из переменных окружения. Отключается режим автоматической фиксации транзакций соединения.
Регистрируется пользовательский тип вектора
PGvector
в драйвере данных.Вызывается метод
fetchFeatures
для извлечения списка характеристик из базы данных из созданной нами ранее таблицы infrastructure_for_children_features.Значения
eps
иminPoints
извлекаются из переменных окружения и преобразуются в типыfloat
иint
соответственно.Создается экземпляр
DBSCANClusterer
из commons-math3 с использованиемeps
,minPoints
и метрики расстояния (EuclideanDistance
). Вызывается методcluster
объектаDBSCANClusterer
для кластеризации записей, полученных ранее из базы в методеfetchFeatures
.Вызывается метод
saveClusters
для сохранения кластеров в базе данных в поле cluster таблицы infrastructure_for_children_features, учитывая номер кластера и id каждой записи из результатов алгоритма кластеризации. Если данные успешно записались, то транзакция фиксируется, в случае ошибки - транзакция откатывается.
Найдем в базе данных дом со скриншота по идентификатору:
osmworld=# select * from infrastructure_for_children_features where id=831;
id | cluster | district | street | housenumber | feature
-----+---------+-------------+--------------------+-------------------+-----------------------+-------------+---------------------------------
831 | 4 | Пресненский район | Большая Бронная улица | 19 | [37,28,3,23,6,4,17,23,9,148,92]
(1 row)
А теперь в PostgreSQL найдем 10 ближайших к нему домов по значению вектора:
osmworld=# select id, cluster, district,street,housenumber from infrastructure_for_children_features order by feature<-> '[37,28,3,23,6,4,17,23,9,148,92]' limit 10;
id | cluster | district | street | housenumber
------+---------+-------------------+------------------------------+-------------
831 | 4 | Пресненский район | Большая Бронная улица | 19
1011 | 4 | Пресненский район | Большая Бронная улица | 17
897 | 4 | Пресненский район | Сытинский переулок | 5/10 с3
827 | 4 | Пресненский район | Богословский переулок | 8/15
1019 | 4 | Пресненский район | Большая Бронная улица | 16
823 | 4 | Тверской район | Малый Палашёвский переулок | 4
893 | 4 | Пресненский район | Большой Козихинский переулок | 4
631 | 4 | Пресненский район | Большая Бронная улица | 25 с3
821 | 4 | Пресненский район | Сытинский переулок | 5/10 с4
1117 | 4 | Пресненский район | Богословский переулок | 5
(10 rows)
Time: 25,777 ms
osmworld=# explain select id, cluster, district,street,housenumber from infrastructure_for_children_features order by feature<-> '[37,28,3,23,6,4,17,23,9,148,92]' limit 10;
QUERY PLAN
--------------------------------------------------------------------------------------------------------
Limit (cost=2375.37..2375.40 rows=10 width=90)
-> Sort (cost=2375.37..2450.97 rows=30237 width=90)
Sort Key: ((feature <-> '[37,28,3,23,6,4,17,23,9,148,92]'::vector))
-> Seq Scan on infrastructure_for_children_features (cost=0.00..1721.96 rows=30237 width=90)
(4 rows)
То есть, можно найти в городе похожие дома на указанный в запросе, по доступной в окрестностях инфраструктуре.
Можно ли ускорить поиск по векторам? Да, расширение поддерживает индексы IVFFlat и HNSW. Попробуем HNSW он более быстрый и точный, если верить научным публикациям:
osmworld=# CREATE INDEX ON infrastructure_for_children_features USING hnsw (feature vector_l2_ops) WITH (m = 16, ef_construction = 64);
CREATE INDEX
Time: 12705,611 ms (00:12,706)
osmworld=# \d infrastructure_for_children_features
Table "public.infrastructure_for_children_features"
Column | Type | Collation | Nullable | Default
-------------+------------------+-----------+----------+---------
id | integer | | not null |
cluster | integer | | |
lon | double precision | | |
lat | double precision | | |
district | text | | |
street | text | | |
housenumber | text | | |
feature | vector(11) | | |
Indexes:
"infrastructure_for_children_features_pkey" PRIMARY KEY, btree (id)
"infrastructure_for_children_features_feature_idx" hnsw (feature vector_l2_ops) WITH (m='16', ef_construction='64')
osmworld=# explain select id, cluster, district,street,housenumber from infrastructure_for_children_features order by feature<-> '[37,28,3,23,6,4,17,23,9,148,92]' limit 10;
QUERY PLAN
-----------------------------------------------------------------------------------------------------------------------------------------------------------
Limit (cost=5.00..5.61 rows=10 width=90)
-> Index Scan using infrastructure_for_children_features_feature_idx on infrastructure_for_children_features (cost=5.00..1861.36 rows=30237 width=90)
Order By: (feature <-> '[37,28,3,23,6,4,17,23,9,148,92]'::vector)
(3 rows)
Time: 1,022 ms
Запрос стал использовать этот индекс и поиск по векторам стал быстрее на порядок по сравнению с seqscan, и в этом конкретном примере угадал с параметрами для DBSCAN:
osmworld=# select id, cluster, district,street,housenumber from infrastructure_for_children_features order by feature<-> '[37,28,3,23,6,4,17,23,9,148,92]' limit 10;
id | cluster | district | street | housenumber
------+---------+-------------------+------------------------------+-------------
831 | 4 | Пресненский район | Большая Бронная улица | 19
1011 | 4 | Пресненский район | Большая Бронная улица | 17
827 | 4 | Пресненский район | Богословский переулок | 8/15
897 | 4 | Пресненский район | Сытинский переулок | 5/10 с3
1019 | 4 | Пресненский район | Большая Бронная улица | 16
823 | 4 | Тверской район | Малый Палашёвский переулок | 4
893 | 4 | Пресненский район | Большой Козихинский переулок | 4
631 | 4 | Пресненский район | Большая Бронная улица | 25 с3
821 | 4 | Пресненский район | Сытинский переулок | 5/10 с4
1117 | 4 | Пресненский район | Богословский переулок | 5
(10 rows)
Time: 1,644 ms
Вывод
Расширение pgvector PostgreSQL оказалось простым в использовании и с ним можно работать не только алгоритмами машинного обучения, но и классическими алгоритмами кластеризации из программы по JDBC, а так же быстро искать используя поиск по близости векторов и специализированный индекс HNSW.