Статья является продолжением «Пишем агента на Kotlin: KOSMOS», но может читаться независимо. Мотивация к написанию — сохранить читателю время на возню с фреймворками для решения относительно простой задачи.
Автор подразумевает у читателя теоретическое понимание того, что такое агент. Иначе лучше прочесть хотя бы начало предыдущей части.
Как и везде, в программирование важен маркетинг, поэтому обертку над HTTP-запросами в цикле называют революцией:
From Python to Kotlin: How JetBrains Revolutionized AI Agent Development. — reddit, medium.
Но в этом нет ничего революционного. Ниже хочу показать, как самостоятельно написать аналог Koog или Langchain4j. У вас не будет всех их фичей, зато будет очень простая и расширяемая система.
Содержание
Введение
Проблемы использования фреймворков
- Мета проблемы
- Сложная ментальная модель
- Запутанный синтаксис
Реализация агента на основе графов
- Упрощенная реализация Агента на основе графов
- Детальная реализация Агента на основе графов
- Добавление RAG
Когда использовать фреймворк, а не самописное решение?
Предисловие для тех, кто читал первую статью
Единственная часть, которую было сложно расширять и поддерживать в прошлой статье — сам агент. Тут мы напишем такое решение, чтобы агент собирался, как конструктор, и чтобы любую часть можно было легко вынести и переиспользовать в других агентах.
Если нет времени читать, можете глянуть PR с реализацией агента на графах и PR с добавлением RAG.
Предисловие для мобильных разработчиков
С 2015, когда я начал карьеру, постоянно появляются библиотеки-решения для организации UI-архитектуры на основе паттернов: MVC, MVP, MVVM, TEA, VIPER, Flux/Redux. Я пробовал все паттерны из перечисленных и смело могу сказать, что особой разницы нет, пока вся команда придерживается одного подхода. Но каждый раз использование чьей-то библиотеки приводило к страданиям. Потому что находился кейс, который библиотека упускала и не давала решить легко. Были баги, которые не починить без форка. Код библиотек запутан и сложен. Всегда проще было самому написать основу и жить с ней.
То же самое и с фреймворками по написанию ИИ-агентов. Лучше разберитесь с основами и напишите под себя легковесное решение, которое вы будете понимать и контролировать.
Предисловие для бэкендеров
Я встречал java-бэкенд разработчиков, которые несколько лет работали с Postgres, но не знают, как взять лок. Разгадка — фреймворк Hibernate. Такой разработчик может ходить в базу в цикле или внутри транзакций к базе выполнять HTTP-запросы. Конечно, есть исключения, но при прочих равных, фреймворк лишает понимания, задерживает развитие, способствует написанию неоптимального кода и даже негативно влияет на экологию (сколько энергии было потрачено на разработку и компиляцию этого фреймворка? SW).
То же самое и с фреймворками по написанию ИИ-агентов. Под капотом происходит всего лишь вызов нескольких ручек HTTP, а ваш фреймворк падает с out of memory.
Проблемы использования фреймворков
Приведу несколько соображений на примере Koog. Если вы и так понимаете, что фреймворк — это усложнение на ровном месте, пропускайте.
Frameworks are one of the hugest anti-patterns in software development. — Peter Krumins
Мета проблемы
Существуют проблемы, не относящиеся к сложности непосредственного использования API фреймворка. Вот несколько примеров:
Мы взяли Koog в KMP-проект, и он сломался о колено по конфликту версий kotlinx-datetime (issue, PR висит с начал сентября). Это решить сложнее, чем кажется, так как апи kotlinx-datetime еще в альфе и меняется, не заботясь об обратной совместимости.
Посмотрите, сколько всего вы затащите с Koog — libs.versions.toml. Всё что начинается с «0.», а это 7 библиотек на момент написания статьи, может пойти по пути из пункта 1.
Чтобы использовать Гигачат, придется писать клиента самому. Гигачат умеет работать с примерами для функций (тулов) — few_shot_examples. У Koog ни Annotation-based tools, ни Class-based tools это не умеют.
Из-за того, что фреймворк пытается поддерживать все популярные LLM, он допускает случайные ошибки в конкретных реализациях. В примере до этого фреймворк не учел возможности некоторых API, вроде Гигачат. В других случаях он просто падает в рантайме (Issue).
Фреймворки могут скрывать грязные приемы внутри. Вот тут Koog подхачивает промпты пользователя, чтобы стратегия работала, как полагается: «Don't chat with plain text! Call one of the available tools, instead: ${tools.joinToString(", ")». Я не хочу, чтобы фреймворк менял промпты, потому что это и так делает API LLM. Видимо, тут следствие проблемы из пункта 4.
Сложная ментальная модель
Агент — это про цикл взаимодействия между LLM, пользователем системы и вызовом функций (тулов). Где-то рядом можно прикрутить трейсинг, кеши и RAG. И в общем-то всё.
Но Koog настолько сложен, что падает с OOM на компиляции (Issue).
Если вы захотите быстренько разобраться с Koog, вам придется погрузиться в их концепцию и терминологию:
Agent, LLM, Message, Prompt, Attachment, System prompt, Context, Session,
Event, EventContext, EventHandler,
OpenTelemetry, LoggingSpanExporter, Sampler,
AgentMemory, Concept, Fact, FactType, Subject,
ToolArgs, ToolSet, Class-based tool, Function-based tool
Strategy, Graph, Subgraph, Node, Edge, Conditions
LLMEmbedder, JVMTextDocumentEmbedder, EmbeddingBasedDocumentStorage
McpTool, McpToolDescriptorParser, McpToolRegistryProvider, ProcessBuilder
Модель можно значительно упростить:
Всё, что касается эвентов, можно реализовать на стороне пользователя. Нужен лишь способ передать callback на события перехода между нодами графа.
Зачем нам и граф, и сабграф. Уже выглядит грязно, ведь хочется, чтобы любой граф мог выступить в качестве сабграфа.
Зачем и граф, и стратегия? Можно было бы оставить только граф.
Всё что касается памяти, пользователь может решить сам — дополнить контекст в своем туле или в ноде графа.
Запутанный синтаксис
Вот пример создания ребра графа между sourceNode и targetNode в Koog.
edge(sourceNode forwardTo targetNode onCondition {input -> input.length > 10})
Что не так с этим кодом?
Мы не можем читать код слева направо. До того, как выполнится
forwardTo, запускаетсяonCondition.Мы не можем доверять тому, что infix функции выполнятся в том порядке, в котором мы ожидаем. Смысла в infix-функциях тут нет никакого.
Излишний синтаксис. Зачем нам и
edge, иforwardTo— это бесцельное дублирование.
На мой взгляд, код ниже читается лучше:
sourceNode.edgeTo {input -> if (input.length > 10) transformerNode else null}
Реализация агента на основе графов
Цикл работы агента ровно такой же, как и в прошлой статье (рисунок из нее), но реализация будет на основе графов.
По горячим следам прошлого параграфа, давайте сделаем набросок того, как должен выглядеть агент:
val agent = buildGraph { nodeInput.edgeTo(nodeLLM) nodeLLM.edgeTo { context -> if (context.isToolUse()) nodeToolUse else nodeFinish } nodeToolUse.edgeTo(nodeLLM) }
Каждый Node должен уметь принять Input и вернуть Output. Например, nodeLLM может выглядеть как-то так:
val nodeLLM = suspend fun (input: String, context: AgentContext): Pair { val request = buildRequestFrom(input, context) // Запрос к АПИ val response = GigaHttpClient.chat(request) // Добавление истории в контекст val newContext = context.copyWith(appendToHistory = response.messages) return response.messages.last to newContext }
А nodeInput — это просто ожидание System.in от юзера:
val nodeInput = suspend fun (input: String, context: AgentContext): Pair { println("> ") val userMessage = readlnOrNull() val newContext = context.copyWith(appendToHistory = userMessage) return userMessage to newContext }
В AgentContext мы можем положить всё, что нужно между нодами. Например, историю общения с агентом и предыдущий output.
Цикл общения с агентом можно построить двумя способами. Первый — сохранять контекст вовне:
val agent = ... var seed = AgentContext("Агент готов") // тут будет сохраняться история while (true) { val result = graph.start(seed) println("Agent said: $result") seed = agent.currentContext }
Второй способ — можно сделать граф зацикленным (заменить nodeFinish на nodeInput в buildGraph выше).
Пока всё должно выглядеть очень просто, давайте реализуем Node.
Упрощенная реализация Агента на основе графов
Давайте напишем первую реализацию, чтобы нащупать нужные абстракции.
В контексте агента важны input и history — ввод для вершины графа (Node) и история. Удобно иметь историю в контексте. Например, если агент упадет, мы можем запустить нового с уже имеющейся историей. Всю мутабельность можно спрятать на этом уровне в будущем.
data class AgentContext( val input: String, val history: List )
Контекст меняется, переходя от одной вершины графа к другой. Давайте опишем вершины и переходы:
interface Node { val name: String // для логов suspend fun execute(ctx: AgentContext): AgentContext } /** Create new [Node] implementation based on [op] */ fun Node( name: String, op: suspend (AgentContext) -> AgentContext, ): Node = object : Node { override val name: String = "Node $name; ${Integer.toHexString(hashCode())}" override suspend fun execute(ctx: AgentContext) = op(ctx) } /** Ребра графа */ sealed interface Transition { class Static(val target: Node) : Transition class Dynamic(val router: suspend (AgentContext) -> Node) : Transition }
Ребра (Transition) могут быть статическими и динамическими. Пример динамического перехода был в предыдущем разделе:
nodeLLM.edgeTo { context -> if (context.isToolUse()) nodeToolUse else nodeFinish }
Теперь надо решить, где хранить ребра графа (переходы). Не хочется, чтобы Node был мутабельным — бывшие Clojure-коллеги не поймут. Да и сами посудите, вдруг понадобится переиспользовать один Node в разных графах. Мутабельность всё испортит.
Пусть пока мутабельным будет Graph, чуть позже мы проведем рефакторинг:
class Graph { val transitions = HashMap<Node, ArrayList<Transition>>() val nodeEnter: Node = Node("enter") { it } fun Node.edgeTo(target: Node): Node { registerTransition(this, Transition.Static(target)) return target } fun Node.edgeTo(router: suspend (AgentContext) -> Node) { registerTransition(this, Transition.Dynamic(router)) } private fun registerTransition(from: Node, transition: Transition) { val bucket = transitions.getOrPut(from) { arrayListOf() } bucket += transition } }
Теперь мы можем создать агента:
suspend fun main() { val nodeInput = Node("NodeInput") { ctx -> val userMessage = readln() ctx.copy( input = userMessage, history = ArrayList(ctx.history).apply { add(userMessage) } ) } val nodeLLM = Node("NodeLLM") { ctx -> val llmResult = "I can't do much, just a mock" ctx.copy( input = llmResult, history = ArrayList(ctx.history).apply { add(llmResult) } ) } val agent = Graph().apply { nodeEnter.edgeTo(nodeInput) nodeInput.edgeTo(nodeLLM) nodeLLM.edgeTo(nodeEnter) } /* agent.run(AgentContext("start")) { node, ctx -> println(node.name + ": " + ctx.input) } */ }
Чтобы граф можно было «запустить», реализуем функцию Graph.run(я быврал BFS, можно в будущем это контралировать через контекст):
suspend fun Graph.nextNodes(node: Node, ctx: AgentContext): List<Node> { val registered = transitions[node] as? List<Transition> ?: emptyList() if (registered.isEmpty()) return emptyList() val next = ArrayList<Node>(registered.size) for (transition in registered) { when (transition) { is Transition.Static -> next.add(transition.target) is Transition.Dynamic -> next.add(transition.router(ctx)) } } return next } suspend fun Graph.run( seed: AgentContext, onStep: (Node, AgentContext) -> Unit ): AgentContext { val queue = ArrayDeque<Pair<Node, AgentContext>>() .apply { add(nodeEnter to seed) } var lastCtx: AgentContext = seed while (queue.isNotEmpty() && currentCoroutineContext().isActive) { val (node, ctx) = queue.removeFirst() val outCtx = node.execute(ctx) onStep(node, outCtx) lastCtx = outCtx val nextNodes = nextNodes(node, outCtx) if (nextNodes.isNotEmpty()) { for (child in nextNodes) { queue.add(child to outCtx) } } } return lastCtx }
Вот и всё, мы реализовали core Koog. Можем запускаться и смотреть на результат.
Детальная реализация Агента на основе графов
Решение выше — всего лишь набросок, но уже функциональный. Внутри Node можно реализовать что угодно — например, построение другого графа или даже нескольких графов, которые можно запустить параллельно через обычный async. Имея callback в функции Graph.run, мы можем повесить метрики. Sequence abstraction (map, filter, reduce) легко реализуется через Node и Transition.
Чего не хватает:
Где-то нужно хранить тулы, system prompt, текущую модель и т.п..
Нет обработки ошибок и retry.
Нельзя использовать граф как сабграф (Node).
Отсутствие полиморфизма (параметрического, т.е. нет дженериков). Не переводить же String input в Json и обратно в String на каждом шагу.
Реализация графа не иммутабельна, а детали реализации торчат наружу (нет инкапсуляции).
Как будем решать?
Все настройки можно положить в AgentContext.
Вынесем абстракцию GraphRunner, которая будет думать о retry. Контекст запуска будем хранить в отдельной сущности GraphRuntime.
Граф реализует интерфейс Node.
Сделаем input в AgentContext дженериком. А историю можем хранить в DTO моделях Гигачата. Понадобится другая LLM-модель — напишем конвертер из Гигачат-моделей в целевые.
Вынесем создание графа в билдер, а граф оставим иммутабельным. Описание графа будет доступно в GraphRuntime.
Прежде чем начнем, реализации функций (тулов) можно найти в предыдущей статье или на гитхабе. Там же есть DTO и Ktor клиенты для Гигачата. Продублирую здесь:
Клиент и DTO для Гигачат
import com.fasterxml.jackson.annotation.JsonProperty import java.util.* object GigaResponse { data class Token( @JsonProperty("access_token") val accessToken: String, @JsonProperty("expires_at") val expiresAt: Date ) sealed interface Chat { data class Ok(val choices: List<Choice>, val created: Long, val model: String) : Chat data class Error(val status: Int, val message: String) : Chat } data class Choice( val message: Message, val index: Int, @JsonProperty("finish_reason") val finishReason: String ) data class Message( val content: String, val role: GigaMessageRole, @JsonProperty("function_call") val functionCall: FunctionCall? = null, @JsonProperty("functions_state_id") val functionsStateId: String? ) data class FunctionCall( val name: String, val arguments: Map<String, Any> ) } object GigaRequest { data class Chat( val model: String = "GigaChat-Max", val messages: List<Message>, @JsonProperty("function_call") val functionCall: String = "auto", val functions: List<Function>? = null, ) data class Message( val role: GigaMessageRole, val content: String, // Could be String or FunctionCall object @JsonProperty("functions_state_id") val functionsStateId: String? = null ) data class Function( val name: String, val description: String, val parameters: Parameters ) data class Parameters( val type: String, val properties: Map<String, Property> ) data class Property( val type: String, val description: String? = null ) } @Suppress("EnumEntryName") enum class GigaMessageRole { system, user, assistant, function } const val MAX_TOKENS = 8192 enum class GigaModel(val alias: String, val maxTokens: Int) { Lite("GigaChat-2", MAX_TOKENS), Pro("GigaChat-Pro", MAX_TOKENS), Max("GigaChat-Max", MAX_TOKENS), } fun String.toSystemPromptMessage() = GigaRequest.Message( role = GigaMessageRole.system, content = this )
import io.ktor.client.* import io.ktor.client.call.* import io.ktor.client.engine.cio.* import io.ktor.client.request.* import io.ktor.client.request.forms.* import io.ktor.http.* object GigaAuth { suspend fun requestToken(apiKey: String): String { val client = HttpClient(CIO) { gigaDefaults() } val response = client.submitForm( url = "https://ngw.devices.sberbank.ru:9443/api/v2/oauth", formParameters = Parameters.build { append("scope", "GIGACHAT_API_PERS") } ) { header("Content-Type", "application/x-www-form-urlencoded") header("Authorization", "Basic $apiKey") }.body<GigaResponse.Token>() client.close() return response.accessToken } }
import io.ktor.client.* import io.ktor.client.call.* import io.ktor.client.engine.cio.* import io.ktor.client.plugins.auth.* import io.ktor.client.plugins.auth.providers.* import io.ktor.client.plugins.logging.LogLevel import io.ktor.client.plugins.logging.Logging import io.ktor.client.request.* import io.ktor.http.* class GigaChatAPI(private val auth: GigaAuth) { private val client = HttpClient(CIO) { var token = "" // get form env, or cache, or db val gigaKey = System.getenv("GIGA_KEY") gigaDefaults() install(Auth) { bearer { loadTokens { BearerTokens(token, "") } refreshTokens { token = auth.requestToken(gigaKey) BearerTokens(token, "") } } } install(Logging) { val envLevel = LogLevel.INFO level = envLevel sanitizeHeader { it.equals(HttpHeaders.Authorization, true) } } } suspend fun message(body: GigaRequest.Chat): GigaResponse.Chat { val response = client.post("https://gigachat.devices.sberbank.ru/api/v1/chat/completions") { setBody(body) } return when { response.status.isSuccess() -> response.body<GigaResponse.Chat.Ok>() else -> response.body<GigaResponse.Chat.Error>() } } fun clear() = client.close() }
import com.fasterxml.jackson.databind.DeserializationFeature import io.ktor.client.* import io.ktor.client.engine.cio.* import io.ktor.client.plugins.* import io.ktor.client.plugins.contentnegotiation.* import io.ktor.client.request.* import io.ktor.http.* import io.ktor.serialization.jackson.* import java.security.cert.X509Certificate import java.util.* import javax.net.ssl.X509TrustManager fun HttpClientConfig<CIOEngineConfig>.gigaDefaults() { this.defaultRequest { header(HttpHeaders.ContentType, "application/json") header(HttpHeaders.Accept, "application/json") header("RqUID", UUID.randomUUID().toString()) } install(HttpTimeout) { requestTimeoutMillis = 40000 } install(ContentNegotiation) { jackson { this.disable(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES) } } engine { https { trustManager = object : X509TrustManager { override fun checkClientTrusted(chain: Array<out X509Certificate>?, authType: String?) {} override fun checkServerTrusted(chain: Array<out X509Certificate>?, authType: String?) {} override fun getAcceptedIssuers(): Array<X509Certificate> = arrayOf() } } } }
Описание функций (тулов)
@Target(AnnotationTarget.PROPERTY) @Retention(AnnotationRetention.RUNTIME) annotation class InputParamDescription(val value: String) interface ToolSetup<Input> { val name: String val description: String operator fun invoke(input: Input): String } class BadInputException(msg: String) : Exception(msg)
Пример реализации тула:
object ToolRunBashCommand : ToolSetup<ToolRunBashCommand.Input> { override val name = "RunBashCommand" override val description = "Executes a bash command and returns its output" override fun invoke(input: Input): String { val process = ProcessBuilder("bash", "-c", input.command) .redirectErrorStream(true) .start() val output = process.inputStream.bufferedReader().use(BufferedReader::readText) val exitCode = process.waitFor() if (exitCode != 0) throw RuntimeException("Command failed with exit code $exitCode") return output.trim() } data class Input( @InputParamDescription("The bash command to run, e.g., 'ls', 'echo Hello', './gradlew tasks'") val command: String ) }
Маппинг на модели гигачата:
import com.dumch.tool.InputParamDescription import com.dumch.tool.ToolSetup import com.fasterxml.jackson.module.kotlin.jacksonObjectMapper import kotlin.reflect.KCallable import kotlin.reflect.full.declaredMembers import kotlin.reflect.full.findAnnotation interface GigaToolSetup { val fn: GigaRequest.Function operator fun invoke(functionCall: GigaResponse.FunctionCall): GigaRequest.Message } val gigaJsonMapper = jacksonObjectMapper() inline fun <reified Input> ToolSetup<Input>.toGiga(): GigaToolSetup { val toolSetup = this return object : GigaToolSetup { override val fn: GigaRequest.Function = GigaRequest.Function( name = toolSetup.name, description = toolSetup.description, parameters = GigaRequest.Parameters( "object", properties = HashMap<String, GigaRequest.Property>().apply { val clazz = Input::class for (kProperty: KCallable<*> in clazz.declaredMembers) { val annotation = kProperty.findAnnotation<InputParamDescription>() ?: continue val description = annotation.value val type = kProperty.returnType.toString().substringAfterLast(".").lowercase() val gigaProperty = GigaRequest.Property(type, description) put(kProperty.name, gigaProperty) } } ) ) override fun invoke( functionCall: GigaResponse.FunctionCall, ): GigaRequest.Message { return try { val input: Input = gigaJsonMapper.convertValue(functionCall.arguments, Input::class.java) val toolResult = toolSetup.invoke(input) val gigaResult = gigaJsonMapper.writeValueAsString( mapOf("result" to toolResult) ) GigaRequest.Message( role = GigaMessageRole.function, content = gigaResult, ) } catch (e: Exception) { e.toGigaToolMessage() } } } } fun Exception.toGigaToolMessage(): GigaRequest.Message { return GigaRequest.Message( role = GigaMessageRole.function, content = """{"result": "${message ?: toString()}"}""", ) }
Приступаем к агенту.
Node с дженериками
interface Node<IN, OUT> { val name: String suspend fun execute(ctx: AgentContext<IN>, runtime: GraphRuntime): AgentContext<OUT> } /** * Create new [Node] implementation based on [op] */ fun <IN, OUT> Node( name: String, op: suspend (AgentContext<IN>) -> AgentContext<OUT>, ): Node<IN, OUT> = object : Node<IN, OUT> { override val name: String = "Node $name; ${Integer.toHexString(hashCode())}" override suspend fun execute(ctx: AgentContext<IN>, runtime: GraphRuntime) = op(ctx) }
Реализация AgentContext
data class AgentContext<I>( val input: I, val settings: AgentSettings, val history: List<GigaRequest.Message>, val tools: List<GigaRequest.Function>, val systemPrompt: String, ) { inline fun <reified O> map( settings: AgentSettings = this.settings, history: List<GigaRequest.Message> = this.history, activeTools: List<GigaRequest.Function> = this.tools, systemPrompt: String = this.systemPrompt, transform: (I) -> O = { it as O }, ): AgentContext<O> = AgentContext(input = transform(input), settings, history, activeTools, systemPrompt) } data class AgentSettings( val model: String, val temperature: Float, val tools: Map<String, GigaToolSetup> )
Реализация Graph с Builder
class Graph<IN, OUT> internal constructor( label: String, private val enter: Node<IN, *>, private val exit: Node<OUT, OUT>, private val retryPolicy: RetryPolicy, private val definition: GraphDefinition, ) : Node<IN, OUT> { private val runner = GraphRunner() override val name: String = "$label::graph" @Suppress("UNCHECKED_CAST") override suspend fun execute(ctx: AgentContext<IN>, runtime: GraphRuntime): AgentContext<OUT> { val result = runner.run( start = enter as Node<Any?, Any?>, seed = ctx as AgentContext<Any?>, runtime = runtime, definition = definition, // ребра передадим в Runner stopPredicate = { node, _ -> node === exit } ) return result as AgentContext<OUT> } suspend fun start( seed: AgentContext<IN>, maxSteps: Int = 1000, onStep: ((step: StepInfo, node: Node<Any?, Any?>, ctx: AgentContext<Any?>) -> Unit)? = null, ): AgentContext<OUT> { val runtime = GraphRuntime( retryPolicy = retryPolicy, maxSteps = maxSteps, onStep = onStep, ) return execute(seed, runtime) } } class GraphBuilder<IN, OUT> internal constructor( private val graphName: String, private val retryPolicy: RetryPolicy, ) { val nodeInput: Node<IN, IN> = Node("$graphName::enter") { it } val nodeFinish: Node<OUT, OUT> = Node("$graphName::exit") { it } private val transitions: MutableMap<Node<*, *>, MutableList<Transition<*>>> = mutableMapOf() fun <IN, OUT, OUT2> Node<IN, OUT>.edgeTo(target: Node<OUT, OUT2>): Node<OUT, OUT2> { registerTransition(this, Transition.Static(target)) return target } fun <IN, OUT> Node<IN, OUT>.edgeTo(router: suspend (AgentContext<OUT>) -> Node<OUT, *>): Unit { registerTransition(this, Transition.Dynamic(router)) } private fun <OUT> registerTransition(from: Node<*, OUT>, transition: Transition<OUT>) { val bucket = transitions.getOrPut(from) { mutableListOf() } bucket += transition } internal fun build(): Graph<IN, OUT> = Graph( graphName, nodeInput, nodeFinish, retryPolicy, GraphDefinition(transitions.mapValues { it.value.toList() }), ) } // Вынесем еще и абстракцию для хранения ребер internal class GraphDefinition( private val transitions: Map<Node<*, *>, List<Transition<*>>>, ) { @Suppress("UNCHECKED_CAST") suspend fun nextNodes(node: Node<Any?, Any?>, ctx: AgentContext<Any?>): List<Node<Any?, *>> { val registered = transitions[node] as? List<Transition<Any?>> ?: emptyList() if (registered.isEmpty()) return emptyList() val next = ArrayList<Node<Any?, *>>(registered.size) for (transition in registered) { when (transition) { is Transition.Static -> next.addOrWarn(transition.target as Node<Any?, *>) is Transition.Dynamic -> next.addOrWarn(transition.router(ctx) as Node<Any?, *>) } } return next } private fun MutableCollection<Node<Any?, *>>.addOrWarn(node: Node<Any?, *>) { if (contains(node)) { add(node) } } } internal sealed interface Transition<OUT> { class Static<OUT>(val target: Node<OUT, *>) : Transition<OUT> class Dynamic<OUT>(val router: suspend (AgentContext<OUT>) -> Node<OUT, *>) : Transition<OUT> } // Ниже всего лишь бойлерплейт для делегатов (by). fun <I, O> buildGraph( name: String = "Graph", retryPolicy: RetryPolicy = RetryPolicy(), configure: GraphBuilder<I, O>.() -> Unit ): Graph<I, O> { val builder = GraphBuilder<I, O>(name, retryPolicy) builder.configure() return builder.build() } fun <I, O> graph( name: String? = null, retryPolicy: RetryPolicy = RetryPolicy(), configure: GraphBuilder<I, O>.() -> Unit ): ReadOnlyProperty<Any?, Graph<I, O>> = GraphDelegate(name, retryPolicy, configure) private class GraphDelegate<I, O>( private val nameHint: String?, private val retryPolicy: RetryPolicy, private val configure: GraphBuilder<I, O>.() -> Unit, ) : ReadOnlyProperty<Any?, Graph<I, O>> { private var cached: Graph<I, O>? = null override fun getValue(thisRef: Any?, property: KProperty<*>): Graph<I, O> { return cached ?: build(property.name).also { cached = it } } private fun build(propertyName: String): Graph<I, O> { val name = nameHint ?: propertyName val builder = GraphBuilder<I, O>(name, retryPolicy) builder.configure() return builder.build() } }
Реализация GraphRunner и GraphRuntime
internal class GraphRunner { suspend fun run( start: Node<Any?, Any?>, seed: AgentContext<Any?>, runtime: GraphRuntime, definition: GraphDefinition, stopPredicate: ((Node<Any?, Any?>, AgentContext<Any?>) -> Boolean)? = null, ): AgentContext<Any?> { val queue = ArrayDeque<Frame>().apply { add(Frame(start, seed, 0)) } val leaves = mutableListOf<AgentContext<*>>() var lastCtx: AgentContext<Any?> = seed try { while (queue.isNotEmpty() && currentCoroutineContext().isActive) { if (runtime.counter.get() >= runtime.maxSteps) { error("Graph maxSteps (${runtime.maxSteps}) reached — potential loop") } val frame = queue.removeFirst() val outCtx = executeWithRetry(frame.node, frame.ctx, runtime) val stepInfo = StepInfo(currentGraphIndex = frame.depth, index = runtime.counter.get()) runtime.onStep?.invoke(stepInfo, frame.node, outCtx) lastCtx = outCtx if (stopPredicate?.invoke(frame.node, outCtx) == true) return outCtx val nextNodes = definition.nextNodes(frame.node, outCtx) if (nextNodes.isEmpty()) { leaves += outCtx } else { for (child in nextNodes) { @Suppress("UNCHECKED_CAST") queue.add(Frame(child as Node<Any?, Any?>, outCtx, frame.depth + 1)) } } runtime.counter.incrementAndGet() } } catch (cancel: CancellationException) { throw GraphCancellation(lastCtx, cancel) } @Suppress("UNCHECKED_CAST") return leaves.lastOrNull() as? AgentContext<Any?> ?: lastCtx } private suspend fun executeWithRetry( node: Node<Any?, Any?>, inCtx: AgentContext<Any?>, runtime: GraphRuntime, ): AgentContext<Any?> { val policy = runtime.retryPolicy var attempt = 0 var lastError: Throwable? = null while (attempt < policy.maxAttempts) { attempt++ try { return node.execute(inCtx, runtime) } catch (t: Throwable) { if (t is CancellationException) throw t lastError = t val shouldRetry = policy.shouldRetry(t, inCtx, node, attempt) val attemptsLeft = policy.maxAttempts - attempt if (!shouldRetry || attemptsLeft <= 0) break } } throw lastError ?: IllegalStateException("Unknown failure in node ${node.name}") } private data class Frame( val node: Node<Any?, Any?>, val ctx: AgentContext<Any?>, val depth: Int, ) } class GraphCancellation( val lastContext: AgentContext<*>, cause: CancellationException? = null ) : CancellationException(cause?.message) { init { initCause(cause) } } data class StepInfo( /** * Sequential index of the executed node within the run (starting from 0). */ val index: Int, /** * Sequential index of the executed node within the current graph run (starting from 0). */ val currentGraphIndex: Int, ) class GraphRuntime private constructor( val retryPolicy: RetryPolicy, val maxSteps: Int, val onStep: ((step: StepInfo, node: Node<Any?, Any?>, ctx: AgentContext<Any?>) -> Unit)? = null, val counter: AtomicInteger ) { constructor( retryPolicy: RetryPolicy, maxSteps: Int, onStep: ((step: StepInfo, node: Node<Any?, Any?>, ctx: AgentContext<Any?>) -> Unit)? = null, ): this(retryPolicy, maxSteps, onStep, counter = AtomicInteger()) } data class RetryPolicy( val maxAttempts: Int = 2, val shouldRetry: suspend ( error: Throwable, ctx: AgentContext<*>, node: Node<*, *>?, attempt: Int ) -> Boolean = { _, _, _, _ -> true } )
Дефолтные реализации Node
object NodesCommon { val stringToReq: Node<String, GigaRequest.Chat> = Node("String->Request") { ctx -> val usrMsg = GigaRequest.Message(GigaMessageRole.user, ctx.input) val history = ArrayList(ctx.history).apply { if (isEmpty()) add(ctx.systemPrompt.toSystemPromptMessage()) add(usrMsg) } ctx.map(history = history) { ctx.toGigaRequest(history) } } val respToString: Node<GigaResponse.Chat, String> = Node("Response->String") { ctx -> when (val input = ctx.input) { is GigaResponse.Chat.Error -> ctx.map { input.message } is GigaResponse.Chat.Ok -> ctx.map { input.choices.last().message.content } } } val toolUse: Node<GigaResponse.Chat, GigaRequest.Chat> = Node("toolUse") { ctx -> val fnCallMessages = fnCallMessages(ctx) val history = ArrayList(ctx.history).apply { addAll(fnCallMessages) } ctx.map(history = history) { ctx.toGigaRequest(history) } } private suspend fun fnCallMessages(ctx: AgentContext<GigaResponse.Chat>): List<GigaRequest.Message> { val fnCallMessages = (ctx.input as GigaResponse.Chat.Ok).choices.mapNotNull { choice -> val msg = choice.message if (msg.functionCall != null && msg.functionsStateId != null) { executeTool(ctx.settings, msg.functionCall) } else null } return fnCallMessages } private suspend fun executeTool( settings: AgentSettings, functionCall: GigaResponse.FunctionCall, ): GigaRequest.Message { val tools = settings.tools val fn: GigaToolSetup = tools[functionCall.name] ?: return GigaRequest.Message( GigaMessageRole.function, """{"result":"no such function ${functionCall.name}"}""" ) return fn.invoke(functionCall) } } fun <T> AgentContext<T>.toGigaRequest(history: List<GigaRequest.Message>): GigaRequest.Chat { val ctx = this return GigaRequest.Chat( model = ctx.settings.model, messages = history, functions = ctx.tools, ) }
class NodesLLM(llmApi: GigaChatAPI) { val chat: Node<GigaRequest.Chat, GigaResponse.Chat> = Node("llmCall") { ctx -> val response = withContext(Dispatchers.IO) { llmApi.message(ctx.input) } val history = ArrayList(ctx.history).apply { if (response is GigaResponse.Chat.Ok) { addAll(response.choices.mapNotNull { it.toMessage() }) } } ctx.map(history = history) { response } } /** * Restores the last message, and a system prompt. Other messages are transformed into TLDR */ val summarize: Node<GigaResponse.Chat, GigaResponse.Chat> = Node("llmSummarize") { ctx -> val conversation = ArrayList(ctx.history) val summaryResponse: GigaResponse.Chat = withContext(Dispatchers.IO) { conversation.add(GigaRequest.Message( role = GigaMessageRole.user, content = "Резюмируй разговор", )) val request = ctx.toGigaRequest(conversation) .copy(functions = emptyList()) llmApi.message(request) } val msg: GigaRequest.Message = when (summaryResponse) { is GigaResponse.Chat.Error -> { throw IOException(summaryResponse.message) } is GigaResponse.Chat.Ok -> summaryResponse.choices.mapNotNull { it.toMessage() }.last() } val newHistory = listOf(ctx.systemPrompt.toSystemPromptMessage(), ctx.history.last(), msg) ctx.map(history = newHistory) { summaryResponse } } private fun GigaResponse.Choice.toMessage(): GigaRequest.Message? { val msg = this.message val content: String = when { msg.content.isNotBlank() -> msg.content msg.functionCall != null -> gigaJsonMapper.writeValueAsString( mapOf("name" to msg.functionCall.name, "arguments" to msg.functionCall.arguments) ) else -> return null } return GigaRequest.Message( role = msg.role, content = content, functionsStateId = msg.functionsStateId ) } }
И, наконец, агент с вызовом тулов, суммаризацией истории и хранением контекста (истории) будет выглядеть так:
import com.dumch.agent.engine.* import com.dumch.agent.node.NodesCommon import com.dumch.agent.node.NodesLLM import com.dumch.giga.* import kotlinx.coroutines.Deferred import kotlinx.coroutines.async import kotlinx.coroutines.coroutineScope import kotlinx.coroutines.flow.MutableStateFlow import kotlinx.coroutines.flow.StateFlow import java.util.concurrent.atomic.AtomicReference import kotlin.coroutines.cancellation.CancellationException import kotlin.math.ceil class GraphBasedAgent( private val model: String, private val llmApi: GigaChatAPI, private val tools: Map<String, GigaToolSetup> = GigaAgent.tools ) { private val nodesLLM = NodesLLM(llmApi) // Make sure summarization only happens after all tool requests from LLM are answered private val nodeSummarize: Node<GigaResponse.Chat, String> by graph(name = "Go to user") { nodeInput.edgeTo { ctx -> if (ctx.historyIsTooBig()) nodesLLM.summarize else NodesCommon.respToString } nodesLLM.summarize.edgeTo(NodesCommon.respToString) NodesCommon.respToString.edgeTo(nodeFinish) } private val settings = AgentSettings( model = model, temperature = 0.7f, tools = tools ) private val allFunctions: List<GigaRequest.Function> = settings.tools.values.map { it.fn } private val initialCtx = AgentContext( input = "", settings = settings, history = emptyList(), tools = allFunctions, systemPrompt = SYSTEM_PROMPT ) private val _ctx: MutableStateFlow<AgentContext<String>> = MutableStateFlow(initialCtx) val currentContext: StateFlow<AgentContext<String>> = _ctx private val runningJob = AtomicReference<Deferred<*>>() fun cancelActiveJob() { runningJob.get()?.cancel(CancellationException("Cleared by force")) } /** Execute one job at a time */ suspend fun execute(input: String): String { cancelActiveJob() val ctx = currentContext.value.copy(input = input) val result: Deferred<AgentContext<String>> = coroutineScope { async { buildGraph().start(ctx) { _, _, _ -> } } } runningJob.set(result) val newContext = result.await() _ctx.emit(newContext) return newContext.input } private fun buildGraph(): Graph<String, String> = buildGraph(name = "Agent") { nodeInput.edgeTo(NodesCommon.stringToReq) NodesCommon.stringToReq.edgeTo(nodesLLM.chat) nodesLLM.chat.edgeTo { ctx -> when (val output = ctx.input) { is GigaResponse.Chat.Error -> nodeSummarize is GigaResponse.Chat.Ok -> if (isToolUse(output)) NodesCommon.toolUse else nodeSummarize } } NodesCommon.toolUse.edgeTo(nodesLLM.chat) nodeSummarize.edgeTo(nodeFinish) } private fun isToolUse(input: GigaResponse.Chat.Ok): Boolean = input.choices.any { it.message.functionCall != null } private fun AgentContext<GigaResponse.Chat>.historyIsTooBig( threshold: Double = HISTORY_SUMMARIZE_THRESHOLD, ): Boolean { val model = GigaModel.entries.firstOrNull { it.alias == settings.model } val contextWindow = model?.maxTokens ?: MAX_TOKENS val estimatedTokens = systemPrompt.estimateTokenCount() + history.sumOf { it.content.estimateTokenCount() } return estimatedTokens >= contextWindow * threshold } private fun String.estimateTokenCount(): Int = ceil(length / APPROX_CHARS_PER_TOKEN).toInt() } private const val HISTORY_SUMMARIZE_THRESHOLD = 0.8 private const val APPROX_CHARS_PER_TOKEN = 4.0 private val SYSTEM_PROMPT = """ Ты программист-помощник, 10 лет пишешь код на Kotlin, Android и Backend. Стараешься писать простой и поддерживаемый код. """.trimIndent()
Использование:
private const val AGENT_ALIAS = "🪐" suspend fun main() { val agent = GraphBasedAgent( model = GigaModel.Max.alias, llmApi = GigaChatAPI(GigaAuth), ) userInputFlow().collect { text -> val result = agent.execute(text) println(AGENT_ALIAS + result) } } private fun userInputFlow(): Flow = flow { println("Type `exit` to quit") while (true) { print("> ") val input = readlnOrNull() ?: break if (input.lowercase() == "exit") break emit(input) println("\n") } }
Изменения по сравнению с версией агента из предыдущей статьи — то есть реализация на основе графа и базовые Nodes — собраны в PR.
Если нужен openai вместо gigacode, можно взять openai-kotlin. В предыдущей статье писал, как легко адаптировать anthropic через их sdk. Библиотеки «композируются», в отличие от фреймворков.
Frameworks do not compose. — Tomas Petricek, article.
Добавление RAG
Обычно под RAG имеется в виду следующий алгоритм:
Запрос к API, чтобы перевести текст в вектор (например, запрос пользователя) .
Поиск по векторной базе данных похожих текстов (например, по документам компании).
Прикрепление похожих текстов к промпту.
На хабре есть статья с формальными определениями и примерами.
RAG можно найти в документации Koog в подкатегории Advanced Usage. И я уверен, что с Koog задача действительно потребует advanced-усилий, ведь не понятно, что они используют под капотом, есть ли там кеши, retry, какая база данных будет использоваться, можно ли не тащить ненужные зависимости в проект, будут ли они добавлять промпты, чтобы захачить свои проблемы. На всё есть ответы, но с этим надо разбираться (об этих проблемах с примерами и ссылками писал в этой же статье выше).
Давайте реализуем RAG в рамках имеющегося решения. Абстракции, относящиеся к агенту, трогать не будем — всё решим на уровне Node.
Нам понадобится ручка с модельками для перевода текстов в вектора. Реализуем на доступном всем Гигачат.
Ручка и модельки
object GigaResponse { // ... предыдущий код data class Embeddings( val data: List<Embedding>, val model: String, @JsonProperty("object") val objectType: String, ) data class Embedding( val embedding: List<Double>, val index: Int, @JsonProperty("object") val objectType: String? = null, ) } object GigaRequest { // ... предыдущий код data class Embeddings( val model: String = "Embeddings", val input: List<String>, ) } class GigaChatAPI(private val auth: GigaAuth) { // ... предыдущий код suspend fun embeddings(body: GigaRequest.Embeddings): GigaResponse.Embeddings { val response = client.post("https://gigachat.devices.sberbank.ru/api/v1/embeddings") { setBody(body) } return when { response.status.isSuccess() -> response.body<GigaResponse.Embeddings>() response.status == HttpStatusCode.Unauthorized || response.status == HttpStatusCode.Forbidden -> TODO("Auth exception") else -> TODO("unexpected error") } } }
Добавляем векторную базу и наивную реализацию:
implementation("org.apache.lucene:lucene-core:9.9.2")
Можно было бы решить и плагином для SQL, но для целей статьи так быстрее:
Обертка над векторной базой
object VectorDB { private const val INDEX_PATH = "build/rag_index" init { val isInitialized = File(INDEX_PATH).exists() // naive way to check initialization if (!isInitialized) { val dir = FSDirectory.open(Paths.get(INDEX_PATH)) IndexWriter(dir, IndexWriterConfig()).use { } } } fun insert(data: List<String>, embeddings: List<List<Double>>) { val dir = FSDirectory.open(Paths.get(INDEX_PATH)) IndexWriter(dir, IndexWriterConfig()).use { writer -> data.indices.forEach { idx -> val doc = Document() doc.add(StoredField("text", data[idx])) doc.add(KnnFloatVectorField("embedding", toFloatArray(embeddings[idx]))) writer.addDocument(doc) } } } fun getAllTexts(): List<String> { val dir = FSDirectory.open(Paths.get(INDEX_PATH)) DirectoryReader.open(dir).use { reader -> val list = mutableListOf<String>() for (i in 0 until reader.maxDoc()) { val doc = reader.document(i) doc.get("text")?.let { list.add(it) } } return list } } fun searchSimilar(embedding: List<Double>, limit: Int = 5): List<String> { val dir = FSDirectory.open(Paths.get(INDEX_PATH)) DirectoryReader.open(dir).use { reader -> val searcher = IndexSearcher(reader) val query = KnnFloatVectorQuery("embedding", toFloatArray(embedding), limit) val topDocs = searcher.search(query, limit) val texts = mutableListOf<String>() topDocs.scoreDocs.forEach { sd -> searcher.doc(sd.doc).get("text")?.let { texts.add(it) } } return texts } } fun clearAllData() { val dir = FSDirectory.open(Paths.get(INDEX_PATH)) IndexWriter(dir, IndexWriterConfig()).use { writer -> writer.deleteAll() } } private fun toFloatArray(list: List<Double>): FloatArray { val size = min(list.size, MAX_DIM) val arr = FloatArray(size) for (i in 0 until size) { arr[i] = list[i].toFloat() } return arr } private const val MAX_DIM = 1024 }
Имеющихся реализаций хватит, чтобы пощупать RAG руками:
suspend fun main() { val vectorDb = VectorDB // Настройка базы vectorDb.clearAllData() // осторожнее с последующими запусками, тут — чистка val api = GigaChatAPI(GigaAuth) val knownFacts = listOf( "RAG is an AI technique that combines a search engine with a large language model (LLM) — Google AI overview", "Perhaps the biggest and the most obvious problem with frameworks is that they cannot be composed. — Tomas Petricek", "Use frameworks only for applications with a short development lifespan, and avoid frameworks for systems you intend to keep for multiple years. — Mathias Verraes", "Inversion of control is a common feature of frameworks, but it's something that comes at a price. " + "It tends to be hard to understand and leads to problems when you are trying to debug. " + "So on the whole I prefer to avoid it unless I need it. — Martin Fowler" ) val factsEmbeddings = api.embeddings(GigaRequest.Embeddings(input = knownFacts)) vectorDb.insert(knownFacts, factsEmbeddings.data.map { it.embedding }) // Использование базы для поиска схожих строк val input = "Фреймворк — хорошо или плохо? Есть ли причины не использовать фреймворки?" val embedding = api.embeddings(GigaRequest.Embeddings(input = listOf(input))) val result = vectorDb.searchSimilar(embedding.data.first().embedding, limit = 3) // Ожидаю, что напечатаются 3 цитаты про фреймворки. println(result.joinToString(prefix = "Found:\n", separator = "\n")) }
Теперь в классе, где мы описываем граф, можно добавить Node:
class GraphBasedAgent(...) { private val nodeAppendAdditionalData: Node = Node("appendActualInformation") { ctx -> val additionalMessage = appendActualInformation(ctx.input) if (additionalMessage == null) { ctx } else { val history = ArrayList(ctx.history).apply { add(additionalMessage) } ctx.map(history = history) } } private fun buildGraph(): Graph = buildGraph(name = "Agent") { nodeInput.edgeTo(nodeAppendAdditionalData) nodeAppendAdditionalData.edgeTo(NodesCommon.stringToReq) NodesCommon.stringToReq.edgeTo(nodesLLM.chat) ... } private suspend fun appendActualInformation(userText: String): GigaRequest.Message? { if (userText.isBlank()) return null val embedding = llmApi.embeddings(GigaRequest.Embeddings(input = listOf(userText))) val result = VectorDB.searchSimilar(embedding.data.first().embedding, limit = 2) .joinToString( prefix = "Найденные в локальном хранилище данные:\n", separator = "\n", ) return GigaRequest.Message( role = GigaMessageRole.user, content = result, ) } }
Вот и весь RAG. Для удобства вынес в PR на гитхабе. В production-ready приложении добавятся обработка ошибок, ретраи, и, может быть, слой с репозиторием.
Когда использовать фреймворк, а не самописное решение?
Когда иначе невозможно. Примеры: разработка мобильных приложений.
Когда фреймворк становится де-факто стандартом. Пример: Spring для бэкенда на Java.
Когда приложение короткоживущее. Пример: MVP, POC, ~разовый аутсорс.
Use a framework for applications with a short expected development lifespan. — Mathias Verraes, X.
В заключение
Весь код, необходимый для того, чтобы переписать агента с циклов (первая статья) на графы — в PR на гитхабе. Примерно такой же код я использую в двух других проектах. Отличия — в деталях, которые опустил в статье для упрощения материала. Читателю не составит труда адаптировать код или даже написать свою реализацию.
Надеюсь, кому-то статья окажется полезной. Обратная связь и критика приветствуются.
