Pull to refresh

На грани ИИ: пример поиска и обработки векторов в PostgreSQL + pgvector

Level of difficultyMedium
Reading time9 min
Views10K

На Хабре было много упоминаний pgvector в обзорах Postgresso. И каждый раз новость была про место которое где‑то за границей и далеко. Многие коммерческие решения для хранения и поиска векторов в базе данных нынче не доступны, а pgvector доступен любому, тем более в самой популярной базе в России. Применим pgvector для задачи поиска похожих домов по инфраструктуре для детей в Москве.

В этой статье покажу на этом практическом примере как хранить, кластеризовать алгоритмом DBSCANвекторы и искать по ним в базе данных. В примере задача с векторами на грани типичного хранения и обработки результатов работы нейросетевых моделей в базе данных.

Прежде всего надо установить pgvector в PostgreSQL, он доступен в виде расширения. Поскольку я работаю с базой данных из Docker, то могу просто добавить в Dockerfile строчки и пересобрать образ:

RUN git clone --branch v0.5.1 https://github.com/pgvector/pgvector.git
RUN cd pgvector && make && make install

А в самой базе данных, нужно загрузить расширение:

osmworld=# CREATE EXTENSION vector;
CREATE EXTENSION
Time: 32,606 ms

Данные для векторов можно получить, например, из модели машинного обучения в python скрипте или ML модели в spark и вставить в таблицу с колонкой типа vector. А можно создать в SQL как гистограмму определенных категорий. В этом случае можно значения в массивах real[], integer[] или double precision[], numeric[] привести к типу ::vector.

Данными для примера послужат гистограммы числа объектов детской инфраструктуры в окрестностях жилых домов в Москве. Про то как рассчитать эти данные я рассказывал здесь, а в этой публикации я просто возьму готовые данные и создам из них таблицу с колонкой типа одинадцатимерный vector:

create table infrastructure_for_children_features as 
  select (row_number() over ())::integer id, null::integer cluster,
          district, street, housenumber,
          ARRAY[kindergarten::integer, school::integer,college::integer, university::integer, language_school::integer, music_school::integer,training::integer,sports_centre::integer,community_centre::integer,playground::integer,clinic::integer]
   ::vector(11) feature 
from infrastructure_for_children;
alter table infrastructure_for_children_features add primary key(id);

Так в базе создал таблицу на 30237 записей со структурой:

osmworld=# \d infrastructure_for_children_features
   Table "public.infrastructure_for_children_features"
   Column    |    Type    | 
-------------+------------|
 id          | integer    | 
 cluster     | integer    | 
 district    | text       | 
 street      | text       | 
 housenumber | text       | 
 feature     | vector(11) |
Indexes:
    "infrastructure_for_children_features_pkey" PRIMARY KEY, btree (id)

Теперь хотелось бы объединить их в группы по близости векторов. Опять же можно использовать нейросети, а можно использовать классические алгоритмы кластеризации — метод k‑средних(k‑means) или основанную на плотности пространственную кластеризацию для приложений с шумами (DBSCAN). Для метрики близости использую Евклидово расстояние между векторами из колонки feature.

Поскольку число кластеров мне не известно, то я выберу метод DBSCAN и прогоню этот набор данных через него чтобы посмотреть зависимость от epsilon числа групп и число элементов не попавших в группы:

eps|clusters|not_in_cluster
0.0	75	29667
0.5	75	29667
1.0	202	28648
1.5	475	26904
2.0	928	22630
2.5	1227	17620
3.0	1173	11778
3.5	856	7605
4.0	601	4562
4.5	364	2760
5.0	232	1574
5.5	138	972
6.0	77	604
6.5	51	377
7.0	29	265
7.5	14	168
8.0	10	96
8.5	4	54
9.0	4	37
9.5	2	30

На свой субъективный взгляд выберу eps=5.5 и запущу Java программу, которая заполнит колонку cluster значениями алгоритма DBSCAN для minPoints=3 и eps=5.5:

package com.github.isuhorukov;

import com.pgvector.PGvector;
import org.apache.commons.math3.ml.clustering.Cluster;
import org.apache.commons.math3.ml.clustering.Clusterable;
import org.apache.commons.math3.ml.clustering.DBSCANClusterer;
import org.apache.commons.math3.ml.distance.EuclideanDistance;

import java.sql.*;
import java.util.ArrayList;
import java.util.List;

public class Main {
    public static void main(String[] args) throws Exception {
        try (Connection connection = DriverManager.getConnection(
                System.getenv("jdbc_url"), System.getenv("user"), System.getenv("password"))) {
            connection.setAutoCommit(false);
            PGvector.addVectorType(connection);
            float eps = Float.parseFloat(System.getenv("eps"));
            int minPoints = Integer.parseInt(System.getenv("minPoints"));
            DBSCANClusterer<Feature> dbscanClusterer = new DBSCANClusterer<>(eps,minPoints,new EuclideanDistance());
            List<Feature> features = fetchFeatures(connection, 
                                                   "select id,feature from infrastructure_for_children_features");
            List<Cluster<Feature>> cluster = dbscanClusterer.cluster(features);
            saveClusters(connection, cluster);
        }
    }

    private static void saveClusters(Connection connection, List<Cluster<Feature>> cluster) throws SQLException {
        try (PreparedStatement clusterPs = connection.prepareStatement(
             "update infrastructure_for_children_features set cluster = ? where id = ?")){
            for (int idx = 0; idx < cluster.size(); idx++) {
                List<Feature> featureCluster = cluster.get(idx).getPoints();
                for (Feature feature : featureCluster) {
                    clusterPs.setInt(1, idx);
                    clusterPs.setInt(2, feature.id);
                    clusterPs.addBatch();
                }
                clusterPs.executeBatch();
            }
            connection.commit();
        } catch (Exception e) {
            connection.rollback();
            throw new RuntimeException(e);
        }
    }

    private static List<Feature> fetchFeatures(Connection connection, String query) {
        List<Feature> features = new ArrayList<>();
        try (Statement statement = connection.createStatement();
             ResultSet resultSet = statement.executeQuery(query))
        {
            while (resultSet.next()) {
                int id = resultSet.getInt(1);
                float[] feature = ((PGvector) resultSet.getObject(2)).toArray();
                features.add(new Feature(id, feature));
            }
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
        return features;
    }

    static class Feature implements Clusterable {
        public int id;
        public double[] feature;
        public Feature(int id, float[] feature) {
            this.id = id;
            this.feature = new double[feature.length];
            for (int i = 0; i < feature.length; i++) {
                this.feature[i] = feature[i];
            }
        }

        @Override
        public double[] getPoint() {
            return feature;
        }
    }
}

Для компиляции программы нужен pom.xml для maven:

<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
         xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
         xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
    <modelVersion>4.0.0</modelVersion>

    <groupId>com.github.igor-suhorukov</groupId>
    <artifactId>vectors</artifactId>
    <version>1.0-SNAPSHOT</version>

    <properties>
        <maven.compiler.source>11</maven.compiler.source>
        <maven.compiler.target>11</maven.compiler.target>
        <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
    </properties>
    <dependencies>
        <dependency>
            <groupId>org.apache.commons</groupId>
            <artifactId>commons-math3</artifactId>
            <version>3.6.1</version>
        </dependency>
        <dependency>
            <groupId>com.pgvector</groupId>
            <artifactId>pgvector</artifactId>
            <version>0.1.3</version>
        </dependency>
        <dependency>
            <groupId>org.postgresql</groupId>
            <artifactId>postgresql</artifactId>
            <version>42.6.0</version>
        </dependency>
    </dependencies>
</project>

Внутри этой программы происходит следующее:

  • Устанавливается соединение с базой данных используя URL соединения, имя пользователя и пароль, полученные из переменных окружения. Отключается режим автоматической фиксации транзакций соединения.

  • Регистрируется пользовательский тип вектора PGvector в драйвере данных.

  • Вызывается метод fetchFeatures для извлечения списка характеристик из базы данных из созданной нами ранее таблицы infrastructure_for_children_features.

  • Значения eps и minPoints извлекаются из переменных окружения и преобразуются в типы float и int соответственно.

  • Создается экземпляр DBSCANClusterer из commons-math3 с использованием epsminPoints и метрики расстояния (EuclideanDistance). Вызывается метод cluster объекта DBSCANClusterer для кластеризации записей, полученных ранее из базы в методеfetchFeatures.

  • Вызывается метод saveClusters для сохранения кластеров в базе данных в поле cluster таблицы infrastructure_for_children_features, учитывая номер кластера и id каждой записи из результатов алгоритма кластеризации. Если данные успешно записались, то транзакция фиксируется, в случае ошибки - транзакция откатывается.

Чем же похожи эти районы, еще стоит выяснить или попробовать другие epsilon
Чем же похожи эти районы, еще стоит выяснить или попробовать другие epsilon
Один дом выбранный наугад в центре города
Один дом выбранный наугад в центре города

Найдем в базе данных дом со скриншота по идентификатору:

osmworld=# select * from infrastructure_for_children_features where id=831;
 id  | cluster |     district      |        street         | housenumber |             feature             
-----+---------+-------------+--------------------+-------------------+-----------------------+-------------+---------------------------------
 831 |       4 | Пресненский район | Большая Бронная улица | 19          | [37,28,3,23,6,4,17,23,9,148,92]
(1 row)

А теперь в PostgreSQL найдем 10 ближайших к нему домов по значению вектора:

osmworld=# select id, cluster, district,street,housenumber from infrastructure_for_children_features order by  feature<-> '[37,28,3,23,6,4,17,23,9,148,92]' limit 10;
  id  | cluster |     district      |            street            | housenumber 
------+---------+-------------------+------------------------------+-------------
  831 |       4 | Пресненский район | Большая Бронная улица        | 19
 1011 |       4 | Пресненский район | Большая Бронная улица        | 17
  897 |       4 | Пресненский район | Сытинский переулок           | 5/10 с3
  827 |       4 | Пресненский район | Богословский переулок        | 8/15
 1019 |       4 | Пресненский район | Большая Бронная улица        | 16
  823 |       4 | Тверской район    | Малый Палашёвский переулок   | 4
  893 |       4 | Пресненский район | Большой Козихинский переулок | 4
  631 |       4 | Пресненский район | Большая Бронная улица        | 25 с3
  821 |       4 | Пресненский район | Сытинский переулок           | 5/10 с4
 1117 |       4 | Пресненский район | Богословский переулок        | 5
(10 rows)

Time: 25,777 ms

osmworld=# explain select id, cluster, district,street,housenumber from infrastructure_for_children_features order by  feature<-> '[37,28,3,23,6,4,17,23,9,148,92]' limit 10;
                                               QUERY PLAN                                               
--------------------------------------------------------------------------------------------------------
 Limit  (cost=2375.37..2375.40 rows=10 width=90)
   ->  Sort  (cost=2375.37..2450.97 rows=30237 width=90)
         Sort Key: ((feature <-> '[37,28,3,23,6,4,17,23,9,148,92]'::vector))
         ->  Seq Scan on infrastructure_for_children_features  (cost=0.00..1721.96 rows=30237 width=90)
(4 rows)

То есть, можно найти в городе похожие дома на указанный в запросе, по доступной в окрестностях инфраструктуре.

Можно ли ускорить поиск по векторам? Да, расширение поддерживает индексы IVFFlat и HNSW. Попробуем HNSW он более быстрый и точный, если верить научным публикациям:

osmworld=# CREATE INDEX ON infrastructure_for_children_features USING hnsw (feature vector_l2_ops) WITH (m = 16, ef_construction = 64);
CREATE INDEX
Time: 12705,611 ms (00:12,706)

osmworld=# \d infrastructure_for_children_features
       Table "public.infrastructure_for_children_features"
   Column    |       Type       | Collation | Nullable | Default 
-------------+------------------+-----------+----------+---------
 id          | integer          |           | not null | 
 cluster     | integer          |           |          | 
 lon         | double precision |           |          | 
 lat         | double precision |           |          | 
 district    | text             |           |          | 
 street      | text             |           |          | 
 housenumber | text             |           |          | 
 feature     | vector(11)       |           |          | 
Indexes:
    "infrastructure_for_children_features_pkey" PRIMARY KEY, btree (id)
    "infrastructure_for_children_features_feature_idx" hnsw (feature vector_l2_ops) WITH (m='16', ef_construction='64')

osmworld=# explain select id, cluster, district,street,housenumber from infrastructure_for_children_features order by  feature<-> '[37,28,3,23,6,4,17,23,9,148,92]' limit 10;
                                                                        QUERY PLAN                                                                         
-----------------------------------------------------------------------------------------------------------------------------------------------------------
 Limit  (cost=5.00..5.61 rows=10 width=90)
   ->  Index Scan using infrastructure_for_children_features_feature_idx on infrastructure_for_children_features  (cost=5.00..1861.36 rows=30237 width=90)
         Order By: (feature <-> '[37,28,3,23,6,4,17,23,9,148,92]'::vector)
(3 rows)
Time: 1,022 ms

Запрос стал использовать этот индекс и поиск по векторам стал быстрее на порядок по сравнению с seqscan, и в этом конкретном примере угадал с параметрами для DBSCAN:

osmworld=# select id, cluster, district,street,housenumber from infrastructure_for_children_features order by  feature<-> '[37,28,3,23,6,4,17,23,9,148,92]' limit 10;
  id  | cluster |     district      |            street            | housenumber 
------+---------+-------------------+------------------------------+-------------
  831 |       4 | Пресненский район | Большая Бронная улица        | 19
 1011 |       4 | Пресненский район | Большая Бронная улица        | 17
  827 |       4 | Пресненский район | Богословский переулок        | 8/15
  897 |       4 | Пресненский район | Сытинский переулок           | 5/10 с3
 1019 |       4 | Пресненский район | Большая Бронная улица        | 16
  823 |       4 | Тверской район    | Малый Палашёвский переулок   | 4
  893 |       4 | Пресненский район | Большой Козихинский переулок | 4
  631 |       4 | Пресненский район | Большая Бронная улица        | 25 с3
  821 |       4 | Пресненский район | Сытинский переулок           | 5/10 с4
 1117 |       4 | Пресненский район | Богословский переулок        | 5
(10 rows)

Time: 1,644 ms

Вывод

Расширение pgvector PostgreSQL оказалось простым в использовании и с ним можно работать не только алгоритмами машинного обучения, но и классическими алгоритмами кластеризации из программы по JDBC, а так же быстро искать используя поиск по близости векторов и специализированный индекс HNSW.

Tags:
Hubs:
Total votes 12: ↑12 and ↓0+12
Comments8

Articles