Pull to refresh

Пишем умный поиск по коду с Open AI

Level of difficultyMedium
Reading time6 min
Views9K

В этой статье мы кратко рассмотрим технологию, которая лежит в основе ChatGPT — эмбеддинги, и напишем простой интеллектуальный поиск по кодовой базе проекта.

Эмбеддинг (от англ. embedding) — это процесс преобразования слов или текста в набор чисел – числовой вектор. Векторы можно сравнивать между собой, чтобы определить насколько два текста или слова похожи по смыслу.

К примеру, возьмем два числовых вектора (эмбеддинга) слов «отдать» и «подарить». Слова разные, но смысл схож, т.е. они взаимосвязаны, и результатом обоих будет передача чего-то кому-то. 

В контексте кодовой базы проекта эмбеддинги можно использовать для поиска по коду или документации. Например, можно векторизовать функции, классы, методы и документацию, а затем сравнивать их векторы с вектором запроса поиска, чтобы находить релевантные функции или классы.

Нам понадобится аккаунт Open AI и токен. Если у вас еще нет аккаунта, то можете зарегистрироваться на официальном сайте Open AI. После регистрации и подтверждения аккаунта пройдите в разделе профиля API Keys и сгенерируйте API токен.

На старте дают $18 — мне этого хватило, чтобы сделать пример для этой статьи (ниже) и провести дальнейшее тестирование сервиса.

Выберите какой-нибудь проект на TypeScript в качестве кодовой базы. Рекомендую взять небольшой, чтобы не томить себя в ожиданиях генерации векторов. Или можете воспользоваться уже готовым. Еще нам нужен Python 3+ версии и библиотека от Open AI. Не пугайтесь, если не знаете какой-то язык — примеры будут простыми и не требуют глубокого понимания TypeScript и Python.

Приступим. Для начала напишем код для извлечения различных фрагментов кода из проекта, например, функции. TypeScript предоставляет удобный API компилятор для работы с AST деревом, что упрощает задачу. Установим csv-stringify библиотеку для генерации CSV:

$ npm install csv-stringify

Пишем извлечение информации из кода:

const path = require('path');
const ts = require('typescript');
const csv = require('csv-stringify/sync');
 
const cwd = process.cwd();
const configJSON = require(path.join(cwd, 'tsconfig.json'));
const config = ts.parseJsonConfigFileContent(configJSON, ts.sys, cwd);
const program = ts.createProgram(
    config.fileNames, 
    config.options, 
    ts.createCompilerHost(config.options)
);
const checker = program.getTypeChecker();

const rows = [];

const addRow = (fileName, name, code, docs = '') => rows.push({
    file_name: path.relative(cwd, fileName),
    name,
    code,
    docs
});

function addFunction(fileName, node) {
    const symbol = checker.getSymbolAtLocation(node.name);
    if (symbol) {
        const name = symbol.getName();
        const docs = getDocs(symbol);
        const code = node.getText();
        addRow(fileName, name, code, docs);
    }
}

function addClass(fileName, node) {
    const symbol = checker.getSymbolAtLocation(node.name);
    if (symbol) {
        const name = symbol.getName();
        const docs = getDocs(symbol);
        const code = `class ${name} {}`;
        addRow(fileName, name, code, docs);
        node.members.forEach(m => addClassMember(fileName, name, m));
    }
}

function addClassMember(fileName, className, node) {
    const symbol = checker.getSymbolAtLocation(node.name);
    if (symbol) {
        const name = className + ':' + symbol.getName();
        const docs = getDocs(symbol);
        const code = node.getText();
        addRow(fileName, name, code, docs);
    }
}

function addInterface(fileName, node) {
    const symbol = checker.getSymbolAtLocation(node.name);
    if (symbol) {
        const name = symbol.getName();
        const docs = getDocs(symbol);
        const code = `interface ${name} {}`;
        addRow(fileName, name, code, docs);
        node.members.forEach(m => addInterfaceMember(fileName, name, m));
    }
}

function addInterfaceMember(fileName, interfaceName, node) {
    if (!ts.isPropertySignature(node) || !ts.isMethodSignature(node)) {
        return;
    }
    const symbol = checker.getSymbolAtLocation(node.name);
    if (symbol) {
        const name = interfaceName + ':' + symbol.getName();
        const docs = getDocs(symbol);
        const code = node.getText();
        addRow(fileName, name, code, docs);
    }
}

function getDocs(symbol) {
    return ts.displayPartsToString(symbol.getDocumentationComment(checker));
}

for (const fileName of config.fileNames) {
    const sourceFile = program.getSourceFile(fileName);
    const visitNode = node => {
        if (ts.isFunctionDeclaration(node)) {
            addFunction(fileName, node);
        } else if (ts.isClassDeclaration(node)) {
            addClass(fileName, node);
        } else if (ts.isInterfaceDeclaration(node)) {
            addInterface(fileName, node);
        }
        ts.forEachChild(node, visitNode);
    };
    ts.forEachChild(sourceFile, visitNode);
}

for (const row of rows) {
    row.combined = '';
    if (row.docs) {
        row.combined += `Code documentation: ${row.docs}; `;
    }
    row.combined += `Code: ${row.code}; Name: ${row.name};`;
}

const output = csv.stringify(rows, {
    header: true
});

console.log(output);

Скрипт собирает все нужные нам фрагменты и выдает CSV таблицу в консоль. CSV файл состоит из колонок file_name, name, code, docs, combined.

  • file_name - здесь будет содержаться путь до файла в проекте,

  • name - название фрагмента, к примеру, «имя функции»,

  • code - код сущности,

  • docs - описание из комментариев к фрагменту,

  • combined - это сложение контента code и docs колонок — мы будем использовать эту колонку для генерации векторов.

Запускать его не нужно.

Переходим к Python.

Установим библиотеку от Open AI и утилиты для работы с эмбеддингами:

$ pip install openai[embeddings]

Создаем файл create_search_db.py со следующим кодом:

from io import StringIO
from subprocess import PIPE, run
from pandas import read_csv
from openai.embeddings_utils import get_embedding as _get_embedding
from tenacity import wait_random_exponential, stop_after_attempt

get_embedding = _get_embedding.retry_with(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(10))

if __name__ == '__main__':
	# 1
	result = run(['node', 'code-to-csv.js'], stdout=PIPE, stderr=PIPE, universal_newlines=True)
	if result.returncode != 0:
	    raise RuntimeError(result.stderr)
	# 2
	db = read_csv(StringIO(result.stdout))
	# 3
	db['embedding'] = db['combined'].apply(lambda x: get_embedding(x, engine='text-embedding-ada-002'))
	# 4
	db.to_csv("search_db.csv", index=False)

Скрипт запускается code-to-csv.js(1), загружается результат в датафрейм(2) и генерирует векторы для контента в колонке combined(3). Векторы записываются в embedding колонку. Итоговая таблица со всем нужным для поиска сохраняется в файл search_db.csv(4).

Для работы скрипта понадобится API токен. Библиотека openai автоматически может брать токен из переменных окружения, поэтому мы напишем удобный скрипт, чтобы не записывать токен в окружение вручную:

export OPENAI_API_KEY=ВашТокен

Сохранить его куда-нибудь, к примеру в env.sh, и запустим:

$ source env.sh

Все готово для генерации базы поиска.

Запускаем скрипт create_search_db.py и ждем пока появится CSV файл с базой. Это может занять пару минут. После можно приступать к написанию поисковика.

Создаем новый файл search.py и пишем следующее:

import sys
import numpy as np
from pandas import read_csv
from openai.embeddings_utils import cosine_similarity, get_embedding as _get_embedding
from tenacity import  stop_after_attempt, wait_random_exponential

get_embedding = _get_embedding.retry_with(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(10))

def search(db, query):
	# 4
    query_embedding = get_embedding(query, engine='text-embedding-ada-002')
	# 5
    db['similarities'] = db.embedding.apply(lambda x: cosine_similarity(x, query_embedding))
    # 6
	db.sort_values('similarities', ascending=False, inplace=True)
    result = db.head(3)
    text = ""
    for row in result.itertuples(index=False):
        score=round(row.similarities, 3)
        if type(row.docs) == str:
            text += '/**\n * {docs}\n */\n'.format(docs='\n * '.join(row.docs.split('\n')))
        text += '{code}\n\n'.format(code='\n'.join(row.code.split('\n')[:7]))
        text += '[score={score}] {file_name}:{name}\n'.format(score=score, file_name=row.file_name, name=row.name)
        text += '-' * 70 + '\n\n'
    return text

if __name__ == '__main__':
	# 1
    db = read_csv('search_db.csv')
	# 2
    db['embedding'] = db.embedding.apply(eval).apply(np.array)
    query = sys.argv[1]
    print('')
	# 3
    print(search(db, query))

Разберем работу скрипта. Данные из search_db.csv загружаются в датафрейм(1), в объектно-ориентированное представление таблицы. Потом строки с векторами из таблицы конвертируются в массивы с числами(2), чтобы с ними можно было работать. В конце запускается функция поиска по базе со строкой запроса(3).

Функция поиска генерирует вектор для запроса(4), сравнивает этот вектор с каждым вектором из базы и сохраняет степень схожести в similarities колонку(5).

Степень схожести определяется числом от 0 до 1, где 1 означает максимальная подходящий вариант. Данные в таблице сортируются по similarities(6).

В заключении извлекаются первые три строки из базы и выводятся в консоль.

Поисковик готов, можно протестировать.

Для теста запускаем команду с запросом:

Пробуем ввести запрос на другом языке:

Как вы видите, поиск осуществляется с учетом значения слов в запросе, а не просто по ключевым словам.

Инструмент не ограничен только этим случаем и одним проектом. Можно организовать более масштабный поиск по всем проектам сразу. Это полезно, если вы каждый год разрабатываете по нескольку схожих в функционале приложений и хотели бы быстро находить готовые решения. Или у вас много документации, а поиск по ключевым словам не всегда эффективен. Все зависит от задач и сферы применения.

Благодарю за внимание!

Полезные ссылки:

Tags:
Hubs:
Total votes 10: ↑9 and ↓1+10
Comments3

Articles