Как стать автором
Обновить

Использование Kotlin и WebFlux для выполнения задач ML в Apache Spark на GPU

Уровень сложностиСредний
Время на прочтение32 мин
Количество просмотров3.5K

Это третья статья по теме реализации масштабируемой системы для выполнения задач распределенного машинного обучения на GPU с использованием Java, Kotlin, Spring и Spark. Список всех статей:

  1. Варианты использования Java ML библиотек совместно со Spring, Docker, Spark, Rapids, CUDA

  2. Масштабируемая Big Data система в Kubernetes с использованием Spark и Cassandra

  3. Использование Kotlin и WebFlux для выполнения задач ML в Apache Spark на GPU

О чем данная статья

В предыдущей статье для создания Spark Driver приложения использовался сервлетный стек Spring (Boot 2.7.11) и JDK 8.

На дворе вторая половина 2023 года, у многих в проде уже используется Boot 3+ (а то и 3.1+), совсем скоро должна выйти новая LTS версия Java, и, мягко говоря, Boot 2+ и JDK8 устарели. Использовались они намеренно, так как для задач тренировки моделей машинного обучения на GPU в среде Spark частью системы является ускоритель вычислений на GPU NVidia Rapids. Поддержка JDK 17 появилась только в релизе v23.06.0 от 27.06.23, с ее выходом появилась возможность перейти на актуальную LTS версию Java, а с ней - на Spring Boot 3+.

В данной статье описывается миграция с Boot 2 и JDK 8 До Boot 3 и JDK 17, со Spring Web на Spring WebFlux, в конце сравниваются Web и WebFlux версии по потреблению аппаратных ресурсов и скорости выполнения.

JDK8, Spring boot 2.7.11 → JDK17, Spring Boot 3.1.1

Для миграции достаточно поднять версии Rapids до 23.06.0, JDK до 17, Spring Boot до 3.1.1. Нюансов не так уж и много:

  1. Конфликт логеров Slf4j и Log4j при использовании Spark: из зависимости spring boot starter web исключаем spring boot starter logging:

pom.xml
<dependency>
    <groupId>org.springframework.boot</groupId>
    <artifactId>spring-boot-starter-web</artifactId>
    <version>${spring.boot.version}</version>
    <exclusions>
        <exclusion>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-logging</artifactId>
        </exclusion>
        <exclusion>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-tomcat</artifactId>
        </exclusion>
    </exclusions>
</dependency>

  1. Запускать Spark Driver на JDK 17 необходимо со следующими параметрами (приведено для Dockerfile):

Application Dockerfile
ENV JAVA_OPTS='--add-opens=java.base/java.lang=ALL-UNNAMED \
               --add-opens=java.base/java.lang.invoke=ALL-UNNAMED \
               --add-opens=java.base/java.lang.reflect=ALL-UNNAMED \
               --add-opens=java.base/java.io=ALL-UNNAMED \
               --add-opens=java.base/java.net=ALL-UNNAMED \
               --add-opens=java.base/java.nio=ALL-UNNAMED \
               --add-opens=java.base/java.util=ALL-UNNAMED \
               --add-opens=java.base/java.util.concurrent=ALL-UNNAMED \
               --add-opens=java.base/java.util.concurrent.atomic=ALL-UNNAMED \
               --add-opens=java.base/sun.nio.ch=ALL-UNNAMED \
               --add-opens=java.base/sun.nio.cs=ALL-UNNAMED \
               --add-opens=java.base/sun.security.action=ALL-UNNAMED \
               --add-opens=java.base/sun.util.calendar=ALL-UNNAMED \
               --add-opens=java.security.jgss/sun.security.krb5=ALL-UNNAMED'

  1. В связи с переходом на Hibernate 6, при использовании JSOB и BYTEA полей в сущностях Postgres придется немного отрефакторить Entity:

При этом, использовавшийся ранее CustomPostgresDialect оказывается не нужным и его можно удалить, заменив на org.hibernate.dialect.PostgreSQLDialect:

application.yml
spring:
  ...
  jpa:
    database-platform: com.mlwebservice.config.CustomPostgresDialect  # <== delete
    database-platform: org.hibernate.dialect.PostgreSQLDialect        # <== add

Ранее использовавшийся CustomPostgresDialect
package com.mlwebservice.config

import com.vladmihalcea.hibernate.type.array.IntArrayType
import com.vladmihalcea.hibernate.type.array.StringArrayType
import com.vladmihalcea.hibernate.type.json.JsonBinaryType
import com.vladmihalcea.hibernate.type.json.JsonNodeBinaryType
import com.vladmihalcea.hibernate.type.json.JsonNodeStringType
import com.vladmihalcea.hibernate.type.json.JsonStringType
import org.hibernate.dialect.PostgreSQL10Dialect
import java.sql.Types

class CustomPostgresDialect : PostgreSQL10Dialect() {
    init {
        registerHibernateType(Types.OTHER, StringArrayType::class.qualifiedName)
        registerHibernateType(Types.OTHER, IntArrayType::class.qualifiedName)
        registerHibernateType(Types.OTHER, JsonStringType::class.qualifiedName)
        registerHibernateType(Types.OTHER, JsonBinaryType::class.qualifiedName)
        registerHibernateType(Types.OTHER, JsonNodeBinaryType::class.qualifiedName)
        registerHibernateType(Types.OTHER, JsonNodeStringType::class.qualifiedName)
    }
}

Не считая докерфайлов и действий по добавлению новой версии jar'ника Rapids в директорию с jar-файлами для отправки в Spark executors и в образ executor’а, это все, что необходимо выполнить. Актуальную версию можно взять в соответствующей ветке репозитория.

На этом можно было бы и закончить, но любопытство ведь берет свое, и появился вопрос - а заработает ли на реактивном стеке, и будет ли эффект?

Сделаем ML реактивным: Spring Web → Spring WebFlux

Зависимости

Изменений при таком переходе изначально должно быть больше, но так же есть нюансы в виде управления зависимостями. Так, Netty, необходимый для Project Reactor (WebFlux) используется самим Spark и драйвером Cassandra, поэтому изначально конфликтовали. Решается путем задания трех зависимостей в самом начале списка зависимостей:

pom.xml: Зависимости Netty
<dependencies>
    <!-- Netty -->
    <dependency>
        <groupId>io.netty</groupId>
        <artifactId>netty-all</artifactId>
        <version>4.1.74.Final</version>
    </dependency>
    <dependency>
        <groupId>io.netty</groupId>
        <artifactId>netty-codec-http</artifactId>
        <version>4.1.74.Final</version>
    </dependency>
    <dependency>
        <groupId>io.netty</groupId>
        <artifactId>netty-resolver-dns</artifactId>
        <version>4.1.74.Final</version>
    </dependency>

    <!-- Spring -->
    <dependency>
        <groupId>org.springframework.boot</groupId>
        <artifactId>spring-boot-starter-webflux</artifactId>
        <version>${spring.boot.version}</version>
        <exclusions>
            <exclusion>
                <artifactId>log4j-to-slf4j</artifactId>
                <groupId>org.apache.logging.log4j</groupId>
            </exclusion>
        </exclusions>
    </dependency>
    ...
</dependencies>

Spring Data тоже заменяется на реактивную версию:

pom.xml: R2DBC и Spring Data Cassandra Reactive
<dependency>
    <groupId>org.springframework.boot</groupId>
    <artifactId>spring-boot-starter-data-cassandra-reactive</artifactId>
    <version>${spring.boot.version}</version>
</dependency>
<dependency>
    <groupId>org.springframework.boot</groupId>
    <artifactId>spring-boot-starter-data-r2dbc</artifactId>
    <version>${spring.boot.version}</version>
</dependency>
<dependency>
    <groupId>io.r2dbc</groupId>
    <artifactId>r2dbc-postgresql</artifactId>
    <version>0.8.13.RELEASE</version>
</dependency>

И добавляются несколько библиотек для работы Kotlin в среде WebFlux:

pom.xml: Kotlin dependencies
<dependency>
    <groupId>org.jetbrains.kotlin</groupId>
    <artifactId>kotlin-stdlib</artifactId>
    <version>${kotlin.version}</version>
</dependency>
<dependency>
    <groupId>org.jetbrains.kotlin</groupId>
    <artifactId>kotlin-reflect</artifactId>
    <version>${kotlin.version}</version>
    <scope>runtime</scope>
</dependency>
<dependency>
    <groupId>org.jetbrains.kotlinx</groupId>
    <artifactId>kotlinx-coroutines-reactor</artifactId>
    <version>1.7.2</version>
    <scope>runtime</scope>
</dependency>
<dependency>
    <groupId>io.projectreactor.kotlin</groupId>
    <artifactId>reactor-kotlin-extensions</artifactId>
    <version>1.2.2</version>
    <scope>runtime</scope>
</dependency>

Кстати, сам Kotlin тоже поднял с версии 1.8.21 до 1.9.0.

Для логирования HTTP запросов-ответов добавляем Zalando Logbook:

pom.xml: Zalando Logbook
<dependency>
    <groupId>org.zalando</groupId>
    <artifactId>logbook-spring-boot-autoconfigure</artifactId>
    <version>3.2.0</version>
</dependency>
<dependency>
    <groupId>org.zalando</groupId>
    <artifactId>logbook-netty</artifactId>
    <version>3.2.0</version>
</dependency>

pom.xml (полная версия для WebFlux)
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
         xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
         xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
    <modelVersion>4.0.0</modelVersion>

    <groupId>com.mlwebservice</groupId>
    <artifactId>MLWebService</artifactId>
    <version>1.0.0-SNAPSHOT</version>

    <properties>
        <java.version>17</java.version>
        <spring.boot.version>3.1.1</spring.boot.version>
        <scala.version>2.12</scala.version>
        <spark.version>3.3.2</spark.version>
        <lombok.version>1.18.24</lombok.version>
        <org.mapstruct.version>1.4.2.Final</org.mapstruct.version>
        <kotlin.version>1.9.0</kotlin.version>
        <jackson.version>2.13.5</jackson.version>
    </properties>

    <distributionManagement>
        <repository>
            <id>XGBoost4J Snapshot Repo</id>
            <name>XGBoost4J Snapshot Repo</name>
            <url>https://s3-us-west-2.amazonaws.com/xgboost-maven-repo/snapshot/</url>
        </repository>
    </distributionManagement>

    <dependencies>
        <!-- Netty -->
        <dependency>
            <groupId>io.netty</groupId>
            <artifactId>netty-all</artifactId>
            <version>4.1.74.Final</version>
        </dependency>
        <dependency>
            <groupId>io.netty</groupId>
            <artifactId>netty-codec-http</artifactId>
            <version>4.1.74.Final</version>
        </dependency>
        <dependency>
            <groupId>io.netty</groupId>
            <artifactId>netty-resolver-dns</artifactId>
            <version>4.1.74.Final</version>
        </dependency>

        <!-- Spring -->
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-webflux</artifactId>
            <version>${spring.boot.version}</version>
            <exclusions>
                <exclusion>
                    <groupId>org.springframework.boot</groupId>
                    <artifactId>spring-boot-starter-logging</artifactId>
                </exclusion>
            </exclusions>
        </dependency>
        <dependency>
            <groupId>com.fasterxml.jackson.core</groupId>
            <artifactId>jackson-core</artifactId>
            <version>${jackson.version}</version>
        </dependency>
        <dependency>
            <groupId>com.fasterxml.jackson.module</groupId>
            <artifactId>jackson-module-kotlin</artifactId>
            <version>${jackson.version}</version>
        </dependency>
        <dependency>
            <groupId>com.fasterxml.jackson.core</groupId>
            <artifactId>jackson-annotations</artifactId>
            <version>${jackson.version}</version>
        </dependency>
        <dependency>
            <groupId>com.fasterxml.jackson.core</groupId>
            <artifactId>jackson-databind</artifactId>
            <version>${jackson.version}</version>
        </dependency>

        <!-- Spring Data -->
        <dependency>
            <groupId>org.springframework.data</groupId>
            <artifactId>spring-data-commons</artifactId>
            <version>${spring.boot.version}</version>
        </dependency>
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-data-cassandra-reactive</artifactId>
            <version>${spring.boot.version}</version>
        </dependency>
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-data-r2dbc</artifactId>
            <version>${spring.boot.version}</version>
        </dependency>
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-data-jpa</artifactId>
            <version>${spring.boot.version}</version>
        </dependency>
        <dependency>
            <groupId>org.postgresql</groupId>
            <artifactId>postgresql</artifactId>
            <scope>runtime</scope>
            <version>42.6.0</version>
        </dependency>
        <dependency>
            <groupId>io.r2dbc</groupId>
            <artifactId>r2dbc-postgresql</artifactId>
            <version>0.8.13.RELEASE</version>
        </dependency>
        <dependency>
            <groupId>com.vladmihalcea</groupId>
            <artifactId>hibernate-types-60</artifactId>
            <version>2.21.1</version>
        </dependency>

        <!-- Cassandra -->
        <dependency>
            <groupId>com.datastax.oss</groupId>
            <artifactId>java-driver-core</artifactId>
            <version>4.13.0</version>
        </dependency>
        <dependency>
            <groupId>org.scala-lang</groupId>
            <artifactId>scala-library</artifactId>
            <version>2.12.15</version>
        </dependency>
        <dependency>
            <groupId>com.datastax.spark</groupId>
            <artifactId>spark-cassandra-connector_2.12</artifactId>
            <version>3.3.0</version>
        </dependency>

        <dependency>
            <groupId>com.typesafe</groupId>
            <artifactId>config</artifactId>
            <version>1.4.2</version>
        </dependency>

        <!-- Spark -->
        <dependency>
            <groupId>org.apache.spark</groupId>
            <artifactId>spark-core_${scala.version}</artifactId>
            <version>${spark.version}</version>
        </dependency>
        <dependency>
            <groupId>org.apache.spark</groupId>
            <artifactId>spark-hive_${scala.version}</artifactId>
            <version>${spark.version}</version>
        </dependency>
        <dependency>
            <groupId>org.apache.spark</groupId>
            <artifactId>spark-streaming_${scala.version}</artifactId>
            <version>${spark.version}</version>
        </dependency>
        <dependency>
            <groupId>org.apache.spark</groupId>
            <artifactId>spark-sql_${scala.version}</artifactId>
            <version>${spark.version}</version>
        </dependency>
        <dependency>
            <groupId>org.apache.spark</groupId>
            <artifactId>spark-mllib_${scala.version}</artifactId>
            <version>${spark.version}</version>
        </dependency>
        <dependency>
            <groupId>org.antlr</groupId>
            <artifactId>antlr4-runtime</artifactId>
            <version>4.8</version>
            <scope>runtime</scope>
        </dependency>

        <!-- GXBoost -->
        <dependency>
            <groupId>ml.dmlc</groupId>
            <artifactId>xgboost4j-spark-gpu_${scala.version}</artifactId>
            <version>1.7.5</version>
        </dependency>
        <dependency>
            <groupId>ml.dmlc</groupId>
            <artifactId>xgboost4j-gpu_${scala.version}</artifactId>
            <version>1.7.5</version>
        </dependency>

        <!-- Kubernetes -->
        <dependency>
            <groupId>org.apache.spark</groupId>
            <artifactId>spark-kubernetes_${scala.version}</artifactId>
            <version>${spark.version}</version>
        </dependency>

        <dependency>
            <groupId>org.codehaus.janino</groupId>
            <artifactId>commons-compiler</artifactId>
            <version>3.0.16</version>
        </dependency>
        <dependency>
            <groupId>org.codehaus.janino</groupId>
            <artifactId>janino</artifactId>
            <version>3.0.16</version>
        </dependency>

        <!-- Rapids -->
        <dependency>
            <groupId>com.nvidia</groupId>
            <artifactId>rapids-4-spark_${scala.version}</artifactId>
            <version>23.06.0</version>
        </dependency>

        <!-- Lombok -->
        <dependency>
            <groupId>org.projectlombok</groupId>
            <artifactId>lombok</artifactId>
            <version>${lombok.version}</version>
        </dependency>

        <!-- Logging -->
        <dependency>
            <groupId>org.zalando</groupId>
            <artifactId>logbook-spring-webflux</artifactId>
            <version>3.1.0</version>
        </dependency>
        <dependency>
            <groupId>org.zalando</groupId>
            <artifactId>logbook-spring-boot-autoconfigure</artifactId>
            <version>3.2.0</version>
        </dependency>
        <dependency>
            <groupId>org.zalando</groupId>
            <artifactId>logbook-netty</artifactId>
            <version>3.2.0</version>
        </dependency>

        <!-- Utils -->
        <dependency>
            <groupId>org.apache.commons</groupId>
            <artifactId>commons-lang3</artifactId>
            <version>3.12.0</version>
        </dependency>

        <!-- Kotlin -->
        <dependency>
            <groupId>org.jetbrains.kotlin</groupId>
            <artifactId>kotlin-stdlib</artifactId>
            <version>${kotlin.version}</version>
        </dependency>
        <dependency>
            <groupId>org.jetbrains.kotlin</groupId>
            <artifactId>kotlin-reflect</artifactId>
            <version>${kotlin.version}</version>
            <scope>runtime</scope>
        </dependency>
        <dependency>
            <groupId>org.jetbrains.kotlinx</groupId>
            <artifactId>kotlinx-coroutines-reactor</artifactId>
            <version>1.7.2</version>
            <scope>runtime</scope>
        </dependency>
        <dependency>
            <groupId>io.projectreactor.kotlin</groupId>
            <artifactId>reactor-kotlin-extensions</artifactId>
            <version>1.2.2</version>
            <scope>runtime</scope>
        </dependency>
        <dependency>
            <groupId>org.jetbrains.kotlinx.spark</groupId>
            <artifactId>kotlin-spark-api_3.3.1_${scala.version}</artifactId>
            <version>1.2.3</version>
        </dependency>
        <dependency>
            <groupId>org.jetbrains.kotlin</groupId>
            <artifactId>kotlin-test</artifactId>
            <version>${kotlin.version}</version>
            <scope>test</scope>
        </dependency>
    </dependencies>

    <build>
        <finalName>service</finalName>
        <plugins>
            <plugin>
                <groupId>org.springframework.boot</groupId>
                <artifactId>spring-boot-maven-plugin</artifactId>
                <version>3.0.6</version>
                <configuration>
                    <mainClass>com.mlwebservice.MLWebServiceApplication</mainClass>
                </configuration>
                <executions>
                    <execution>
                        <goals>
                            <goal>repackage</goal>
                        </goals>
                    </execution>
                </executions>
            </plugin>

            <plugin>
                <groupId>org.apache.maven.plugins</groupId>
                <artifactId>maven-compiler-plugin</artifactId>
                <version>3.11.0</version>
                <executions>
                    <execution>
                        <id>compile</id>
                        <phase>compile</phase>
                        <goals>
                            <goal>compile</goal>
                        </goals>
                    </execution>
                    <execution>
                        <id>testCompile</id>
                        <phase>test-compile</phase>
                        <goals>
                            <goal>testCompile</goal>
                        </goals>
                    </execution>
                </executions>
                <configuration>
                    <source>${java.version}</source>
                    <target>${java.version}</target>
                    <annotationProcessorPaths>
                        <path>
                            <groupId>org.projectlombok</groupId>
                            <artifactId>lombok</artifactId>
                            <version>${lombok.version}</version>
                        </path>
                    </annotationProcessorPaths>
                </configuration>
            </plugin>

            <plugin>
                <groupId>org.jetbrains.kotlin</groupId>
                <artifactId>kotlin-maven-plugin</artifactId>
                <version>${kotlin.version}</version>
                <executions>
                    <execution>
                        <id>compile</id>
                        <phase>process-sources</phase>
                        <goals>
                            <goal>compile</goal>
                        </goals>
                        <configuration>
                            <jvmTarget>${java.version}</jvmTarget>
                            <sourceDirs>
                                <source>src/main/java</source>
                                <source>src/main/kotlin</source>
                                <source>target/generated-sources/annotations</source>
                            </sourceDirs>
                        </configuration>
                    </execution>
                    <execution>
                        <id>test-compile</id>
                        <phase>test-compile</phase>
                        <goals>
                            <goal>test-compile</goal>
                        </goals>
                        <configuration>
                            <jvmTarget>${java.version}</jvmTarget>
                            <sourceDirs>
                                <source>src/main/java</source>
                                <source>src/main/kotlin</source>
                                <source>target/generated-sources/annotations</source>
                            </sourceDirs>
                        </configuration>
                    </execution>
                </executions>
                <configuration>
                    <jvmTarget>${java.version}</jvmTarget>
                    <sourceDirs>
                        <source>src/main/java</source>
                        <source>src/main/kotlin</source>
                        <source>target/generated-sources/annotations</source>
                    </sourceDirs>
                </configuration>
            </plugin>
        </plugins>
    </build>
</project>

Main класс

Модифицируем Main класс приложения, необходимо добавить аннотации @EnableWebFlux и @EnableR2dbcRepositories, указать тип приложения REACTIVE

Main class
package com.mlwebservice;

import org.springframework.boot.WebApplicationType;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.boot.autoconfigure.cassandra.CassandraAutoConfiguration;
import org.springframework.boot.autoconfigure.gson.GsonAutoConfiguration;
import org.springframework.boot.builder.SpringApplicationBuilder;
import org.springframework.data.r2dbc.repository.config.EnableR2dbcRepositories;
import org.springframework.web.reactive.config.EnableWebFlux;

import java.net.InetAddress;
import java.net.UnknownHostException;

@EnableWebFlux
@EnableR2dbcRepositories
@SpringBootApplication(exclude = {
        GsonAutoConfiguration.class,
        CassandraAutoConfiguration.class
})
public class MLWebServiceApplication {
    public static void main(String[] args) {
        new SpringApplicationBuilder(MLWebServiceApplication.class)
                .web(WebApplicationType.REACTIVE)
                .run(args);
        );
    }
}

Spring Data → R2DBC

Так как в сущности БД используется JSONB поле (с его отображением в приложении в виде JsonNode), необходима конфигурация R2DBC с кастомными конвертерами:

Jsonb converters
package com.mlwebservice.config

import com.fasterxml.jackson.databind.JsonNode
import com.fasterxml.jackson.databind.ObjectMapper
import io.r2dbc.postgresql.codec.Json
import org.springframework.context.annotation.Bean
import org.springframework.context.annotation.Configuration
import org.springframework.core.convert.converter.Converter
import org.springframework.data.convert.ReadingConverter
import org.springframework.data.convert.WritingConverter
import org.springframework.data.r2dbc.convert.R2dbcCustomConversions
import org.springframework.data.r2dbc.dialect.PostgresDialect

@Configuration
open class R2dbcConfiguration(private val objectMapper: ObjectMapper) {

    @Bean
    open fun customConversions() : R2dbcCustomConversions {
        val converters = listOf<Converter<*, *>>(
            JsonNodeWritingConverter(objectMapper),
            JsonNodeReadingConverter(objectMapper)
        )
        return R2dbcCustomConversions.of(PostgresDialect.INSTANCE, converters);
    }
}

@WritingConverter
class JsonNodeWritingConverter(private val objectMapper: ObjectMapper) : Converter<JsonNode, Json> {
    override fun convert(source: JsonNode): Json {
        return Json.of(objectMapper.writeValueAsString(source));
    }
}

@ReadingConverter
class JsonNodeReadingConverter(private val objectMapper: ObjectMapper) : Converter<Json, JsonNode> {
    override fun convert(source: Json): JsonNode? {
        return objectMapper.readTree(source.asString());
    }
}

Далее следует удалить из упомянутой выше сущности ModelEntity лишние аннотации, в итоге должно получиться:

ModelEntity
package com.mlwebservice.persist.entity

import com.fasterxml.jackson.databind.JsonNode
import com.fasterxml.jackson.databind.ObjectMapper
import com.fasterxml.jackson.databind.node.ObjectNode
import org.springframework.data.annotation.CreatedDate
import org.springframework.data.annotation.Id
import org.springframework.data.annotation.LastModifiedDate
import org.springframework.data.relational.core.mapping.Column
import org.springframework.data.relational.core.mapping.Table
import java.time.LocalDateTime
import java.util.*

@Table(name = "models", schema = "instrument_data")
data class ModelEntity constructor(
    @Id
    val id: Long? = null,
  
    @Column("model")
    val model: ByteArray,

    @Column("created_at")
    val createdAt: LocalDateTime,

    @Column("last_trained_at")
    val lastTrainedAt: LocalDateTime,

    @Column("task_id")
    val taskId: UUID,

    @Column("parameters")
    val parameters: JsonNode
)
// конструкторы и прочее необходимое

Сам репозиторий сущностей теперь наследуется от R2dbcRepository:

@Repository
interface ModelRepository : R2dbcRepository<ModelEntity, Long>

Методы сохранения и загрузки модели трансформируются для работы в WebFlux:

методы для работы с моделями данных

метод загрузки модели из БД

internal inline fun <reified T> loadModel(modelId: Long): T {
    val optional = modelRepository.findById(modelId)

    val entity = optional.get()
    val modelByteArray = entity.model

    val byteArrayInputStream = ByteArrayInputStream(modelByteArray)
    val modelObject = ObjectInputStream(byteArrayInputStream).use { it.readObject() }

    if (modelObject is T) {
        return modelObject
    } else {
        throw ServiceException.withMessage("Model id $modelId has incorrect format")
    }
}

модифицируется до:

internal inline fun <reified T> loadModel(modelId: Long): Mono<T> =
        modelRepository.findById(modelId)
            .map { modelEntity: ModelEntity ->
                ByteArrayInputStream(modelEntity.model)
            }
            .publishOn(Schedulers.boundedElastic())
            .map { byteArrayInputStream: ByteArrayInputStream ->
                ObjectInputStream(byteArrayInputStream).use { it.readObject() }
            }
            .flatMap { modelObject ->
                if (modelObject is T) {
                    Mono.just(modelObject)
                } else {
                    Mono.error(ServiceException.withMessage("Model id $modelId has incorrect format"))
                }
            }

а метод сохранения:

fun saveModel(
    model : PredictionModel<Vector, XGBoostRegressionModel>,
    taskId : UUID,
    modelParameters : AnalyticsRequest.ModelParameters
) {
    val byteArrayOutputStream = ByteArrayOutputStream()
    ObjectOutputStream(byteArrayOutputStream).use { it.writeObject(model) }
    val modelByteArray: ByteArray = byteArrayOutputStream.toByteArray()
    val jsonParams : JsonNode = objectMapper.convertValue(modelParameters, JsonNode::class.java)

    val entity = ModelEntity(modelByteArray, taskId, jsonParams)
    modelRepository.save(entity)
    log.info("Model for task id {} saved. Parameters map: {}, jsonNode: {}",
        taskId, modelParameters, jsonParams)
}

модифицируется до:

fun saveModel(
        model: PredictionModel<Vector, XGBoostRegressionModel>,
        taskId: UUID,
        modelParameters: AnalyticsRequest.ModelParameters
    ): Mono<Void> =
        Mono.fromCallable {
            val jsonParams: JsonNode = objectMapper.convertValue(modelParameters, JsonNode::class.java)

            val byteArrayOutputStream = ByteArrayOutputStream()
            ObjectOutputStream(byteArrayOutputStream).use { objectOutputStream ->
                objectOutputStream.writeObject(model)
            }
            val modelByteArray: ByteArray = byteArrayOutputStream.toByteArray()

            ModelEntity(modelByteArray, taskId, jsonParams)
        }
            .subscribeOn(Schedulers.boundedElastic())
            .flatMap { entity ->
                modelRepository.save(entity)
                    .doOnSuccess {
                        log.info(
                            "Model for task id {} saved. Parameters map: {}, jsonNode: {}",
                            taskId, modelParameters, entity.parameters.toString()
                        )
                    }
                    .then()
            }

Cassandra

Репозитории Cassandra строились на основе взаимодействия со спарковой сессией. Переработать методы довольно просто. Так, метод получения датасета в базовом абстрактном репозитории:

cassandraDataset web
fun cassandraDataset(keyspace: String, table: String): Dataset<Row> {
    val cassandraDataset: Dataset<Row> = sparkSession.read()
        .format("org.apache.spark.sql.cassandra")
        .option("keyspace", keyspace)
        .option("table", table)
        .load()

    cassandraDataset.createOrReplaceTempView(table)
    return cassandraDataset
}

модифицируется до:

cassandraDataset webflux
fun cassandraDataset(keyspace: String, table: String): Mono<Dataset<Row>> =
    Mono.fromCallable {
        val cassandraDataset: Dataset<Row> = sparkSession.read()
            .format("org.apache.spark.sql.cassandra")
            .option("keyspace", keyspace)
            .option("table", table)
            .load()

        cassandraDataset.createOrReplaceTempView(table)
        cassandraDataset
    }

метод сохранения датасета:

saveDataSet web
open fun saveDataSet(dataset: Dataset<Row>) {
    dataset.write()
        .format("org.apache.spark.sql.cassandra")
        .mode("append")
        .option("confirm.truncate", "false")
        .option("keyspace", keyspace)
        .option("table", table)
        .save();
}

модифицируется до:

saveDataSet webflux
open fun saveDataSet(dataset: Dataset<Row>): Mono<Void> =
    Mono.fromRunnable {
        dataset.write()
            .format("org.apache.spark.sql.cassandra")
            .mode("append")
            .option("confirm.truncate", "false")
            .option("keyspace", keyspace)
            .option("table", table)
            .save()
    }

метод получения базового датасета с определенными оффсетами:

getBaseDataSet web
fun getBaseDataSet(
    ticker: String,
    taskNumber : UUID,
    dateStart : LocalDate,
    dateEnd : LocalDate,
    currentOffset : Int,
    batchSize : Int
): Dataset<Row> {
    val filteredDataset = cassandraDataset(table)
        .filter(
            functions.col("ticker").equalTo(ticker)
                .and(functions.col("task_number").equalTo(taskNumber.toString()))
                .and(functions.col("datetime").between(dateStart, dateEnd))
        )

    val offsetDataset = filteredDataset.withColumn(
        "row_number",
        functions.row_number().over(orderBy("datetime"))
    )

    return offsetDataset
        .filter(functions.col("row_number")
            .between(currentOffset + 1, currentOffset + batchSize))
        .drop("row_number")
}

модифицируется до:

getBaseDataSet webflux
fun getBaseDataSet(
    ticker: String,
    taskNumber: UUID,
    dateStart: LocalDate,
    dateEnd: LocalDate,
    currentOffset: Int,
    batchSize: Int
): Mono<Dataset<Row>> =
    cassandraDataset(table)
        .map { dataset ->
            dataset
                .filter(
                    functions.col("ticker").equalTo(ticker)
                        .and(functions.col("task_number").equalTo(taskNumber.toString()))
                        .and(functions.col("datetime").between(dateStart, dateEnd))
                ).withColumn(
                    "row_number",
                    functions.row_number().over(orderBy("datetime"))
                )
                .filter(
                    functions.col("row_number")
                        .between(currentOffset + 1, currentOffset + batchSize)
                )
                .drop("row_number")
        }

Остальные репозитории конкретных таблиц переписываются по такому же принципу.

В сервисе работы с данными следует упомянуть метод объединения датасетов (теперь же репозитории возвращают реактивные Mono<Dataset<Row>>):

getMainDataset web
fun getMainDataset(
    ticker : String,
    taskNumber : UUID,
    dateStart : LocalDate,
    dateEnd : LocalDate
) : Dataset<Row> {
    val timeSeries = timeSeriesRepository.getDataset(ticker, taskNumber, dateStart, dateEnd).`as`("ts")
    val emaDataSet = emaRepository.getEmaDataSet(ticker, dateStart, dateEnd).`as`("ema")
    val stochasticDataset = stochasticRepository.getStochasticDataSet(ticker, dateStart, dateEnd).`as`("stoch")
    val bBandsDataset = bBandIndicatorRepository.getBBandsDataSet(ticker, dateStart, dateEnd).`as`("bb")
    val macdDataset = macdRepository.getMacdDataSet(ticker, dateStart, dateEnd).`as`("macd")
    val rsiDataset = rsiRepository.getRsiDataSet(ticker, dateStart, dateEnd).`as`("rsi")
    val smaDataset = smaRepository.getSmaDataSet(ticker, dateStart, dateEnd).`as`("sma")
    val willrDataset = willrRepository.getWillrDataSet(ticker, dateStart, dateEnd).`as`("willr")

    return combineDatasets(
        timeSeries, emaDataSet, stochasticDataset, bBandsDataset, macdDataset, rsiDataset, smaDataset, willrDataset
    )
}

модифицируется до:

getMainDataset webflux
fun getMainDataset(
    ticker: String,
    taskNumber: UUID,
    dateStart: LocalDate,
    dateEnd: LocalDate
): Mono<Dataset<Row>> {
    val timeSeriesMono = timeSeriesRepository.getDataset(ticker, taskNumber, dateStart, dateEnd)
        .map { dataset -> dataset.alias("ts") }
    val emaDataSetMono = emaRepository.getEmaDataSet(ticker, dateStart, dateEnd)
        .map { dataset -> dataset.alias("ema") }
    val stochasticDatasetMono = stochasticRepository.getStochasticDataSet(ticker, dateStart, dateEnd)
        .map { dataset -> dataset.alias("stoch") }
    val bBandsDatasetMono = bBandIndicatorRepository.getBBandsDataSet(ticker, dateStart, dateEnd)
        .map { dataset -> dataset.alias("bb") }
    val macdDatasetMono = macdRepository.getMacdDataSet(ticker, dateStart, dateEnd)
        .map { dataset -> dataset.alias("macd") }
    val rsiDatasetMono = rsiRepository.getRsiDataSet(ticker, dateStart, dateEnd)
        .map { dataset -> dataset.alias("rsi") }
    val smaDatasetMono = smaRepository.getSmaDataSet(ticker, dateStart, dateEnd)
        .map { dataset -> dataset.alias("sma") }
    val willrDatasetMono = willrRepository.getWillrDataSet(ticker, dateStart, dateEnd)
        .map { dataset -> dataset.alias("willr") }

    return Mono.zip(
        timeSeriesMono, emaDataSetMono, stochasticDatasetMono, bBandsDatasetMono,
        macdDatasetMono, rsiDatasetMono, smaDatasetMono, willrDatasetMono
    ).map { tuple ->
        combineDatasets(tuple.t1, tuple.t2, tuple.t3, tuple.t4, tuple.t5, tuple.t6, tuple.t7, tuple.t8)
    }
}

здесь получаются 8 датасетов в Mono-обертках, обертки объединяются в один Mono посредством .zip() и передаются на исполнение в метод комбинации датасетов, который не менялся.

Сервис StockAnalyticsService

Метод выполнения предикта с помощью сохраненной модели:

predictWithExistingModel web
fun predictWithExistingModel(
    ticker : String,
    taskNumber : UUID,
    dateStart : LocalDate,
    dateEnd : LocalDate,
    modelId : Long
): StockPredictDto {
    val model: PredictionModel<Vector, XGBoostRegressionModel> = modelService.loadModel(modelId)
    val data = dataReaderService.getMainDataset(ticker, taskNumber, dateStart, dateEnd)

    var predictions = model.transform(data)
    predictions = predictions.select("dateTime", "prediction")
    return StockPredictDto.fromDataset(predictions)
}

модифицируется до:

predictWithExistingModel webflux
fun predictWithExistingModel(
    ticker: String,
    taskNumber: UUID,
    dateStart: LocalDate,
    dateEnd: LocalDate,
    modelId: Long
): Mono<StockPredictDto> =
    modelService.loadModel<PredictionModel<Vector, XGBoostRegressionModel>>(modelId)
        .flatMap { model ->
            dataReaderService.getMainDataset(ticker, taskNumber, dateStart, dateEnd)
                .map { data ->
                    val predictions = model.transform(data)
                        .select("dateTime", "prediction")
                    StockPredictDto.fromDataset(predictions)
                }
        }

Метод обучения модели:

trainModel web
fun trainModel(
    ticker : String,
    taskNumber : UUID,
    dateStart : LocalDate,
    dateEnd : LocalDate,
    evalPivotPoint : Long,
    offset : Long,
    modelParameters : AnalyticsRequest.ModelParameters
) : ModelTrainResultResponse {
    val pivot = dateEnd.minusDays(evalPivotPoint)

    val tdf = dataReaderService.getDatasetWithLabel(ticker, taskNumber, dateStart, pivot, offset)
    val edf = dataReaderService.getDatasetWithLabel(ticker, taskNumber, pivot, dateEnd, offset)
        .selectExpr(*allColumns)

    val modelParams = createModelParams(modelParameters)
    val regressor = xgBoostRegressor(modelParams)

    val model: PredictionModel<Vector, XGBoostRegressionModel> = regressor.fit(tdf)
    val predictions = model.transform(edf)

    combinedDataRepository.saveData(tdf.selectExpr(*allColumns).unionAll(edf), ticker, taskNumber)
    modelService.saveModel(model, taskNumber, modelParameters)

    val result = predictions.withColumn("error", col("prediction").minus(col(labelName)))
    return ModelTrainResultResponse(ModelTrainResult.listFromDataset(result.selectExpr(*resultExp)))
}

модифицируется до:

trainModel webflux
fun trainModel(
        ticker: String,
        taskNumber: UUID,
        dateStart: LocalDate,
        dateEnd: LocalDate,
        evalPivotPoint: Long,
        offset: Long,
        modelParameters: AnalyticsRequest.ModelParameters
    ): Mono<ModelTrainResultResponse> =
        Mono.just(dateEnd.minusDays(evalPivotPoint))
            .flatMap { pivot: LocalDate ->
                dataReaderService.getDatasetWithLabel(ticker, taskNumber, dateStart, pivot, offset)
                    .zipWith(dataReaderService.getDatasetWithLabel(ticker, taskNumber, pivot, dateEnd, offset))
            }
            .flatMap { tuple: Tuple2<Dataset<Row>, Dataset<Row>> ->
                val tdf = tuple.t1
                val edf = tuple.t2

                val modelParams = createModelParams(modelParameters)
                val regressor = xgBoostRegressor(modelParams)

                Mono.fromCallable { regressor.fit(tdf) }
                    .flatMap { model: XGBoostRegressionModel ->
                        val predictions = model.transform(edf)

                        val saveDataMono = combinedDataRepository.saveData(
                            tdf.selectExpr(*allColumns).unionAll(edf),
                            ticker,
                            taskNumber
                        )

                        modelService.saveModel(model, taskNumber, modelParameters)
                            .then(saveDataMono)
                            .thenReturn(predictions)
                    }
            }
            .map { predictions: Dataset<Row> ->
                val result = predictions.withColumn("error", col("prediction").minus(col(labelName)))
                ModelTrainResultResponse(ModelTrainResult.listFromDataset(result.selectExpr(*resultExp)))
            }

здесь tdf и edf обернуты в Mono, поэтому объединяются в кортеж из двух элементов Mono<Tuple2>, далее в оборачиваем в Callable функцию regressor.fit(tdf), которая будет выполнена асинхронно и вернет результат в виде model: XGBoostRegressionModel. В функции flatMap она используется с эвалюирующим датасетом для получения предиктов, затем сохраняется в БД с помощью описанного выше метода saveModel. Остальная логика очевидна.

Наибольшую сложность вызывает метод инкрементального обучения (да, на данной модели инкремент не работает и требуется замена XGBoost на другую модель, но цель была трансформировать логику под реактивную среду и получить работающий пример, который далее можно использовать для инкрементального обучения модели).

Исходный метод:

incrementTrainModel web
fun incrementTrainModel(
    ticker : String,
    taskNumber : UUID,
    dateStart : LocalDate,
    dateEnd : LocalDate,
    evalPivotPoint : Long,
    offset : Long,
    batchSize : Int,
    modelParameters : AnalyticsRequest.ModelParameters
) : ModelTrainResultResponse {
    val pivot = dateEnd.minusDays(evalPivotPoint)
    var currentBatchOffset = 0
    var i = 0

    val modelParams = createModelParams(modelParameters)
    val regressor = xgBoostRegressor(modelParams)

    var model: PredictionModel<Vector, XGBoostRegressionModel>? = null
    var predictions: Dataset<Row>? = null

    var tdf: Dataset<Row>?
    do {
        log.info("Iteration {}: currentOffset {}", i, currentBatchOffset)
        tdf = dataReaderService.getDatasetWithLabel(
            ticker, taskNumber, dateStart, pivot, offset, currentBatchOffset, batchSize
        )
        if (tdf.isEmpty) break

        model = regressor.fit(tdf)
        combinedDataRepository.saveData(tdf.selectExpr(*allColumns), ticker, taskNumber)

        currentBatchOffset += batchSize
        i++
    } while (tdf?.isEmpty == false)

    val edf = dataReaderService.getDatasetWithLabel(
        ticker, taskNumber, pivot, dateEnd, offset, 0, 100).selectExpr(*allColumns)
    if (model != null) {
        predictions = model.transform(edf)
    }
    combinedDataRepository.saveData(edf.selectExpr(*allColumns), ticker, taskNumber)
    modelService.saveModel(model!!, taskNumber, modelParameters)

    val result = predictions!!.withColumn("error", col("prediction").minus(col(labelName)))
    return ModelTrainResultResponse(ModelTrainResult.listFromDataset(result.selectExpr(*resultExp)))
}

модифицируется до:

incrementTrainModel webflux
fun incrementTrainModel(
        ticker: String,
        taskNumber: UUID,
        dateStart: LocalDate,
        dateEnd: LocalDate,
        evalPivotPoint: Long,
        offset: Long,
        batchSize: Int,
        modelParameters: AnalyticsRequest.ModelParameters
    ): Mono<ModelTrainResultResponse> {
        val pivot = dateEnd.minusDays(evalPivotPoint)
        var currentBatchOffset = 0
        var i = 0

        val modelParams = createModelParams(modelParameters)
        val regressor = xgBoostRegressor(modelParams)

        var model: PredictionModel<Vector, XGBoostRegressionModel>? = null
        var tdf: Dataset<Row>? = null

        return Mono.defer {
            dataReaderService.getDatasetWithLabel(
                ticker, taskNumber, dateStart, pivot, offset, currentBatchOffset, batchSize
            )
        }
            .map { dataset ->
                tdf = dataset
                log.info("Iteration {}: currentOffset {}", i, currentBatchOffset)
                if (tdf?.isEmpty == true) {
                    log.warn(
                        "tdf is empty, no more data for learning, Iteration {}: currentOffset {}",
                        i, currentBatchOffset
                    )
                    Mono.empty()
                } else {
                    model = regressor.fit(tdf)
                    log.info("model trained, Iteration {}: currentOffset {}", i, currentBatchOffset)
                    currentBatchOffset += batchSize
                    i++
                    combinedDataRepository.saveData(tdf!!.selectExpr(*allColumns), ticker, taskNumber)
                        .thenReturn(currentBatchOffset + batchSize)
                }
            }
            .repeat { tdf?.isEmpty == false }
            .then(dataReaderService
                .getDatasetWithLabel(ticker, taskNumber, pivot, dateEnd, offset, 0, 100)
                .flatMap { edf ->
                    log.info("Got edf")
                    combinedDataRepository.saveData(edf.selectExpr(*allColumns), ticker, taskNumber)
                        .then(modelService.saveModel(model!!, taskNumber, modelParameters))
                        .thenReturn(model!!.transform(edf))
                }.map { predictions ->
                    log.info("Predictions stage")
                    val result = predictions?.withColumn(
                        "error", col("prediction")
                            .minus(col(labelName))
                    )
                    ModelTrainResultResponse(ModelTrainResult.listFromDataset(result!!.selectExpr(*resultExp)))
                })
            .doOnError { exception ->
                log.error("Error while increment learning; taskNumber = {}", taskNumber, exception)
                ModelTrainResultResponse()
            }
    }

В отличие от Java, лямбда-выражения в Kotlin не требуют от переменных, чтобы они были effectively final, поэтому переменные currentBatchOffset, i, model и tdf могут изменяться в ходе выполнения основного стрима.

Здесь функция получения датасета обертывается в Mono.defer(). Особенность данного подхода в том, что выполнение функции откладывается до момента подписки на данный Mono. А подписка будет повторяться методом .repeat() до тех пор, пока не выполнится условие tdf?.isEmpty == false.

Когда очередной tdf будет пустым, выполнится логика в then: из кассандры будет получен датасет edf, который сохранится в таблице скомбинированных данных, так же будут получены предикты модели и сохранена сама модель. Затем из предиктов подготовится результат метода. В случае ошибки вернется пустой результат метода.

Не сказать, что это идеальное исполнение метода, но как пример сойдет.

Подробно реализацию можно посмотреть в отдельной ветке репозитория.

Сравнение двух реализаций

Как известно, реактивный стек отличается от сервлетного тем, что для выполнения одной и той же логики зачастую требуется меньше ресурсов. В некоторых случаях может возрасти скорость выполнения алгоритма.

Тестирование происходило по следующей методике:

  1. Сервис поднимается в Docker-контейнере с 4 CPU и 4 Gb памяти, использует Spark Executor (v. 3.3.2, JDK 17), так же в Docker контейнере, который подключается к Standalone-мастеру Spark в виртуальной машине. Все работает на одной машине под управлением Windows 10 Pro, задачи тренировки моделей выполняются на GPU NVidia 4090.

  2. В течении 10 минут производятся запросы методов: обучения новой модели (POST /analytics - для сокращения “1 запрос”), получения предиктов с помощью сохраненной модели (GET /analytics - для сокращения “2 запрос”) и инкрементального обучения (POST /analytics/increment - для сокращения “3 запрос”) с batch_size = 50 записей, во время которого делается 12 итераций над 6 сотнями записей в таблицах Cassandra. Первый цикл на “не прогретом” драйвере (первые запросы всегда выполняются дольше), далее два одинаковых цикла по одному запросу каждого метода на “прогретом драйвере” и в четвертом цикле запускаются одновременно 1, 2, 3 методы.

  3. Driver работает в режиме Spark Cluster, используется одна Spark Session на все время работы приложения;

  4. Изначальные параметры запуска JVM одинаковые: первоначальный размер кучи 512 Мб, максимальный размер не указан, GC по умолчанию (G1).

Результаты потребления ресурсов:

Максимальное потребление CPU

Среднее потребление CPU

Максимальное потребление памяти, Gb

Среднее потребление памяти, Gb

Количество Stop the world за 10 минут

Spring Web

3,4

1,5

4

2

4

Spring Webflux

3,4

1,1

1

0,5

0

С указанными выше параметрами для сервлетного стека наблюдалось 4 stop the world от G1 GC, при этом один раз результатом выполнения предиктов из сохраненной модели стала ошибка сервера.

На графике видно, что потребление памяти растет линейно до момента, когда свободного места для кучи уже нет и ее необходимо чистить.

График потребления ресурсов Web приложения с параметрами JVM -Xms512m
График потребления ресурсов Web приложения с параметрами JVM -Xms512m

У реактивного стека другая картина: после первых запросов стабильные ~0.5 Гб памяти. По потреблению CPU разница не настолько большая.

График потребления ресурсов WebFlux приложения с параметрами JVM -Xms512m
График потребления ресурсов WebFlux приложения с параметрами JVM -Xms512m

Скорость выполнения запросов:

Сравнительная таблица Web и WebFlux версия приложения с параметрами JVM -Xms512m
Сравнительная таблица Web и WebFlux версия приложения с параметрами JVM -Xms512m

Топ-5 классов по потреблению памяти:

Учитывая, что весь результирующий датасет занимает около 55 Мб, такой объем аллоцированной памяти вызывает вопросы. Анализ стектрейсов показал, что в большинстве случаев источником и причиной является Spark и Rapids, которые строят план запросов, обмениваются данными между БД, экзекуторами и драйвером, подготавливают массивы данных для загрузки в GPU и вычитывают результат из него. Потратив некоторое время на изучение вопроса оптимизации использования памяти, могу сделать вывод, что это штатное поведение системы в такой конфигурации, и надо научиться с этим жить при использовании сервлетного стека.

Первые попытки жить с этим в привели к изменению параметров запуска JVM для сервлетного стека на следующие: -Xms512m -Xmx3g -XX:GCTimeRatio=19 (жесткое указание того, что система может потратить до 5% времени на сборку мусора - (1 / (1+19))) -XX:+UseZGC. Учитывая, что реактивному стеку достаточно в среднем 512 Мб памяти, и что Z GC потребляет несколько больше памяти, чем G1 GC, планка максимального размера кучи снизилась до 3 Гб.

График потребления ресурсов Web приложения с параметрами JVM -Xms512m -Xmx3g -XX:GCTimeRatio=19 -XX:+UseZGC
График потребления ресурсов Web приложения с параметрами JVM -Xms512m -Xmx3g -XX:GCTimeRatio=19 -XX:+UseZGC

Потребление CPU незначительно снизилось, наблюдается схожее потребление памяти, но stop the world уже не зафиксировано. Судя по графику, после завершения работы методов POST /analytics и GET /analytics куча очищается, но при работе POST /analytics/increment куча очищается только к моменту приближения к своему максимальному размеру. Логика, которая могла бы вести к утечке памяти, отсутствует, причина такого высокого потребления памяти остается не выясненной.

Результаты переключения GC в таблице потребления ресурсов:


Максимальное потребление CPU

Среднее потребление CPU

Максимальное потребление памяти, Gb

Среднее потребление памяти, Gb

Количество Stop the world

Spring Web

3,4

1,5

3

1,5

0

Spring Webflux

3,4

1,1

1

0,5

0

и скорости выполнения запросов:

Сравнительная таблица скорости выполнения запросов Web приложения с параметрами JVM -Xms512m -Xmx3g -XX:GCTimeRatio=19 -XX:+UseZGC относительно WebFlux приложения с G1 GC и параметрами JVM -Xms512m
Сравнительная таблица скорости выполнения запросов Web приложения с параметрами JVM -Xms512m -Xmx3g -XX:GCTimeRatio=19 -XX:+UseZGC относительно WebFlux приложения с G1 GC и параметрами JVM -Xms512m

Стало интересно, что будет, если для G1 GC установить максимальный размер кучи как для Z GC и установить жесткий предел времени выполнения на сборку мусора. В этом случае оказалось, что память заполняется как и раньше, но stop the world стало больше, так как доступной памяти меньше, и, соответственно, заполняется она быстрее. Потребление ресурсов осталось примерно на том же уровне:

График потребления ресурсов Web приложения с параметрами JVM -Xms512m -Xmx3g -XX:GCTimeRatio=19 -XX:+UseG1GC
График потребления ресурсов Web приложения с параметрами JVM -Xms512m -Xmx3g -XX:GCTimeRatio=19 -XX:+UseG1GC


Максимальное потребление CPU

Среднее потребление CPU

Максимальное потребление памяти, Gb

Среднее потребление памяти, Gb

Количество Stop the world

Spring Web

3,4

1,5

3

1,5

6

Spring Webflux

3,4

1,1

1

0,5

0

Скорость выполнения запросов возросла, но не существенно.

Сравнительная таблица скорости выполнения запросов Web приложения с параметрами JVM -Xms512m -Xmx3g -XX:GCTimeRatio=19 -XX:+UseG1GC относительно WebFlux приложения с G1 GC и параметрами JVM -Xms512m
Сравнительная таблица скорости выполнения запросов Web приложения с параметрами JVM -Xms512m -Xmx3g -XX:GCTimeRatio=19 -XX:+UseG1GC относительно WebFlux приложения с G1 GC и параметрами JVM -Xms512m

Так же попробовал в сервлетном стеке использовать Parallel GC с параметрами -Xms512m -Xmx4g -XX:GCTimeRatio=19 -XX:+UseParallelGC. Результаты самые худшие, за 10 минут удалось прогнать только 2 цикла. Если первые два метода выполнялись примерно за то же время без отклонений, то метод инкрементального обучения выполнялся в первый раз 3мин 32с, что хуже примерно на 1,5 минуты среднего результата сервлетного стека, а второй запрос подвис и выполнялся 8мин 10с. Результаты в таблицах не фиксировал.

График потребления ресурсов Web версии приложения с Pasrallel GC
График потребления ресурсов Web версии приложения с Pasrallel GC

Напоследок применил настройки JVM -Xms512m -Xmx3g -XX:GCTimeRatio=19 -XX:+UseZGC к версии WebFlux, которая оказалась на данный момент самой оптимальной по потреблению ресурсов и скорости обработки запросов. Сравнительные таблицы версии с дефолтными параметрами и G1 GC и версии с кастомными параметрами JVM с Z GC ниже.

График потребления ресурсов:

График потребления ресурсов WebFlux приложения с параметрами JVM -Xms512m -Xmx3g -XX:GCTimeRatio=19 -XX:+UseZGC
График потребления ресурсов WebFlux приложения с параметрами JVM -Xms512m -Xmx3g -XX:GCTimeRatio=19 -XX:+UseZGC

Таблица потребления ресурсов:


Максимальное потребление CPU

Среднее потребление CPU

Максимальное потребление памяти, Gb

Среднее потребление памяти, Gb

Количество Stop the world

Spring Webflux G1 GC

3,4

1,1

1

0,5

0

Spring Webflux Z GC

3,4

1,4

2,84

1,5

0

Таблица скорости выполнения запросов:

Сравнительная таблица скорости выполнения запросов WebFlux приложения с параметрами JVM -Xms512m -Xmx3g -XX:GCTimeRatio=19 -XX:+UseZGC относительно WebFlux приложения с G1 GC и параметрами JVM -Xms512m
Сравнительная таблица скорости выполнения запросов WebFlux приложения с параметрами JVM -Xms512m -Xmx3g -XX:GCTimeRatio=19 -XX:+UseZGC относительно WebFlux приложения с G1 GC и параметрами JVM -Xms512m

Итоговые таблицы со всеми версиями и различными значениями параметров конфигурации JVM представлены ниже. За основу взяты результаты для WebFlux на G1 GC с одним параметром JVM минимального размера хипа 512m.

Сводные таблицы по потреблению ресурсов
Сводные таблицы по потреблению ресурсов
Сводные таблицы времени выполнения запросов
Сводные таблицы времени выполнения запросов

Вывод

Подводя черту после написания третьей статьи на тему построения системы распределенного машинного обучения на Java и Kotlin, самый большой вывод, который напрашивается - построить подобную систему сложно, много неизвестных, необходимо выполнить много исследований, но добиться работающего решения вполне реально, было бы желание.

Если так случилось, что нужно выполнять задачи ML на JVM стеке технологий, учите Python и не занимайтесь фигней, а руководству продайте альтернативную систему отличным выбором в качестве основы будет Kotlin и Spring Webflux (как альтернатива - Web с Z GC), и, естественно, Apache Spark. По окончанию работ над любым приложением стоит провести проверку профилировщиком, так как с очень высокой вероятностью при дефолтных параметрах JVM работа приложения не будет оптимальной.

Другой вопрос - является ли данная система эффективной с точки зрения производительности и потребления ресурсов? Без тестов на альтернативной системе (например, Python + Dask) объективно ответить на данный вопрос я затрудняюсь. Возможно, в будущем попробую поднять такую систему и написать альтернативную логику на питоне, тогда будет с чем сравнить и о чем написать очередную статью.

Теги:
Хабы:
Всего голосов 6: ↑5 и ↓1+8
Комментарии4

Публикации

Истории

Работа

Java разработчик
388 вакансий

Ближайшие события

27 августа – 7 октября
Премия digital-кейсов «Проксима»
МоскваОнлайн
28 – 29 сентября
Конференция E-CODE
МоскваОнлайн
28 сентября – 5 октября
О! Хакатон
Онлайн
30 сентября – 1 октября
Конференция фронтенд-разработчиков FrontendConf 2024
МоскваОнлайн
3 – 18 октября
Kokoc Hackathon 2024
Онлайн
7 – 8 ноября
Конференция byteoilgas_conf 2024
МоскваОнлайн