Elasticsearch 与机器学习结合:实现高效模型推理的方案(上)

在大数据时代,搜索引擎与机器学习的融合已成为数据处理领域的重要技术方向。Elasticsearch 不仅是高性能的搜索与分析引擎,还提供了完整的机器学习推理框架,使我们能在分布式环境中高效部署和运行 ML 模型。

1. Elasticsearch 机器学习功能概述

Elasticsearch 的机器学习能力随版本不断发展,从早期的异常检测功能,到如今成熟的模型推理框架,为数据分析提供了强大支持。

核心功能包括:

  • 异常检测:自动识别数据中的异常模式和离群值
  • 预测分析:基于历史数据进行时序预测和趋势分析
  • 模型推理:在 ES 中部署外部训练的机器学习模型并执行实时或批量推理

ES 机器学习框架的主要优势在于能够利用其分布式架构实现高可用、高扩展的模型部署,无需额外维护专门的 ML 服务基础设施。

版本功能对照

特性 ES 7.x ES 8.0-8.3 ES 8.4+ ES 8.8+
异常检测
时序预测 有限支持
ONNX 模型支持 部分支持 完整支持 完整支持
PyTorch 模型原生支持
分布式推理 有限支持 完整支持 增强支持
NLP 预训练模型库 有限支持 丰富支持 全面支持
推理内存限制控制 增强控制 精细控制
模型部署隔离性 有限支持 增强支持

2. 推理功能原理与支持的模型类型

2.1 支持的模型类型

Elasticsearch 当前支持以下类型的机器学习模型:

  1. PyTorch 模型(通过 ONNX 格式转换或 8.8+版本原生支持)
  2. scikit-learn 模型(支持分类器、回归器等)
  3. XGBoost 模型(分类与回归)
  4. LightGBM 模型(分类与回归)
  5. 预训练的 NLP 模型(BERT、sentence-transformers 等)

2.2 ONNX 操作支持说明

对于 ONNX 模型,Elasticsearch 支持的操作集包括:

  • 基础操作:Add, Sub, Mul, Div, Pow
  • 神经网络层:Conv, MaxPool, BatchNormalization, Dropout
  • 激活函数:Relu, Sigmoid, Tanh, LeakyRelu
  • 数据操作:Reshape, Transpose, Concat, Split
  • 序列处理:LSTM, GRU (从 8.4 版本开始完整支持)

在导入模型前应确认模型使用的操作在 ES 支持列表中。可以通过 ONNX Runtime 的检查工具验证模型兼容性:

java 复制代码
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtSession;

public class ModelCompatibilityChecker {
    public static boolean isModelCompatible(String modelPath) {
        try {
            OrtEnvironment env = OrtEnvironment.getEnvironment();
            OrtSession.SessionOptions options = new OrtSession.SessionOptions();

            // 加载模型进行验证
            OrtSession session = env.createSession(modelPath, options);

            // 检查模型输入输出
            session.getInputNames().forEach(name ->
                System.out.println("输入: " + name));
            session.getOutputNames().forEach(name ->
                System.out.println("输出: " + name));

            session.close();
            env.close();
            return true;
        } catch (Exception e) {
            System.err.println("模型兼容性检查失败: " + e.getMessage());
            return false;
        }
    }
}

2.3 推理处理流程

推理过程通过 Elasticsearch 的 Ingest Pipeline(摄入管道)实现:

  1. 数据文档通过管道时,推理处理器被触发
  2. 处理器从文档中提取特征数据
  3. 数据传入模型进行推理
  4. 推理结果存储在目标字段
  5. 处理后的文档继续完成索引流程

3. 实现步骤与环境准备

3.1 环境准备

  1. Elasticsearch 8.x(建议 8.4+版本)
  2. Java 11+
  3. 适配 ES 版本的 Java 客户端

Maven 依赖配置:

java 复制代码
<dependencies>
    <!-- Elasticsearch Java客户端 -->
    <dependency>
        <groupId>co.elastic.clients</groupId>
        <artifactId>elasticsearch-java</artifactId>
        <version>8.8.0</version>
    </dependency>

    <!-- JSON处理 -->
    <dependency>
        <groupId>com.fasterxml.jackson.core</groupId>
        <artifactId>jackson-databind</artifactId>
        <version>2.14.2</version>
    </dependency>

    <!-- 日志框架 -->
    <dependency>
        <groupId>org.slf4j</groupId>
        <artifactId>slf4j-api</artifactId>
        <version>1.7.36</version>
    </dependency>
    <dependency>
        <groupId>ch.qos.logback</groupId>
        <artifactId>logback-classic</artifactId>
        <version>1.2.11</version>
    </dependency>

    <!-- ONNX Runtime (用于模型验证和优化) -->
    <dependency>
        <groupId>com.microsoft.onnxruntime</groupId>
        <artifactId>onnxruntime</artifactId>
        <version>1.13.1</version>
    </dependency>
</dependencies>

3.2 日志配置示例

为确保生产环境中的日志管理规范,添加以下 logback.xml 配置:

xml 复制代码
<?xml version="1.0" encoding="UTF-8"?>
<configuration>
    <appender name="CONSOLE" class="ch.qos.logback.core.ConsoleAppender">
        <encoder>
            <pattern>%d{yyyy-MM-dd HH:mm:ss.SSS} [%thread] %-5level %logger{36} - %msg%n</pattern>
        </encoder>
    </appender>

    <appender name="FILE" class="ch.qos.logback.core.rolling.RollingFileAppender">
        <file>logs/es-ml-app.log</file>
        <rollingPolicy class="ch.qos.logback.core.rolling.TimeBasedRollingPolicy">
            <fileNamePattern>logs/es-ml-app.%d{yyyy-MM-dd}.log</fileNamePattern>
            <maxHistory>30</maxHistory>
        </rollingPolicy>
        <encoder>
            <pattern>%d{yyyy-MM-dd HH:mm:ss.SSS} [%thread] %-5level %logger{36} - %msg%n</pattern>
        </encoder>
    </appender>

    <logger name="co.elastic" level="INFO"/>
    <logger name="org.elasticsearch" level="WARN"/>
    <logger name="com.yourcompany.esml" level="DEBUG"/>

    <root level="INFO">
        <appender-ref ref="CONSOLE"/>
        <appender-ref ref="FILE"/>
    </root>
</configuration>

3.3 Elasticsearch 配置参数

高效运行 ML 推理需要合理配置以下关键参数:

yaml 复制代码
# elasticsearch.yml配置示例
xpack.ml.max_model_memory_limit: 1gb      # 单个模型最大内存
xpack.ml.max_machine_memory_percent: 30   # ML可使用的最大机器内存百分比
xpack.ml.max_inference_processors: 4      # 每节点最大推理处理器数量
thread_pool.ingest.queue_size: 200        # 推理请求队列大小
thread_pool.ingest.size: 8                # 推理线程池大小

3.4 索引映射定义

为推理结果创建合适的索引映射,确保字段类型正确并支持高效查询:

java 复制代码
import co.elastic.clients.elasticsearch.ElasticsearchClient;
import co.elastic.clients.elasticsearch.indices.CreateIndexResponse;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;

public class IndexManager {
    private static final Logger logger = LoggerFactory.getLogger(IndexManager.class);

    /**
     * 创建推理结果索引模板
     */
    public void createInferenceIndexTemplate(ElasticsearchClient client) throws IOException {
        try {
            client.indices().putIndexTemplate(t -> t
                .name("ml-inference-template")
                .indexPatterns("inference-*", "ml-results-*")
                .template(it -> it
                    .settings(s -> s
                        .numberOfShards(3)
                        .numberOfReplicas(1)
                        .refreshInterval(r -> r.time("5s"))
                    )
                    .mappings(m -> m
                        .properties("text", p -> p.text(t -> t
                            .analyzer("standard")
                            .fields("keyword", f -> f.keyword(k -> k))
                        ))
                        .properties("prediction", p -> p.object(o -> o
                            .properties("score", s -> s.float32(f -> f))
                            .properties("label", l -> l.keyword(k -> k))
                        ))
                        .properties("timestamp", p -> p.date(d -> d))
                        .properties("processing_time_ms", p -> p.long_(l -> l))
                    )
                )
            );

            logger.info("推理结果索引模板创建成功");
        } catch (Exception e) {
            logger.error("创建索引模板失败: {}", e.getMessage(), e);
            throw e;
        }
    }

    /**
     * 创建特定的推理索引
     */
    public boolean createInferenceIndex(ElasticsearchClient client, String indexName) {
        try {
            CreateIndexResponse response = client.indices().create(c -> c
                .index(indexName)
                .aliases(indexName + "_alias", a -> a)
            );

            boolean acknowledged = response.acknowledged();
            logger.info("索引 {} 创建{}", indexName, acknowledged ? "成功" : "失败");
            return acknowledged;
        } catch (Exception e) {
            if (e.getMessage().contains("resource_already_exists_exception")) {
                logger.info("索引 {} 已存在", indexName);
                return true;
            }
            logger.error("创建索引失败: {}", e.getMessage(), e);
            return false;
        }
    }
}

4. 代码实现:模型导入与推理

4.1 异常类型定义

首先,定义清晰的异常层次结构:

java 复制代码
/**
 * ES机器学习操作基础异常
 */
public class ESMLException extends RuntimeException {
    public ESMLException(String message) {
        super(message);
    }

    public ESMLException(String message, Throwable cause) {
        super(message, cause);
    }
}

/**
 * 模型操作异常
 */
public class ModelOperationException extends ESMLException {
    private final String modelId;
    private final int statusCode;

    public ModelOperationException(String message, String modelId, int statusCode, Throwable cause) {
        super(message, cause);
        this.modelId = modelId;
        this.statusCode = statusCode;
    }

    public ModelOperationException(String message, String modelId, int statusCode) {
        this(message, modelId, statusCode, null);
    }

    public String getModelId() {
        return modelId;
    }

    public int getStatusCode() {
        return statusCode;
    }
}

/**
 * 推理异常
 */
public class InferenceException extends ESMLException {
    private final String pipelineId;

    public InferenceException(String message, String pipelineId, Throwable cause) {
        super(message, cause);
        this.pipelineId = pipelineId;
    }

    public InferenceException(String message, String pipelineId) {
        this(message, pipelineId, null);
    }

    public String getPipelineId() {
        return pipelineId;
    }
}

/**
 * 索引操作异常
 */
public class IndexOperationException extends ESMLException {
    private final String indexName;

    public IndexOperationException(String message, String indexName, Throwable cause) {
        super(message, cause);
        this.indexName = indexName;
    }

    public String getIndexName() {
        return indexName;
    }
}

4.2 ES 客户端工具类

java 复制代码
import co.elastic.clients.elasticsearch.ElasticsearchClient;
import co.elastic.clients.json.jackson.JacksonJsonpMapper;
import co.elastic.clients.transport.ElasticsearchTransport;
import co.elastic.clients.transport.rest_client.RestClientTransport;
import org.apache.http.HttpHost;
import org.apache.http.auth.AuthScope;
import org.apache.http.auth.UsernamePasswordCredentials;
import org.apache.http.client.CredentialsProvider;
import org.apache.http.client.config.RequestConfig;
import org.apache.http.conn.ssl.TrustSelfSignedStrategy;
import org.apache.http.impl.client.BasicCredentialsProvider;
import org.apache.http.impl.nio.client.HttpAsyncClientBuilder;
import org.apache.http.ssl.SSLContextBuilder;
import org.apache.http.ssl.SSLContexts;
import org.elasticsearch.client.RestClient;
import org.elasticsearch.client.RestClientBuilder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.net.ssl.SSLContext;
import java.io.IOException;
import java.io.InputStream;
import java.nio.file.Files;
import java.nio.file.Path;
import java.security.KeyStore;
import java.security.cert.Certificate;
import java.security.cert.CertificateFactory;
import java.util.Objects;
import java.util.function.Supplier;

public class ESClientUtil {
    private static final Logger logger = LoggerFactory.getLogger(ESClientUtil.class);

    // 从配置文件或环境变量读取连接信息
    private static final String ES_HOST = System.getProperty("es.host", "localhost");
    private static final int ES_PORT = Integer.parseInt(System.getProperty("es.port", "9200"));
    private static final String ES_PROTOCOL = System.getProperty("es.protocol", "https");
    private static final String ES_USERNAME = System.getProperty("es.username", "elastic");
    private static final String ES_PASSWORD = System.getProperty("es.password", "changeme");
    private static final String ES_CERT_PATH = System.getProperty("es.cert.path");

    // 连接池配置
    private static final int MAX_CONN_TOTAL = 100;
    private static final int MAX_CONN_PER_ROUTE = 30;
    private static final int CONNECTION_TIMEOUT_MS = 5000;
    private static final int SOCKET_TIMEOUT_MS = 60000;

    // 重试配置
    private static final int MAX_RETRIES = 3;
    private static final long RETRY_BACKOFF_MS = 1000;

    /**
     * 创建带连接池的ES客户端
     * @return 配置好的Elasticsearch客户端
     * @throws IOException 如果创建客户端失败
     */
    public static ElasticsearchClient createClient() throws IOException {
        Objects.requireNonNull(ES_HOST, "ES主机地址不能为空");

        RestClientBuilder builder = RestClient.builder(
                new HttpHost(ES_HOST, ES_PORT, ES_PROTOCOL));

        // 配置连接池
        builder.setHttpClientConfigCallback(httpClientBuilder -> {
            httpClientBuilder.setMaxConnTotal(MAX_CONN_TOTAL);
            httpClientBuilder.setMaxConnPerRoute(MAX_CONN_PER_ROUTE);

            // 配置认证
            if (ES_USERNAME != null && !ES_USERNAME.isEmpty()) {
                final CredentialsProvider credentialsProvider = new BasicCredentialsProvider();
                credentialsProvider.setCredentials(AuthScope.ANY,
                        new UsernamePasswordCredentials(ES_USERNAME, ES_PASSWORD));
                httpClientBuilder.setDefaultCredentialsProvider(credentialsProvider);
            }

            // 配置SSL证书
            if (ES_CERT_PATH != null && !ES_CERT_PATH.isEmpty()) {
                try {
                    SSLContext sslContext = buildSSLContext();
                    httpClientBuilder.setSSLContext(sslContext);
                } catch (Exception e) {
                    logger.error("SSL证书配置失败", e);
                }
            }

            return configureHttpClient(httpClientBuilder);
        });

        // 配置请求超时
        builder.setRequestConfigCallback(requestConfigBuilder ->
            requestConfigBuilder
                .setConnectTimeout(CONNECTION_TIMEOUT_MS)
                .setSocketTimeout(SOCKET_TIMEOUT_MS)
        );

        // 创建客户端
        try {
            ElasticsearchTransport transport = new RestClientTransport(
                    builder.build(), new JacksonJsonpMapper());
            return new ElasticsearchClient(transport);
        } catch (Exception e) {
            logger.error("创建ES客户端失败", e);
            throw new IOException("创建ES客户端失败: " + e.getMessage(), e);
        }
    }

    /**
     * 配置HTTP客户端
     */
    private static HttpAsyncClientBuilder configureHttpClient(HttpAsyncClientBuilder httpClientBuilder) {
        return httpClientBuilder
            .setDefaultRequestConfig(
                RequestConfig.custom()
                    .setConnectTimeout(CONNECTION_TIMEOUT_MS)
                    .setSocketTimeout(SOCKET_TIMEOUT_MS)
                    .build()
            );
    }

    /**
     * 构建SSL上下文
     */
    private static SSLContext buildSSLContext() throws Exception {
        Path certPath = Path.of(ES_CERT_PATH);

        if (!Files.exists(certPath)) {
            throw new IllegalArgumentException("证书文件不存在: " + certPath);
        }

        CertificateFactory factory = CertificateFactory.getInstance("X.509");
        Certificate trustedCa;

        try (InputStream is = Files.newInputStream(certPath)) {
            trustedCa = factory.generateCertificate(is);
        }

        KeyStore trustStore = KeyStore.getInstance("pkcs12");
        trustStore.load(null, null);
        trustStore.setCertificateEntry("ca", trustedCa);

        return SSLContexts.custom()
                .loadTrustMaterial(trustStore, new TrustSelfSignedStrategy())
                .build();
    }

    /**
     * 关闭ES传输层
     */
    public static void closeTransport(ElasticsearchTransport transport) {
        if (transport != null) {
            try {
                transport.close();
            } catch (IOException e) {
                logger.warn("关闭ES传输层失败", e);
            }
        }
    }

    /**
     * 执行带重试的ES操作
     * @param operation 要执行的操作
     * @return 操作结果
     * @throws IOException 如果操作最终失败
     */
    public static <T> T executeWithRetry(Supplier<T> operation) throws IOException {
        int attempts = 0;
        IOException lastException = null;

        while (attempts < MAX_RETRIES) {
            try {
                return operation.get();
            } catch (IOException e) {
                if (isRetryableException(e)) {
                    lastException = e;
                    attempts++;

                    if (attempts < MAX_RETRIES) {
                        long backoffTime = RETRY_BACKOFF_MS * attempts;
                        logger.warn("ES操作失败,将在{}ms后重试(尝试{}/{}): {}",
                                backoffTime, attempts, MAX_RETRIES, e.getMessage());
                        try {
                            Thread.sleep(backoffTime);
                        } catch (InterruptedException ie) {
                            Thread.currentThread().interrupt();
                            throw new IOException("重试等待被中断", ie);
                        }
                    }
                } else {
                    // 不可重试的异常直接抛出
                    logger.error("遇到不可重试的ES异常", e);
                    throw e;
                }
            }
        }

        logger.error("ES操作在{}次尝试后仍然失败", MAX_RETRIES);
        throw lastException;
    }

    /**
     * 判断异常是否可以重试
     */
    private static boolean isRetryableException(IOException e) {
        // 网络相关异常通常可以重试
        if (e instanceof java.net.SocketTimeoutException ||
            e instanceof java.net.ConnectException) {
            return true;
        }

        // 特定HTTP状态码也可以重试
        String message = e.getMessage();
        if (message != null && (
                message.contains("429") || // Too Many Requests
                message.contains("503") || // Service Unavailable
                message.contains("507")    // Insufficient Storage
            )) {
            return true;
        }

        return false;
    }

    /**
     * 检查ES集群健康状态
     */
    public static boolean isClusterHealthy() {
        try (var client = createClient()) {
            var response = client.cluster().health();
            String status = response.status().toString();
            return "green".equals(status) || "yellow".equals(status);
        } catch (Exception e) {
            logger.error("检查集群健康状态失败: {}", e.getMessage(), e);
            return false;
        }
    }
}

4.3 模型上传接口与实现

首先定义模型上传的接口,遵循依赖倒置原则:

java 复制代码
/**
 * 模型上传器接口
 */
public interface ModelUploader {
    /**
     * 上传模型到Elasticsearch
     * @param modelId 模型ID
     * @param modelPath 模型文件路径
     * @param description 模型描述
     * @param tags 模型标签
     * @throws ModelOperationException 如果上传失败
     */
    void uploadModel(String modelId, Path modelPath, String description, String... tags);

    /**
     * 检查模型是否存在
     * @param modelId 模型ID
     * @return 模型是否存在
     */
    boolean modelExists(String modelId);

    /**
     * 创建模型别名
     * @param modelId 模型ID
     * @param aliasName 别名
     */
    void createModelAlias(String modelId, String aliasName);
}

然后是接口实现:

java 复制代码
import co.elastic.clients.elasticsearch.ElasticsearchClient;
import co.elastic.clients.elasticsearch.ml.PutTrainedModelRequest;
import co.elastic.clients.elasticsearch.ml.TrainedModelConfig;
import co.elastic.clients.elasticsearch._types.ElasticsearchException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.io.InputStream;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.Arrays;
import java.util.Base64;
import java.util.Objects;
import java.util.concurrent.locks.ReadWriteLock;
import java.util.concurrent.locks.ReentrantReadWriteLock;

public class ModelUploadService implements ModelUploader {
    private static final Logger logger = LoggerFactory.getLogger(ModelUploadService.class);
    private static final int CHUNK_SIZE = 1024 * 1024;  // 1MB
    private static final int MAX_WAIT_SECONDS = 120;    // 等待模型加载的最大时间

    // 并发控制锁
    private final ReadWriteLock modelUpdateLock = new ReentrantReadWriteLock();

    /**
     * 上传ONNX模型到Elasticsearch
     */
    @Override
    public void uploadModel(String modelId, Path modelPath, String description, String... tags) {
        // 参数验证
        Objects.requireNonNull(modelId, "模型ID不能为空");
        Objects.requireNonNull(modelPath, "模型路径不能为空");
        if (!Files.exists(modelPath)) {
            throw new IllegalArgumentException("模型文件不存在: " + modelPath);
        }

        // 获取写锁,确保模型更新时的独占访问
        modelUpdateLock.writeLock().lock();
        try (var client = ESClientUtil.createClient()) {
            // 检查模型文件大小
            long fileSize = Files.size(modelPath);
            logger.info("开始上传模型 {}, 文件大小: {} 字节", modelId, fileSize);

            // 根据大小选择上传方式
            if (fileSize > 10 * 1024 * 1024) {  // 大于10MB使用分块上传
                uploadLargeModel(client, modelId, modelPath, description, tags);
            } else {
                uploadSmallModel(client, modelId, modelPath, description, tags);
            }

            // 验证模型上传状态
            verifyModelStatus(client, modelId);

            logger.info("模型 {} 上传成功", modelId);
        } catch (ElasticsearchException e) {
            logger.error("ES操作异常: {}, 状态码: {}", e.getMessage(), e.status(), e);
            throw new ModelOperationException("上传模型失败: " + e.getMessage(), modelId, e.status(), e);
        } catch (IOException e) {
            logger.error("IO异常: {}", e.getMessage(), e);
            throw new ModelOperationException("读取或上传模型失败", modelId, 500, e);
        } finally {
            modelUpdateLock.writeLock().unlock();
        }
    }

    /**
     * 上传小型模型(一次性上传)
     */
    private void uploadSmallModel(
            ElasticsearchClient client,
            String modelId,
            Path modelPath,
            String description,
            String[] tags) throws IOException {

        // 读取模型文件并编码
        String modelBase64;
        try (InputStream is = Files.newInputStream(modelPath)) {
            byte[] modelBytes = is.readAllBytes();
            modelBase64 = Base64.getEncoder().encodeToString(modelBytes);
        }

        try {
            // 创建模型配置并上传
            client.ml().putTrainedModel(PutTrainedModelRequest.of(builder ->
                builder
                    .modelId(modelId)
                    .inferenceConfig(ic -> ic.onnx(onnx -> onnx))
                    .modelType("onnx")
                    .tags(Arrays.asList(tags))
                    .description(description)
                    .definition(def -> def.modelBytes(modelBase64))
            ));
        } catch (ElasticsearchException e) {
            if (e.status() == 413) {  // 请求实体太大
                logger.warn("模型过大,尝试分块上传");
                uploadLargeModel(client, modelId, modelPath, description, tags);
            } else {
                throw e;
            }
        }
    }

    /**
     * 上传大型模型(分块上传)
     */
    private void uploadLargeModel(
            ElasticsearchClient client,
            String modelId,
            Path modelPath,
            String description,
            String[] tags) throws IOException {

        // 创建模型配置(不包含模型内容)
        createModelConfig(client, modelId, description, tags);

        // 分块上传模型内容
        uploadModelChunks(client, modelId, modelPath);
    }

    /**
     * 创建模型配置
     */
    private void createModelConfig(
            ElasticsearchClient client,
            String modelId,
            String description,
            String[] tags) throws IOException {

        client.ml().putTrainedModel(builder ->
            builder
                .modelId(modelId)
                .inferenceConfig(ic -> ic.onnx(onnx -> onnx))
                .modelType("onnx")
                .tags(Arrays.asList(tags))
                .description(description)
        );

        logger.info("创建模型 {} 配置成功", modelId);
    }

    /**
     * 分块上传模型内容
     */
    private void uploadModelChunks(
            ElasticsearchClient client,
            String modelId,
            Path modelPath) throws IOException {

        // 获取文件大小
        long fileSize = Files.size(modelPath);
        int totalParts = (int) Math.ceil((double) fileSize / CHUNK_SIZE);

        try (InputStream is = Files.newInputStream(modelPath)) {
            byte[] buffer = new byte[CHUNK_SIZE];
            int partNum = 0;
            int bytesRead;
            long totalBytesRead = 0;

            // 读取并上传每个分块
            while ((bytesRead = is.read(buffer)) != -1) {
                // 如果读取的字节数小于缓冲区大小,创建一个刚好大小的新数组
                byte[] chunk = bytesRead == buffer.length ? buffer : Arrays.copyOf(buffer, bytesRead);
                String base64Chunk = Base64.getEncoder().encodeToString(chunk);

                ESClientUtil.executeWithRetry(() -> {
                    client.ml().putTrainedModelDefinitionPart(d -> d
                        .modelId(modelId)
                        .part(partNum)
                        .definitionLength(fileSize)
                        .totalParts(totalParts)
                        .definition(base64Chunk)
                    );
                    return null;
                });

                totalBytesRead += bytesRead;
                int progressPercent = (int)((totalBytesRead * 100) / fileSize);
                logger.info("模型 {} 上传进度: {}/{} 块 ({}%)",
                        modelId, partNum + 1, totalParts, progressPercent);

                partNum++;
            }
        }
    }

    /**
     * 验证模型状态
     */
    private void verifyModelStatus(ElasticsearchClient client, String modelId) {
        try {
            // 等待模型加载完成
            int attempts = 0;
            boolean modelReady = false;
            int maxAttempts = MAX_WAIT_SECONDS; // 最多等待2分钟

            while (attempts < maxAttempts && !modelReady) {
                var response = client.ml().getTrainedModels(m -> m.modelId(modelId));
                if (response.trainedModelConfigs().isEmpty()) {
                    throw new ModelOperationException("模型未找到", modelId, 404);
                }

                var modelInfo = response.trainedModelConfigs().get(0);
                String modelState = modelInfo.modelState();

                if ("started".equals(modelState)) {
                    modelReady = true;
                    logger.info("模型 {} 已加载就绪", modelId);
                } else if ("starting".equals(modelState)) {
                    logger.info("模型 {} 当前状态: 启动中, 等待加载... (尝试 {}/{})",
                            modelId, attempts + 1, maxAttempts);
                    Thread.sleep(1000);  // 等待1秒再检查
                    attempts++;
                } else if ("failed".equals(modelState)) {
                    throw new ModelOperationException("模型加载失败: " + modelInfo.failure_reason(),
                            modelId, 500);
                } else {
                    logger.info("模型 {} 当前状态: {}, 等待加载... (尝试 {}/{})",
                            modelId, modelState, attempts + 1, maxAttempts);
                    Thread.sleep(1000);
                    attempts++;
                }
            }

            if (!modelReady) {
                logger.warn("模型 {} 未能在{}秒内加载完成", modelId, MAX_WAIT_SECONDS);
                throw new ModelOperationException(
                        "模型加载超时,请稍后检查状态", modelId, 408);
            }
        } catch (InterruptedException e) {
            Thread.currentThread().interrupt();
            logger.warn("等待模型加载过程被中断", e);
            throw new ModelOperationException("模型验证被中断", modelId, 500, e);
        } catch (IOException e) {
            logger.warn("验证模型状态失败: {}", e.getMessage(), e);
            throw new ModelOperationException("验证模型状态失败", modelId, 500, e);
        } catch (ModelOperationException e) {
            throw e;
        } catch (Exception e) {
            logger.warn("验证模型状态失败: {}", e.getMessage(), e);
            throw new ModelOperationException("验证模型状态失败", modelId, 500, e);
        }
    }

    /**
     * 检查模型是否存在
     */
    @Override
    public boolean modelExists(String modelId) {
        Objects.requireNonNull(modelId, "模型ID不能为空");

        // 获取读锁,允许并发读取
        modelUpdateLock.readLock().lock();
        try (var client = ESClientUtil.createClient()) {
            var response = client.ml().getTrainedModels(m -> m.modelId(modelId));
            return !response.trainedModelConfigs().isEmpty();
        } catch (ElasticsearchException e) {
            if (e.status() == 404) {
                return false;
            }
            logger.error("检查模型存在性失败: {}", e.getMessage(), e);
            throw new ModelOperationException("检查模型失败", modelId, e.status(), e);
        } catch (IOException e) {
            logger.error("IO异常: {}", e.getMessage(), e);
            throw new ModelOperationException("检查模型IO异常", modelId, 500, e);
        } finally {
            modelUpdateLock.readLock().unlock();
        }
    }

    /**
     * 创建模型别名
     */
    @Override
    public void createModelAlias(String modelId, String aliasName) {
        Objects.requireNonNull(modelId, "模型ID不能为空");
        Objects.requireNonNull(aliasName, "别名不能为空");

        modelUpdateLock.writeLock().lock();
        try (var client = ESClientUtil.createClient()) {
            client.ml().putTrainedModelAlias(a -> a
                .modelId(modelId)
                .modelAlias(aliasName)
                .reassign(true)
            );
            logger.info("为模型 {} 创建别名 {}", modelId, aliasName);
        } catch (ElasticsearchException e) {
            logger.error("创建模型别名失败: {}", e.getMessage(), e);
            throw new ModelOperationException("创建模型别名失败", modelId, e.status(), e);
        } catch (IOException e) {
            logger.error("IO异常: {}", e.getMessage(), e);
            throw new ModelOperationException("创建模型别名IO异常", modelId, 500, e);
        } finally {
            modelUpdateLock.writeLock().unlock();
        }
    }
}

4.4 推理管道接口与实现

同样,先定义管道处理的接口:

java 复制代码
/**
 * 推理管道服务接口
 */
public interface InferencePipelineManager {
    /**
     * 创建推理管道
     * @param pipelineId 管道ID
     * @param modelId 模型ID
     * @param sourceField 源字段名
     * @param targetField 目标字段名
     * @param description 管道描述
     */
    void createInferencePipeline(String pipelineId, String modelId,
                                String sourceField, String targetField,
                                String description);

    /**
     * 确保管道存在,不存在则创建
     */
    void ensurePipelineExists(String pipelineId, String modelId,
                             String sourceField, String targetField,
                             String description);

    /**
     * 更新现有管道使用新模型
     */
    void updatePipelineModel(String pipelineId, String newModelId);
}

接口实现:

java 复制代码
import co.elastic.clients.elasticsearch.ElasticsearchClient;
import co.elastic.clients.elasticsearch.ingest.Processor;
import co.elastic.clients.elasticsearch.ingest.ProcessorsBuilder;
import co.elastic.clients.elasticsearch.ingest.PutPipelineRequest;
import co.elastic.clients.elasticsearch._types.ElasticsearchException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.scheduling.annotation.Scheduled;

import java.io.IOException;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.locks.ReadWriteLock;
import java.util.concurrent.locks.ReentrantReadWriteLock;

public class InferencePipelineService implements InferencePipelineManager {
    private static final Logger logger = LoggerFactory.getLogger(InferencePipelineService.class);

    // 管道缓存,避免重复创建
    private final ConcurrentHashMap<String, Long> pipelineCache = new ConcurrentHashMap<>();
    private final ReadWriteLock pipelineLock = new ReentrantReadWriteLock();

    // 缓存过期时间
    private static final long CACHE_TTL_MS = 3600000; // 1小时

    // 内存控制参数
    private static final int DEFAULT_INFERENCE_THREADS = 1;
    private static final int DEFAULT_NUM_TOP_CLASSES = 2;

    /**
     * 创建推理管道
     */
    @Override
    public void createInferencePipeline(
            String pipelineId,
            String modelId,
            String sourceField,
            String targetField,
            String description) {

        // 参数验证
        Objects.requireNonNull(pipelineId, "管道ID不能为空");
        Objects.requireNonNull(modelId, "模型ID不能为空");
        Objects.requireNonNull(sourceField, "源字段不能为空");
        Objects.requireNonNull(targetField, "目标字段不能为空");

        pipelineLock.writeLock().lock();
        try (var client = ESClientUtil.createClient()) {
            // 创建推理处理器
            Processor inferenceProcessor = new ProcessorsBuilder()
                .inference(ip -> ip
                    .modelId(modelId)
                    .targetField(targetField)
                    .fieldMap(Map.of(sourceField, "input_text"))
                    .inferenceConfig(ic -> ic
                        .classification(c -> c
                            .numTopClasses(DEFAULT_NUM_TOP_CLASSES)
                        )
                    )
                    .numberOfInferenceThreads(DEFAULT_INFERENCE_THREADS)
                )
                .build();

            // 创建包含推理处理器的管道
            client.ingest().putPipeline(PutPipelineRequest.of(builder ->
                builder
                    .id(pipelineId)
                    .description(description != null ? description : "推理管道 " + pipelineId)
                    .processors(inferenceProcessor)
            ));

            logger.info("推理管道 {} 创建成功, 使用模型 {}", pipelineId, modelId);

            // 验证管道
            verifyPipeline(client, pipelineId);

            // 更新缓存
            pipelineCache.put(pipelineId, System.currentTimeMillis());
        } catch (ElasticsearchException e) {
            logger.error("创建推理管道失败: {}, 状态码: {}", e.getMessage(), e.status(), e);
            throw new InferenceException("创建推理管道失败: " + e.getMessage(), pipelineId, e);
        } catch (IOException e) {
            logger.error("IO异常: {}", e.getMessage(), e);
            throw new InferenceException("创建推理管道IO异常", pipelineId, e);
        } finally {
            pipelineLock.writeLock().unlock();
        }
    }

    /**
     * 验证管道是否正常工作
     */
    private void verifyPipeline(ElasticsearchClient client, String pipelineId) {
        try {
            // 简单测试管道
            Map<String, Object> testDoc = Map.of(
                "text", "这是一个测试文本,用于验证推理管道是否正常工作。"
            );

            var response = client.ingest().simulate(s -> s
                .pipeline(p -> p.id(pipelineId))
                .docs(d -> d.doc(testDoc))
            );

            if (response.docs().isEmpty()) {
                throw new InferenceException("管道测试返回空结果", pipelineId);
            }

            var processedDoc = response.docs().get(0).doc().source();
            logger.info("管道 {} 测试成功", pipelineId);
        } catch (Exception e) {
            logger.warn("管道验证失败: {}", e.getMessage(), e);
            throw new InferenceException("管道验证失败", pipelineId, e);
        }
    }

    /**
     * 检查管道是否存在,不存在则创建
     * 用于确保推理前管道可用
     */
    @Override
    public void ensurePipelineExists(
            String pipelineId,
            String modelId,
            String sourceField,
            String targetField,
            String description) {

        // 先检查缓存
        if (pipelineCache.containsKey(pipelineId)) {
            return;
        }

        // 获取读锁检查
        pipelineLock.readLock().lock();
        try (var client = ESClientUtil.createClient()) {
            var response = client.ingest().getPipeline(g -> g.id(pipelineId));
            if (!response.result().isEmpty()) {
                // 管道已存在,更新缓存
                pipelineCache.put(pipelineId, System.currentTimeMillis());
                return;
            }
        } catch (Exception e) {
            // 忽略检查错误,尝试创建
            logger.debug("检查管道存在性出错,尝试创建: {}", e.getMessage());
        } finally {
            pipelineLock.readLock().unlock();
        }

        // 管道不存在,创建新管道
        createInferencePipeline(pipelineId, modelId, sourceField, targetField, description);
    }

    /**
     * 更新现有管道使用新模型
     */
    @Override
    public void updatePipelineModel(String pipelineId, String newModelId) {
        Objects.requireNonNull(pipelineId, "管道ID不能为空");
        Objects.requireNonNull(newModelId, "新模型ID不能为空");

        pipelineLock.writeLock().lock();
        try (var client = ESClientUtil.createClient()) {
            // 获取现有管道
            var response = client.ingest().getPipeline(g -> g.id(pipelineId));
            if (response.result().isEmpty()) {
                throw new InferenceException("管道不存在,无法更新", pipelineId);
            }

            // 获取当前管道配置
            var pipeline = response.result().get(pipelineId);
            var description = pipeline.description();
            var processors = pipeline.processors();

            // 查找并替换推理处理器中的模型ID
            boolean modelUpdated = false;
            for (var processor : processors) {
                if (processor.inference() != null) {
                    processor.inference().modelId(newModelId);
                    modelUpdated = true;
                }
            }

            if (!modelUpdated) {
                throw new InferenceException("管道中未找到推理处理器", pipelineId);
            }

            // 更新管道
            client.ingest().putPipeline(p -> p
                .id(pipelineId)
                .description(description)
                .processors(processors)
            );

            logger.info("管道 {} 已更新使用新模型 {}", pipelineId, newModelId);

            // 验证更新后的管道
            verifyPipeline(client, pipelineId);

            // 更新缓存
            pipelineCache.put(pipelineId, System.currentTimeMillis());
        } catch (ElasticsearchException e) {
            logger.error("更新管道模型失败: {}, 状态码: {}", e.getMessage(), e.status(), e);
            throw new InferenceException("更新管道模型失败", pipelineId, e);
        } catch (IOException e) {
            logger.error("IO异常: {}", e.getMessage(), e);
            throw new InferenceException("更新管道IO异常", pipelineId, e);
        } finally {
            pipelineLock.writeLock().unlock();
        }
    }

    /**
     * 定期清理过期缓存
     */
    @Scheduled(fixedRate = 3600000) // 每小时执行一次
    public void cleanupCache() {
        long now = System.currentTimeMillis();
        int removedCount = 0;

        for (Map.Entry<String, Long> entry : pipelineCache.entrySet()) {
            if (now - entry.getValue() > CACHE_TTL_MS) {
                pipelineCache.remove(entry.getKey());
                removedCount++;
            }
        }

        if (removedCount > 0) {
            logger.info("缓存清理完成,移除了{}个过期管道记录", removedCount);
        }
    }
}

4.5 系统初始化与配置

添加一个系统初始化组件,确保系统启动时完成必要的设置:

java 复制代码
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;

import javax.annotation.PostConstruct;
import java.io.IOException;
import java.util.concurrent.CompletableFuture;

@Component
public class ESMLSystemInitializer {
    private static final Logger logger = LoggerFactory.getLogger(ESMLSystemInitializer.class);

    private final IndexManager indexManager;
    private final ElasticsearchProperties esProperties;

    @Autowired
    public ESMLSystemInitializer(
            IndexManager indexManager,
            ElasticsearchProperties esProperties) {
        this.indexManager = indexManager;
        this.esProperties = esProperties;
    }

    @PostConstruct
    public void initialize() {
        logger.info("初始化ES ML系统...");

        // 异步初始化,避免阻塞应用启动
        CompletableFuture.runAsync(() -> {
            try {
                // 1. 验证集群连接
                if (!ESClientUtil.isClusterHealthy()) {
                    logger.warn("Elasticsearch集群状态不健康,初始化可能不完整");
                }

                // 2. 创建索引模板
                try (var client = ESClientUtil.createClient()) {
                    indexManager.createInferenceIndexTemplate(client);

                    // 3. 创建必要的索引
                    for (String indexName : esProperties.getRequiredIndices()) {
                        indexManager.createInferenceIndex(client, indexName);
                    }
                }

                logger.info("ES ML系统初始化完成");
            } catch (IOException e) {
                logger.error("初始化ES ML系统失败: {}", e.getMessage(), e);
            }
        });
    }
}

5. 推理服务设计与实现

接下来,我们设计更具体的推理接口和实现:

java 复制代码
/**
 * 单文本推理接口
 */
public interface SingleInference {
    /**
     * 执行单文本推理
     * @param indexName 索引名称
     * @param pipelineId 管道ID
     * @param modelId 模型ID
     * @param text 输入文本
     * @return 推理结果
     */
    Map<String, Object> inferSingle(String indexName, String pipelineId,
                                   String modelId, String text);
}

/**
 * 批量推理接口
 */
public interface BatchInference {
    /**
     * 执行批量文本推理
     * @param indexName 索引名称
     * @param pipelineId 管道ID
     * @param modelId 模型ID
     * @param texts 输入文本列表
     * @return 推理结果列表
     */
    List<Map<String, Object>> inferBatch(String indexName, String pipelineId,
                                        String modelId, List<String> texts);
}

/**
 * 链式推理接口
 */
public interface ChainedInference {
    /**
     * 执行多模型链式推理
     * @param modelChain 模型链配置
     * @param text 输入文本
     * @return 推理结果
     */
    Map<String, Object> chainedInference(List<ModelConfig> modelChain, String text);

    /**
     * 模型配置类
     */
    class ModelConfig {
        private final String modelId;
        private final String pipelineId;
        private final String indexName;
        private final String sourceField;
        private final String targetField;

        public ModelConfig(String modelId, String pipelineId, String indexName,
                          String sourceField, String targetField) {
            this.modelId = modelId;
            this.pipelineId = pipelineId;
            this.indexName = indexName;
            this.sourceField = sourceField;
            this.targetField = targetField;
        }

        // Getters
        public String getModelId() { return modelId; }
        public String getPipelineId() { return pipelineId; }
        public String getIndexName() { return indexName; }
        public String getSourceField() { return sourceField; }
        public String getTargetField() { return targetField; }
    }
}

/**
 * 完整推理服务接口
 */
public interface InferenceService extends SingleInference, BatchInference, ChainedInference {
    // 组合上述三个接口
}

实现上述接口的推理服务:

java 复制代码
import co.elastic.clients.elasticsearch.ElasticsearchClient;
import co.elastic.clients.elasticsearch.core.*;
import co.elastic.clients.elasticsearch.core.bulk.BulkResponseItem;
import co.elastic.clients.elasticsearch._types.ElasticsearchException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;

import java.io.IOException;
import java.util.*;
import java.util.concurrent.atomic.AtomicInteger;

@Service
public class ESInferenceService implements InferenceService {
    private static final Logger logger = LoggerFactory.getLogger(ESInferenceService.class);

    private final ModelUploader modelUploader;
    private final InferencePipelineManager pipelineManager;

    // 批处理控制参数
    private static final int MAX_BATCH_SIZE = 100;
    private static final int DEFAULT_BATCH_SIZE = 20;

    @Autowired
    public ESInferenceService(ModelUploader modelUploader,
                             InferencePipelineManager pipelineManager) {
        this.modelUploader = modelUploader;
        this.pipelineManager = pipelineManager;
    }

    /**
     * 单条文本推理
     */
    @Override
    public Map<String, Object> inferSingle(
            String indexName,
            String pipelineId,
            String modelId,
            String text) {

        // 参数验证
        Objects.requireNonNull(indexName, "索引名称不能为空");
        Objects.requireNonNull(pipelineId, "管道ID不能为空");
        Objects.requireNonNull(modelId, "模型ID不能为空");
        Objects.requireNonNull(text, "输入文本不能为空");

        // 确保管道存在
        pipelineManager.ensurePipelineExists(
                pipelineId, modelId, "text", "prediction", "推理管道");

        try (var client = ESClientUtil.createClient()) {
            // 准备文档
            String documentId = generateDocumentId();
            Map<String, Object> document = new HashMap<>();
            document.put("text", text);
            document.put("timestamp", System.currentTimeMillis());

            // 记录处理开始时间
            long startTime = System.currentTimeMillis();

            // 使用推理管道索引文档
            IndexResponse response = ESClientUtil.executeWithRetry(() ->
                client.index(IndexRequest.of(builder ->
                    builder
                        .index(indexName)
                        .id(documentId)
                        .pipeline(pipelineId)
                        .document(document)
                ))
            );

            if (!response.result().name().contains("CREATED") &&
                !response.result().name().contains("UPDATED")) {
                throw new InferenceException(
                        "索引文档失败: " + response.result().name(), pipelineId);
            }

            // 查询推理结果
            GetResponse<Map> getResponse = ESClientUtil.executeWithRetry(() ->
                client.get(g -> g
                        .index(indexName)
                        .id(documentId),
                        Map.class)
            );

            if (!getResponse.found()) {
                throw new InferenceException("推理后文档未找到", pipelineId);
            }

            // 计算处理时间
            long processingTime = System.currentTimeMillis() - startTime;
            Map<String, Object> result = getResponse.source();
            result.put("processing_time_ms", processingTime);

            logger.debug("单文本推理完成,处理时间: {}ms", processingTime);
            return result;
        } catch (ElasticsearchException e) {
            logger.error("ES推理异常: {}, 状态码: {}", e.getMessage(), e.status(), e);
            throw new InferenceException("推理失败: " + e.getMessage(), pipelineId, e);
        } catch (IOException e) {
            logger.error("IO异常: {}", e.getMessage(), e);
            throw new InferenceException("推理IO异常", pipelineId, e);
        }
    }

    /**
     * 批量文本推理
     */
    @Override
    public List<Map<String, Object>> inferBatch(
            String indexName,
            String pipelineId,
            String modelId,
            List<String> texts) {

        // 检查输入参数
        Objects.requireNonNull(indexName, "索引名称不能为空");
        Objects.requireNonNull(pipelineId, "管道ID不能为空");
        Objects.requireNonNull(modelId, "模型ID不能为空");

        if (texts == null || texts.isEmpty()) {
            return Collections.emptyList();
        }

        // 确保管道存在
        pipelineManager.ensurePipelineExists(
                pipelineId, modelId, "text", "prediction", "推理管道");

        // 大批量数据分批处理
        if (texts.size() > MAX_BATCH_SIZE) {
            return processBatchesInChunks(indexName, pipelineId, texts);
        }

        try (var client = ESClientUtil.createClient()) {
            // 准备批量请求
            List<String> documentIds = new ArrayList<>(texts.size());
            var bulkRequest = client.bulk().builder();

            // 记录处理开始时间
            long startTime = System.currentTimeMillis();

            // 添加所有文档到批量请求
            for (String text : texts) {
                String documentId = generateDocumentId();
                documentIds.add(documentId);

                Map<String, Object> document = new HashMap<>();
                document.put("text", text);
                document.put("timestamp", System.currentTimeMillis());

                bulkRequest.operations(op -> op
                    .index(idx -> idx
                        .index(indexName)
                        .id(documentId)
                        .document(document)
                    )
                );
            }

            // 执行批量请求,指定推理管道
            BulkResponse bulkResponse = ESClientUtil.executeWithRetry(() ->
                bulkRequest.pipeline(pipelineId).build().send()
            );

            // 检查批量操作结果
            if (bulkResponse.errors()) {
                logger.warn("批量推理部分失败");
                for (BulkResponseItem item : bulkResponse.items()) {
                    if (item.error() != null) {
                        logger.error("文档 {} 处理失败: {}",
                                item.id(), item.error().reason());
                    }
                }
            }

            // 批量获取结果文档
            MgetResponse<Map> response = ESClientUtil.executeWithRetry(() ->
                client.mget(m -> m
                    .index(indexName)
                    .ids(documentIds),
                    Map.class)
            );

            // 计算总处理时间
            long totalProcessingTime = System.currentTimeMillis() - startTime;
            long avgProcessingTime = totalProcessingTime / texts.size();

            // 提取结果
            List<Map<String, Object>> results = new ArrayList<>(texts.size());
            for (var doc : response.docs()) {
                if (doc.found()) {
                    Map<String, Object> result = doc.source();
                    result.put("processing_time_ms", avgProcessingTime);
                    results.add(result);
                } else {
                    Map<String, Object> errorDoc = new HashMap<>();
                    errorDoc.put("error", "推理后文档未找到");
                    errorDoc.put("id", doc.id());
                    results.add(errorDoc);
                }
            }

            logger.debug("批量推理完成,{}条文本,总耗时: {}ms,平均: {}ms/条",
                    texts.size(), totalProcessingTime, avgProcessingTime);
            return results;
        } catch (ElasticsearchException e) {
            logger.error("ES批量推理异常: {}, 状态码: {}", e.getMessage(), e.status(), e);
            throw new InferenceException("批量推理失败: " + e.getMessage(), pipelineId, e);
        } catch (IOException e) {
            logger.error("IO异常: {}", e.getMessage(), e);
            throw new InferenceException("批量推理IO异常", pipelineId, e);
        }
    }

    /**
     * 大批量数据分块处理
     */
    private List<Map<String, Object>> processBatchesInChunks(
            String indexName, String pipelineId, List<String> texts) {

        List<Map<String, Object>> allResults = new ArrayList<>(texts.size());
        List<List<String>> batches = new ArrayList<>();

        // 分割成小批次
        for (int i = 0; i < texts.size(); i += DEFAULT_BATCH_SIZE) {
            int end = Math.min(i + DEFAULT_BATCH_SIZE, texts.size());
            batches.add(texts.subList(i, end));
        }

        // 处理每个批次
        AtomicInteger processedCount = new AtomicInteger(0);
        for (List<String> batch : batches) {
            try {
                List<Map<String, Object>> batchResults =
                    inferBatchInternal(indexName, pipelineId, batch);
                allResults.addAll(batchResults);

                int completed = processedCount.addAndGet(batch.size());
                logger.info("批量推理进度: {}/{} ({}%)",
                        completed, texts.size(), (completed * 100 / texts.size()));
            } catch (Exception e) {
                logger.error("处理批次失败: {}", e.getMessage(), e);
                // 添加错误结果
                for (int i = 0; i < batch.size(); i++) {
                    Map<String, Object> errorDoc = new HashMap<>();
                    errorDoc.put("error", "批次处理失败: " + e.getMessage());
                    errorDoc.put("text", batch.get(i));
                    allResults.add(errorDoc);
                }
            }
        }

        return allResults;
    }

    /**
     * 内部批量推理实现
     * 处理单个批次,确保批次大小合理
     */
    private List<Map<String, Object>> inferBatchInternal(
            String indexName, String pipelineId, List<String> batch) {

        if (batch.size() > MAX_BATCH_SIZE) {
            throw new IllegalArgumentException(
                    "批次大小超过限制: " + batch.size() + " > " + MAX_BATCH_SIZE);
        }

        try (var client = ESClientUtil.createClient()) {
            // 准备批量请求
            List<String> documentIds = new ArrayList<>(batch.size());
            var bulkRequest = client.bulk().builder();

            // 添加所有文档到批量请求
            for (String text : batch) {
                String documentId = generateDocumentId();
                documentIds.add(documentId);

                Map<String, Object> document = new HashMap<>();
                document.put("text", text);
                document.put("timestamp", System.currentTimeMillis());

                bulkRequest.operations(op -> op
                    .index(idx -> idx
                        .index(indexName)
                        .id(documentId)
                        .document(document)
                    )
                );
            }

            // 执行批量请求,指定推理管道
            BulkResponse bulkResponse = ESClientUtil.executeWithRetry(() ->
                bulkRequest.pipeline(pipelineId).build().send()
            );

            // 检查批量操作结果
            checkBulkResponse(bulkResponse);

            // 批量获取结果文档
            MgetResponse<Map> response = ESClientUtil.executeWithRetry(() ->
                client.mget(m -> m
                    .index(indexName)
                    .ids(documentIds),
                    Map.class)
            );

            // 提取结果
            List<Map<String, Object>> results = new ArrayList<>(batch.size());
            for (var doc : response.docs()) {
                if (doc.found()) {
                    results.add(doc.source());
                } else {
                    Map<String, Object> errorDoc = new HashMap<>();
                    errorDoc.put("error", "推理后文档未找到");
                    errorDoc.put("id", doc.id());
                    results.add(errorDoc);
                }
            }

            return results;
        } catch (Exception e) {
            logger.error("批次处理异常: {}", e.getMessage(), e);
            throw new InferenceException("批次处理失败", pipelineId, e);
        }
    }

    /**
     * 检查批量响应错误
     */
    private void checkBulkResponse(BulkResponse response) {
        if (response.errors()) {
            StringBuilder errorMsg = new StringBuilder("批量操作部分失败: ");
            for (BulkResponseItem item : response.items()) {
                if (item.error() != null) {
                    errorMsg.append(item.id())
                           .append("(")
                           .append(item.error().reason())
                           .append("), ");
                }
            }
            logger.warn(errorMsg.toString());
        }
    }

    /**
     * 多模型链式推理
     */
    @Override
    public Map<String, Object> chainedInference(
            List<ModelConfig> modelChain, String text) {

        // 参数验证
        Objects.requireNonNull(modelChain, "模型链不能为空");
        if (modelChain.isEmpty()) {
            throw new IllegalArgumentException("模型链不能为空");
        }
        Objects.requireNonNull(text, "输入文本不能为空");

        Map<String, Object> document = new HashMap<>();
        document.put("original_text", text);
        document.put("text", text);
        document.put("timestamp", System.currentTimeMillis());

        // 按顺序执行每个模型
        for (int i = 0; i < modelChain.size(); i++) {
            ModelConfig config = modelChain.get(i);
            String stageName = "stage_" + (i + 1);

            try {
                // 确保管道存在
                String pipelineId = config.getPipelineId();
                pipelineManager.ensurePipelineExists(
                        pipelineId, config.getModelId(),
                        config.getSourceField(), config.getTargetField(),
                        "链式推理管道" + stageName);

                // 执行当前阶段推理
                document = runPipelineStage(
                        config.getIndexName(), pipelineId, document, stageName);

                // 如果有下一个模型,将当前结果作为下一阶段输入
                if (i < modelChain.size() - 1) {
                    ModelConfig nextConfig = modelChain.get(i + 1);
                    // 从当前输出中提取数据作为下一阶段输入
                    Object nextInput = extractFieldFromPath(
                            document, config.getTargetField());
                    // 转换为字符串或保持其结构,取决于下一个模型需要
                    document.put(nextConfig.getSourceField(), nextInput);
                }
            } catch (Exception e) {
                logger.error("链式推理阶段 {} 失败: {}", stageName, e.getMessage(), e);
                document.put("error_" + stageName, e.getMessage());
                // 链式失败,中断后续处理
                break;
            }
        }

        return document;
    }

    /**
     * 执行单个管道阶段
     */
    private Map<String, Object> runPipelineStage(
            String indexName, String pipelineId,
            Map<String, Object> document, String stageName) throws IOException {

        try (var client = ESClientUtil.createClient()) {
            String documentId = generateDocumentId();

            // 索引文档,应用管道
            IndexResponse response = ESClientUtil.executeWithRetry(() ->
                client.index(IndexRequest.of(builder ->
                    builder
                        .index(indexName)
                        .id(documentId)
                        .pipeline(pipelineId)
                        .document(document)
                ))
            );

            // 获取处理后的文档
            GetResponse<Map> getResponse = ESClientUtil.executeWithRetry(() ->
                client.get(g -> g
                        .index(indexName)
                        .id(documentId),
                        Map.class)
            );

            if (!getResponse.found()) {
                throw new InferenceException("阶段 " + stageName + " 处理后文档未找到", pipelineId);
            }

            // 保留处理阶段标记
            Map<String, Object> result = getResponse.source();
            result.put("_stage", stageName);

            return result;
        }
    }

    /**
     * 从嵌套字段路径中提取数据
     */
    private Object extractFieldFromPath(Map<String, Object> document, String fieldPath) {
        if (fieldPath == null || fieldPath.isEmpty()) {
            return null;
        }

        String[] parts = fieldPath.split("\\.");
        Object current = document;

        for (String part : parts) {
            if (current instanceof Map) {
                current = ((Map<?, ?>) current).get(part);
                if (current == null) {
                    return null;
                }
            } else {
                return null;
            }
        }

        return current;
    }

    /**
     * 生成唯一文档ID
     */
    private String generateDocumentId() {
        return UUID.randomUUID().toString();
    }
}

这部分代码提供了以下内容:

  1. 基本的异常类型定义,确保错误处理的明确性
  2. ES 客户端工具类,带有连接池、重试机制和 SSL 配置
  3. 模型上传服务,支持大文件流式处理
  4. 推理管道管理,包括缓存和过期清理
  5. 系统初始化组件,确保启动时完成必要设置
  6. 多接口推理服务设计,分离单文本、批量和链式推理

在下一部分,我们将继续深入探讨高级功能,包括:

  • 熔断器与故障降级策略
  • 模型优化技术
  • 性能监控
  • Spring Boot 集成
  • 实际应用案例
相关推荐
黎䪽圓18 分钟前
【Java多线程从青铜到王者】阻塞队列(十)
java·开发语言
Roam-G32 分钟前
List ToMap优化优化再优化到极致
java·list
"匠"人37 分钟前
讲解负载均衡
java·运维·负载均衡
C雨后彩虹43 分钟前
行为模式-责任链模式
java·设计模式·责任链模式
写bug写bug1 小时前
搞懂Spring Cloud Config配置信息自动更新原理
java·后端·架构
异常君1 小时前
Dubbo 高可用性核心机制详解与实战(上)
java·dubbo·设计
trow1 小时前
Web 认证技术的演进:解决无状态与有状态需求的矛盾
java
快乐肚皮1 小时前
快速排序优化技巧详解:提升性能的关键策略
java·算法·性能优化·排序算法
摘取一颗天上星️1 小时前
外部记忆的组织艺术:集合、树、栈与队列的深度解析
深度学习·机器学习·外部记忆