在大数据时代,搜索引擎与机器学习的融合已成为数据处理领域的重要技术方向。Elasticsearch 不仅是高性能的搜索与分析引擎,还提供了完整的机器学习推理框架,使我们能在分布式环境中高效部署和运行 ML 模型。
1. Elasticsearch 机器学习功能概述
Elasticsearch 的机器学习能力随版本不断发展,从早期的异常检测功能,到如今成熟的模型推理框架,为数据分析提供了强大支持。
核心功能包括:
- 异常检测:自动识别数据中的异常模式和离群值
- 预测分析:基于历史数据进行时序预测和趋势分析
- 模型推理:在 ES 中部署外部训练的机器学习模型并执行实时或批量推理
ES 机器学习框架的主要优势在于能够利用其分布式架构实现高可用、高扩展的模型部署,无需额外维护专门的 ML 服务基础设施。
版本功能对照
特性 | ES 7.x | ES 8.0-8.3 | ES 8.4+ | ES 8.8+ |
---|---|---|---|---|
异常检测 | ✓ | ✓ | ✓ | ✓ |
时序预测 | 有限支持 | ✓ | ✓ | ✓ |
ONNX 模型支持 | ✗ | 部分支持 | 完整支持 | 完整支持 |
PyTorch 模型原生支持 | ✗ | ✗ | ✗ | ✓ |
分布式推理 | ✗ | 有限支持 | 完整支持 | 增强支持 |
NLP 预训练模型库 | ✗ | 有限支持 | 丰富支持 | 全面支持 |
推理内存限制控制 | ✗ | ✓ | 增强控制 | 精细控制 |
模型部署隔离性 | ✗ | 有限支持 | ✓ | 增强支持 |
2. 推理功能原理与支持的模型类型

2.1 支持的模型类型
Elasticsearch 当前支持以下类型的机器学习模型:
- PyTorch 模型(通过 ONNX 格式转换或 8.8+版本原生支持)
- scikit-learn 模型(支持分类器、回归器等)
- XGBoost 模型(分类与回归)
- LightGBM 模型(分类与回归)
- 预训练的 NLP 模型(BERT、sentence-transformers 等)
2.2 ONNX 操作支持说明
对于 ONNX 模型,Elasticsearch 支持的操作集包括:
- 基础操作:Add, Sub, Mul, Div, Pow
- 神经网络层:Conv, MaxPool, BatchNormalization, Dropout
- 激活函数:Relu, Sigmoid, Tanh, LeakyRelu
- 数据操作:Reshape, Transpose, Concat, Split
- 序列处理:LSTM, GRU (从 8.4 版本开始完整支持)
在导入模型前应确认模型使用的操作在 ES 支持列表中。可以通过 ONNX Runtime 的检查工具验证模型兼容性:
java
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtSession;
public class ModelCompatibilityChecker {
public static boolean isModelCompatible(String modelPath) {
try {
OrtEnvironment env = OrtEnvironment.getEnvironment();
OrtSession.SessionOptions options = new OrtSession.SessionOptions();
// 加载模型进行验证
OrtSession session = env.createSession(modelPath, options);
// 检查模型输入输出
session.getInputNames().forEach(name ->
System.out.println("输入: " + name));
session.getOutputNames().forEach(name ->
System.out.println("输出: " + name));
session.close();
env.close();
return true;
} catch (Exception e) {
System.err.println("模型兼容性检查失败: " + e.getMessage());
return false;
}
}
}
2.3 推理处理流程
推理过程通过 Elasticsearch 的 Ingest Pipeline(摄入管道)实现:
- 数据文档通过管道时,推理处理器被触发
- 处理器从文档中提取特征数据
- 数据传入模型进行推理
- 推理结果存储在目标字段
- 处理后的文档继续完成索引流程
3. 实现步骤与环境准备

3.1 环境准备
- Elasticsearch 8.x(建议 8.4+版本)
- Java 11+
- 适配 ES 版本的 Java 客户端
Maven 依赖配置:
java
<dependencies>
<!-- Elasticsearch Java客户端 -->
<dependency>
<groupId>co.elastic.clients</groupId>
<artifactId>elasticsearch-java</artifactId>
<version>8.8.0</version>
</dependency>
<!-- JSON处理 -->
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-databind</artifactId>
<version>2.14.2</version>
</dependency>
<!-- 日志框架 -->
<dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-api</artifactId>
<version>1.7.36</version>
</dependency>
<dependency>
<groupId>ch.qos.logback</groupId>
<artifactId>logback-classic</artifactId>
<version>1.2.11</version>
</dependency>
<!-- ONNX Runtime (用于模型验证和优化) -->
<dependency>
<groupId>com.microsoft.onnxruntime</groupId>
<artifactId>onnxruntime</artifactId>
<version>1.13.1</version>
</dependency>
</dependencies>
3.2 日志配置示例
为确保生产环境中的日志管理规范,添加以下 logback.xml 配置:
xml
<?xml version="1.0" encoding="UTF-8"?>
<configuration>
<appender name="CONSOLE" class="ch.qos.logback.core.ConsoleAppender">
<encoder>
<pattern>%d{yyyy-MM-dd HH:mm:ss.SSS} [%thread] %-5level %logger{36} - %msg%n</pattern>
</encoder>
</appender>
<appender name="FILE" class="ch.qos.logback.core.rolling.RollingFileAppender">
<file>logs/es-ml-app.log</file>
<rollingPolicy class="ch.qos.logback.core.rolling.TimeBasedRollingPolicy">
<fileNamePattern>logs/es-ml-app.%d{yyyy-MM-dd}.log</fileNamePattern>
<maxHistory>30</maxHistory>
</rollingPolicy>
<encoder>
<pattern>%d{yyyy-MM-dd HH:mm:ss.SSS} [%thread] %-5level %logger{36} - %msg%n</pattern>
</encoder>
</appender>
<logger name="co.elastic" level="INFO"/>
<logger name="org.elasticsearch" level="WARN"/>
<logger name="com.yourcompany.esml" level="DEBUG"/>
<root level="INFO">
<appender-ref ref="CONSOLE"/>
<appender-ref ref="FILE"/>
</root>
</configuration>
3.3 Elasticsearch 配置参数
高效运行 ML 推理需要合理配置以下关键参数:
yaml
# elasticsearch.yml配置示例
xpack.ml.max_model_memory_limit: 1gb # 单个模型最大内存
xpack.ml.max_machine_memory_percent: 30 # ML可使用的最大机器内存百分比
xpack.ml.max_inference_processors: 4 # 每节点最大推理处理器数量
thread_pool.ingest.queue_size: 200 # 推理请求队列大小
thread_pool.ingest.size: 8 # 推理线程池大小
3.4 索引映射定义
为推理结果创建合适的索引映射,确保字段类型正确并支持高效查询:
java
import co.elastic.clients.elasticsearch.ElasticsearchClient;
import co.elastic.clients.elasticsearch.indices.CreateIndexResponse;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
public class IndexManager {
private static final Logger logger = LoggerFactory.getLogger(IndexManager.class);
/**
* 创建推理结果索引模板
*/
public void createInferenceIndexTemplate(ElasticsearchClient client) throws IOException {
try {
client.indices().putIndexTemplate(t -> t
.name("ml-inference-template")
.indexPatterns("inference-*", "ml-results-*")
.template(it -> it
.settings(s -> s
.numberOfShards(3)
.numberOfReplicas(1)
.refreshInterval(r -> r.time("5s"))
)
.mappings(m -> m
.properties("text", p -> p.text(t -> t
.analyzer("standard")
.fields("keyword", f -> f.keyword(k -> k))
))
.properties("prediction", p -> p.object(o -> o
.properties("score", s -> s.float32(f -> f))
.properties("label", l -> l.keyword(k -> k))
))
.properties("timestamp", p -> p.date(d -> d))
.properties("processing_time_ms", p -> p.long_(l -> l))
)
)
);
logger.info("推理结果索引模板创建成功");
} catch (Exception e) {
logger.error("创建索引模板失败: {}", e.getMessage(), e);
throw e;
}
}
/**
* 创建特定的推理索引
*/
public boolean createInferenceIndex(ElasticsearchClient client, String indexName) {
try {
CreateIndexResponse response = client.indices().create(c -> c
.index(indexName)
.aliases(indexName + "_alias", a -> a)
);
boolean acknowledged = response.acknowledged();
logger.info("索引 {} 创建{}", indexName, acknowledged ? "成功" : "失败");
return acknowledged;
} catch (Exception e) {
if (e.getMessage().contains("resource_already_exists_exception")) {
logger.info("索引 {} 已存在", indexName);
return true;
}
logger.error("创建索引失败: {}", e.getMessage(), e);
return false;
}
}
}
4. 代码实现:模型导入与推理
4.1 异常类型定义
首先,定义清晰的异常层次结构:
java
/**
* ES机器学习操作基础异常
*/
public class ESMLException extends RuntimeException {
public ESMLException(String message) {
super(message);
}
public ESMLException(String message, Throwable cause) {
super(message, cause);
}
}
/**
* 模型操作异常
*/
public class ModelOperationException extends ESMLException {
private final String modelId;
private final int statusCode;
public ModelOperationException(String message, String modelId, int statusCode, Throwable cause) {
super(message, cause);
this.modelId = modelId;
this.statusCode = statusCode;
}
public ModelOperationException(String message, String modelId, int statusCode) {
this(message, modelId, statusCode, null);
}
public String getModelId() {
return modelId;
}
public int getStatusCode() {
return statusCode;
}
}
/**
* 推理异常
*/
public class InferenceException extends ESMLException {
private final String pipelineId;
public InferenceException(String message, String pipelineId, Throwable cause) {
super(message, cause);
this.pipelineId = pipelineId;
}
public InferenceException(String message, String pipelineId) {
this(message, pipelineId, null);
}
public String getPipelineId() {
return pipelineId;
}
}
/**
* 索引操作异常
*/
public class IndexOperationException extends ESMLException {
private final String indexName;
public IndexOperationException(String message, String indexName, Throwable cause) {
super(message, cause);
this.indexName = indexName;
}
public String getIndexName() {
return indexName;
}
}
4.2 ES 客户端工具类
java
import co.elastic.clients.elasticsearch.ElasticsearchClient;
import co.elastic.clients.json.jackson.JacksonJsonpMapper;
import co.elastic.clients.transport.ElasticsearchTransport;
import co.elastic.clients.transport.rest_client.RestClientTransport;
import org.apache.http.HttpHost;
import org.apache.http.auth.AuthScope;
import org.apache.http.auth.UsernamePasswordCredentials;
import org.apache.http.client.CredentialsProvider;
import org.apache.http.client.config.RequestConfig;
import org.apache.http.conn.ssl.TrustSelfSignedStrategy;
import org.apache.http.impl.client.BasicCredentialsProvider;
import org.apache.http.impl.nio.client.HttpAsyncClientBuilder;
import org.apache.http.ssl.SSLContextBuilder;
import org.apache.http.ssl.SSLContexts;
import org.elasticsearch.client.RestClient;
import org.elasticsearch.client.RestClientBuilder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import javax.net.ssl.SSLContext;
import java.io.IOException;
import java.io.InputStream;
import java.nio.file.Files;
import java.nio.file.Path;
import java.security.KeyStore;
import java.security.cert.Certificate;
import java.security.cert.CertificateFactory;
import java.util.Objects;
import java.util.function.Supplier;
public class ESClientUtil {
private static final Logger logger = LoggerFactory.getLogger(ESClientUtil.class);
// 从配置文件或环境变量读取连接信息
private static final String ES_HOST = System.getProperty("es.host", "localhost");
private static final int ES_PORT = Integer.parseInt(System.getProperty("es.port", "9200"));
private static final String ES_PROTOCOL = System.getProperty("es.protocol", "https");
private static final String ES_USERNAME = System.getProperty("es.username", "elastic");
private static final String ES_PASSWORD = System.getProperty("es.password", "changeme");
private static final String ES_CERT_PATH = System.getProperty("es.cert.path");
// 连接池配置
private static final int MAX_CONN_TOTAL = 100;
private static final int MAX_CONN_PER_ROUTE = 30;
private static final int CONNECTION_TIMEOUT_MS = 5000;
private static final int SOCKET_TIMEOUT_MS = 60000;
// 重试配置
private static final int MAX_RETRIES = 3;
private static final long RETRY_BACKOFF_MS = 1000;
/**
* 创建带连接池的ES客户端
* @return 配置好的Elasticsearch客户端
* @throws IOException 如果创建客户端失败
*/
public static ElasticsearchClient createClient() throws IOException {
Objects.requireNonNull(ES_HOST, "ES主机地址不能为空");
RestClientBuilder builder = RestClient.builder(
new HttpHost(ES_HOST, ES_PORT, ES_PROTOCOL));
// 配置连接池
builder.setHttpClientConfigCallback(httpClientBuilder -> {
httpClientBuilder.setMaxConnTotal(MAX_CONN_TOTAL);
httpClientBuilder.setMaxConnPerRoute(MAX_CONN_PER_ROUTE);
// 配置认证
if (ES_USERNAME != null && !ES_USERNAME.isEmpty()) {
final CredentialsProvider credentialsProvider = new BasicCredentialsProvider();
credentialsProvider.setCredentials(AuthScope.ANY,
new UsernamePasswordCredentials(ES_USERNAME, ES_PASSWORD));
httpClientBuilder.setDefaultCredentialsProvider(credentialsProvider);
}
// 配置SSL证书
if (ES_CERT_PATH != null && !ES_CERT_PATH.isEmpty()) {
try {
SSLContext sslContext = buildSSLContext();
httpClientBuilder.setSSLContext(sslContext);
} catch (Exception e) {
logger.error("SSL证书配置失败", e);
}
}
return configureHttpClient(httpClientBuilder);
});
// 配置请求超时
builder.setRequestConfigCallback(requestConfigBuilder ->
requestConfigBuilder
.setConnectTimeout(CONNECTION_TIMEOUT_MS)
.setSocketTimeout(SOCKET_TIMEOUT_MS)
);
// 创建客户端
try {
ElasticsearchTransport transport = new RestClientTransport(
builder.build(), new JacksonJsonpMapper());
return new ElasticsearchClient(transport);
} catch (Exception e) {
logger.error("创建ES客户端失败", e);
throw new IOException("创建ES客户端失败: " + e.getMessage(), e);
}
}
/**
* 配置HTTP客户端
*/
private static HttpAsyncClientBuilder configureHttpClient(HttpAsyncClientBuilder httpClientBuilder) {
return httpClientBuilder
.setDefaultRequestConfig(
RequestConfig.custom()
.setConnectTimeout(CONNECTION_TIMEOUT_MS)
.setSocketTimeout(SOCKET_TIMEOUT_MS)
.build()
);
}
/**
* 构建SSL上下文
*/
private static SSLContext buildSSLContext() throws Exception {
Path certPath = Path.of(ES_CERT_PATH);
if (!Files.exists(certPath)) {
throw new IllegalArgumentException("证书文件不存在: " + certPath);
}
CertificateFactory factory = CertificateFactory.getInstance("X.509");
Certificate trustedCa;
try (InputStream is = Files.newInputStream(certPath)) {
trustedCa = factory.generateCertificate(is);
}
KeyStore trustStore = KeyStore.getInstance("pkcs12");
trustStore.load(null, null);
trustStore.setCertificateEntry("ca", trustedCa);
return SSLContexts.custom()
.loadTrustMaterial(trustStore, new TrustSelfSignedStrategy())
.build();
}
/**
* 关闭ES传输层
*/
public static void closeTransport(ElasticsearchTransport transport) {
if (transport != null) {
try {
transport.close();
} catch (IOException e) {
logger.warn("关闭ES传输层失败", e);
}
}
}
/**
* 执行带重试的ES操作
* @param operation 要执行的操作
* @return 操作结果
* @throws IOException 如果操作最终失败
*/
public static <T> T executeWithRetry(Supplier<T> operation) throws IOException {
int attempts = 0;
IOException lastException = null;
while (attempts < MAX_RETRIES) {
try {
return operation.get();
} catch (IOException e) {
if (isRetryableException(e)) {
lastException = e;
attempts++;
if (attempts < MAX_RETRIES) {
long backoffTime = RETRY_BACKOFF_MS * attempts;
logger.warn("ES操作失败,将在{}ms后重试(尝试{}/{}): {}",
backoffTime, attempts, MAX_RETRIES, e.getMessage());
try {
Thread.sleep(backoffTime);
} catch (InterruptedException ie) {
Thread.currentThread().interrupt();
throw new IOException("重试等待被中断", ie);
}
}
} else {
// 不可重试的异常直接抛出
logger.error("遇到不可重试的ES异常", e);
throw e;
}
}
}
logger.error("ES操作在{}次尝试后仍然失败", MAX_RETRIES);
throw lastException;
}
/**
* 判断异常是否可以重试
*/
private static boolean isRetryableException(IOException e) {
// 网络相关异常通常可以重试
if (e instanceof java.net.SocketTimeoutException ||
e instanceof java.net.ConnectException) {
return true;
}
// 特定HTTP状态码也可以重试
String message = e.getMessage();
if (message != null && (
message.contains("429") || // Too Many Requests
message.contains("503") || // Service Unavailable
message.contains("507") // Insufficient Storage
)) {
return true;
}
return false;
}
/**
* 检查ES集群健康状态
*/
public static boolean isClusterHealthy() {
try (var client = createClient()) {
var response = client.cluster().health();
String status = response.status().toString();
return "green".equals(status) || "yellow".equals(status);
} catch (Exception e) {
logger.error("检查集群健康状态失败: {}", e.getMessage(), e);
return false;
}
}
}
4.3 模型上传接口与实现
首先定义模型上传的接口,遵循依赖倒置原则:
java
/**
* 模型上传器接口
*/
public interface ModelUploader {
/**
* 上传模型到Elasticsearch
* @param modelId 模型ID
* @param modelPath 模型文件路径
* @param description 模型描述
* @param tags 模型标签
* @throws ModelOperationException 如果上传失败
*/
void uploadModel(String modelId, Path modelPath, String description, String... tags);
/**
* 检查模型是否存在
* @param modelId 模型ID
* @return 模型是否存在
*/
boolean modelExists(String modelId);
/**
* 创建模型别名
* @param modelId 模型ID
* @param aliasName 别名
*/
void createModelAlias(String modelId, String aliasName);
}
然后是接口实现:
java
import co.elastic.clients.elasticsearch.ElasticsearchClient;
import co.elastic.clients.elasticsearch.ml.PutTrainedModelRequest;
import co.elastic.clients.elasticsearch.ml.TrainedModelConfig;
import co.elastic.clients.elasticsearch._types.ElasticsearchException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.io.InputStream;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.Arrays;
import java.util.Base64;
import java.util.Objects;
import java.util.concurrent.locks.ReadWriteLock;
import java.util.concurrent.locks.ReentrantReadWriteLock;
public class ModelUploadService implements ModelUploader {
private static final Logger logger = LoggerFactory.getLogger(ModelUploadService.class);
private static final int CHUNK_SIZE = 1024 * 1024; // 1MB
private static final int MAX_WAIT_SECONDS = 120; // 等待模型加载的最大时间
// 并发控制锁
private final ReadWriteLock modelUpdateLock = new ReentrantReadWriteLock();
/**
* 上传ONNX模型到Elasticsearch
*/
@Override
public void uploadModel(String modelId, Path modelPath, String description, String... tags) {
// 参数验证
Objects.requireNonNull(modelId, "模型ID不能为空");
Objects.requireNonNull(modelPath, "模型路径不能为空");
if (!Files.exists(modelPath)) {
throw new IllegalArgumentException("模型文件不存在: " + modelPath);
}
// 获取写锁,确保模型更新时的独占访问
modelUpdateLock.writeLock().lock();
try (var client = ESClientUtil.createClient()) {
// 检查模型文件大小
long fileSize = Files.size(modelPath);
logger.info("开始上传模型 {}, 文件大小: {} 字节", modelId, fileSize);
// 根据大小选择上传方式
if (fileSize > 10 * 1024 * 1024) { // 大于10MB使用分块上传
uploadLargeModel(client, modelId, modelPath, description, tags);
} else {
uploadSmallModel(client, modelId, modelPath, description, tags);
}
// 验证模型上传状态
verifyModelStatus(client, modelId);
logger.info("模型 {} 上传成功", modelId);
} catch (ElasticsearchException e) {
logger.error("ES操作异常: {}, 状态码: {}", e.getMessage(), e.status(), e);
throw new ModelOperationException("上传模型失败: " + e.getMessage(), modelId, e.status(), e);
} catch (IOException e) {
logger.error("IO异常: {}", e.getMessage(), e);
throw new ModelOperationException("读取或上传模型失败", modelId, 500, e);
} finally {
modelUpdateLock.writeLock().unlock();
}
}
/**
* 上传小型模型(一次性上传)
*/
private void uploadSmallModel(
ElasticsearchClient client,
String modelId,
Path modelPath,
String description,
String[] tags) throws IOException {
// 读取模型文件并编码
String modelBase64;
try (InputStream is = Files.newInputStream(modelPath)) {
byte[] modelBytes = is.readAllBytes();
modelBase64 = Base64.getEncoder().encodeToString(modelBytes);
}
try {
// 创建模型配置并上传
client.ml().putTrainedModel(PutTrainedModelRequest.of(builder ->
builder
.modelId(modelId)
.inferenceConfig(ic -> ic.onnx(onnx -> onnx))
.modelType("onnx")
.tags(Arrays.asList(tags))
.description(description)
.definition(def -> def.modelBytes(modelBase64))
));
} catch (ElasticsearchException e) {
if (e.status() == 413) { // 请求实体太大
logger.warn("模型过大,尝试分块上传");
uploadLargeModel(client, modelId, modelPath, description, tags);
} else {
throw e;
}
}
}
/**
* 上传大型模型(分块上传)
*/
private void uploadLargeModel(
ElasticsearchClient client,
String modelId,
Path modelPath,
String description,
String[] tags) throws IOException {
// 创建模型配置(不包含模型内容)
createModelConfig(client, modelId, description, tags);
// 分块上传模型内容
uploadModelChunks(client, modelId, modelPath);
}
/**
* 创建模型配置
*/
private void createModelConfig(
ElasticsearchClient client,
String modelId,
String description,
String[] tags) throws IOException {
client.ml().putTrainedModel(builder ->
builder
.modelId(modelId)
.inferenceConfig(ic -> ic.onnx(onnx -> onnx))
.modelType("onnx")
.tags(Arrays.asList(tags))
.description(description)
);
logger.info("创建模型 {} 配置成功", modelId);
}
/**
* 分块上传模型内容
*/
private void uploadModelChunks(
ElasticsearchClient client,
String modelId,
Path modelPath) throws IOException {
// 获取文件大小
long fileSize = Files.size(modelPath);
int totalParts = (int) Math.ceil((double) fileSize / CHUNK_SIZE);
try (InputStream is = Files.newInputStream(modelPath)) {
byte[] buffer = new byte[CHUNK_SIZE];
int partNum = 0;
int bytesRead;
long totalBytesRead = 0;
// 读取并上传每个分块
while ((bytesRead = is.read(buffer)) != -1) {
// 如果读取的字节数小于缓冲区大小,创建一个刚好大小的新数组
byte[] chunk = bytesRead == buffer.length ? buffer : Arrays.copyOf(buffer, bytesRead);
String base64Chunk = Base64.getEncoder().encodeToString(chunk);
ESClientUtil.executeWithRetry(() -> {
client.ml().putTrainedModelDefinitionPart(d -> d
.modelId(modelId)
.part(partNum)
.definitionLength(fileSize)
.totalParts(totalParts)
.definition(base64Chunk)
);
return null;
});
totalBytesRead += bytesRead;
int progressPercent = (int)((totalBytesRead * 100) / fileSize);
logger.info("模型 {} 上传进度: {}/{} 块 ({}%)",
modelId, partNum + 1, totalParts, progressPercent);
partNum++;
}
}
}
/**
* 验证模型状态
*/
private void verifyModelStatus(ElasticsearchClient client, String modelId) {
try {
// 等待模型加载完成
int attempts = 0;
boolean modelReady = false;
int maxAttempts = MAX_WAIT_SECONDS; // 最多等待2分钟
while (attempts < maxAttempts && !modelReady) {
var response = client.ml().getTrainedModels(m -> m.modelId(modelId));
if (response.trainedModelConfigs().isEmpty()) {
throw new ModelOperationException("模型未找到", modelId, 404);
}
var modelInfo = response.trainedModelConfigs().get(0);
String modelState = modelInfo.modelState();
if ("started".equals(modelState)) {
modelReady = true;
logger.info("模型 {} 已加载就绪", modelId);
} else if ("starting".equals(modelState)) {
logger.info("模型 {} 当前状态: 启动中, 等待加载... (尝试 {}/{})",
modelId, attempts + 1, maxAttempts);
Thread.sleep(1000); // 等待1秒再检查
attempts++;
} else if ("failed".equals(modelState)) {
throw new ModelOperationException("模型加载失败: " + modelInfo.failure_reason(),
modelId, 500);
} else {
logger.info("模型 {} 当前状态: {}, 等待加载... (尝试 {}/{})",
modelId, modelState, attempts + 1, maxAttempts);
Thread.sleep(1000);
attempts++;
}
}
if (!modelReady) {
logger.warn("模型 {} 未能在{}秒内加载完成", modelId, MAX_WAIT_SECONDS);
throw new ModelOperationException(
"模型加载超时,请稍后检查状态", modelId, 408);
}
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
logger.warn("等待模型加载过程被中断", e);
throw new ModelOperationException("模型验证被中断", modelId, 500, e);
} catch (IOException e) {
logger.warn("验证模型状态失败: {}", e.getMessage(), e);
throw new ModelOperationException("验证模型状态失败", modelId, 500, e);
} catch (ModelOperationException e) {
throw e;
} catch (Exception e) {
logger.warn("验证模型状态失败: {}", e.getMessage(), e);
throw new ModelOperationException("验证模型状态失败", modelId, 500, e);
}
}
/**
* 检查模型是否存在
*/
@Override
public boolean modelExists(String modelId) {
Objects.requireNonNull(modelId, "模型ID不能为空");
// 获取读锁,允许并发读取
modelUpdateLock.readLock().lock();
try (var client = ESClientUtil.createClient()) {
var response = client.ml().getTrainedModels(m -> m.modelId(modelId));
return !response.trainedModelConfigs().isEmpty();
} catch (ElasticsearchException e) {
if (e.status() == 404) {
return false;
}
logger.error("检查模型存在性失败: {}", e.getMessage(), e);
throw new ModelOperationException("检查模型失败", modelId, e.status(), e);
} catch (IOException e) {
logger.error("IO异常: {}", e.getMessage(), e);
throw new ModelOperationException("检查模型IO异常", modelId, 500, e);
} finally {
modelUpdateLock.readLock().unlock();
}
}
/**
* 创建模型别名
*/
@Override
public void createModelAlias(String modelId, String aliasName) {
Objects.requireNonNull(modelId, "模型ID不能为空");
Objects.requireNonNull(aliasName, "别名不能为空");
modelUpdateLock.writeLock().lock();
try (var client = ESClientUtil.createClient()) {
client.ml().putTrainedModelAlias(a -> a
.modelId(modelId)
.modelAlias(aliasName)
.reassign(true)
);
logger.info("为模型 {} 创建别名 {}", modelId, aliasName);
} catch (ElasticsearchException e) {
logger.error("创建模型别名失败: {}", e.getMessage(), e);
throw new ModelOperationException("创建模型别名失败", modelId, e.status(), e);
} catch (IOException e) {
logger.error("IO异常: {}", e.getMessage(), e);
throw new ModelOperationException("创建模型别名IO异常", modelId, 500, e);
} finally {
modelUpdateLock.writeLock().unlock();
}
}
}
4.4 推理管道接口与实现
同样,先定义管道处理的接口:
java
/**
* 推理管道服务接口
*/
public interface InferencePipelineManager {
/**
* 创建推理管道
* @param pipelineId 管道ID
* @param modelId 模型ID
* @param sourceField 源字段名
* @param targetField 目标字段名
* @param description 管道描述
*/
void createInferencePipeline(String pipelineId, String modelId,
String sourceField, String targetField,
String description);
/**
* 确保管道存在,不存在则创建
*/
void ensurePipelineExists(String pipelineId, String modelId,
String sourceField, String targetField,
String description);
/**
* 更新现有管道使用新模型
*/
void updatePipelineModel(String pipelineId, String newModelId);
}
接口实现:
java
import co.elastic.clients.elasticsearch.ElasticsearchClient;
import co.elastic.clients.elasticsearch.ingest.Processor;
import co.elastic.clients.elasticsearch.ingest.ProcessorsBuilder;
import co.elastic.clients.elasticsearch.ingest.PutPipelineRequest;
import co.elastic.clients.elasticsearch._types.ElasticsearchException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.scheduling.annotation.Scheduled;
import java.io.IOException;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.locks.ReadWriteLock;
import java.util.concurrent.locks.ReentrantReadWriteLock;
public class InferencePipelineService implements InferencePipelineManager {
private static final Logger logger = LoggerFactory.getLogger(InferencePipelineService.class);
// 管道缓存,避免重复创建
private final ConcurrentHashMap<String, Long> pipelineCache = new ConcurrentHashMap<>();
private final ReadWriteLock pipelineLock = new ReentrantReadWriteLock();
// 缓存过期时间
private static final long CACHE_TTL_MS = 3600000; // 1小时
// 内存控制参数
private static final int DEFAULT_INFERENCE_THREADS = 1;
private static final int DEFAULT_NUM_TOP_CLASSES = 2;
/**
* 创建推理管道
*/
@Override
public void createInferencePipeline(
String pipelineId,
String modelId,
String sourceField,
String targetField,
String description) {
// 参数验证
Objects.requireNonNull(pipelineId, "管道ID不能为空");
Objects.requireNonNull(modelId, "模型ID不能为空");
Objects.requireNonNull(sourceField, "源字段不能为空");
Objects.requireNonNull(targetField, "目标字段不能为空");
pipelineLock.writeLock().lock();
try (var client = ESClientUtil.createClient()) {
// 创建推理处理器
Processor inferenceProcessor = new ProcessorsBuilder()
.inference(ip -> ip
.modelId(modelId)
.targetField(targetField)
.fieldMap(Map.of(sourceField, "input_text"))
.inferenceConfig(ic -> ic
.classification(c -> c
.numTopClasses(DEFAULT_NUM_TOP_CLASSES)
)
)
.numberOfInferenceThreads(DEFAULT_INFERENCE_THREADS)
)
.build();
// 创建包含推理处理器的管道
client.ingest().putPipeline(PutPipelineRequest.of(builder ->
builder
.id(pipelineId)
.description(description != null ? description : "推理管道 " + pipelineId)
.processors(inferenceProcessor)
));
logger.info("推理管道 {} 创建成功, 使用模型 {}", pipelineId, modelId);
// 验证管道
verifyPipeline(client, pipelineId);
// 更新缓存
pipelineCache.put(pipelineId, System.currentTimeMillis());
} catch (ElasticsearchException e) {
logger.error("创建推理管道失败: {}, 状态码: {}", e.getMessage(), e.status(), e);
throw new InferenceException("创建推理管道失败: " + e.getMessage(), pipelineId, e);
} catch (IOException e) {
logger.error("IO异常: {}", e.getMessage(), e);
throw new InferenceException("创建推理管道IO异常", pipelineId, e);
} finally {
pipelineLock.writeLock().unlock();
}
}
/**
* 验证管道是否正常工作
*/
private void verifyPipeline(ElasticsearchClient client, String pipelineId) {
try {
// 简单测试管道
Map<String, Object> testDoc = Map.of(
"text", "这是一个测试文本,用于验证推理管道是否正常工作。"
);
var response = client.ingest().simulate(s -> s
.pipeline(p -> p.id(pipelineId))
.docs(d -> d.doc(testDoc))
);
if (response.docs().isEmpty()) {
throw new InferenceException("管道测试返回空结果", pipelineId);
}
var processedDoc = response.docs().get(0).doc().source();
logger.info("管道 {} 测试成功", pipelineId);
} catch (Exception e) {
logger.warn("管道验证失败: {}", e.getMessage(), e);
throw new InferenceException("管道验证失败", pipelineId, e);
}
}
/**
* 检查管道是否存在,不存在则创建
* 用于确保推理前管道可用
*/
@Override
public void ensurePipelineExists(
String pipelineId,
String modelId,
String sourceField,
String targetField,
String description) {
// 先检查缓存
if (pipelineCache.containsKey(pipelineId)) {
return;
}
// 获取读锁检查
pipelineLock.readLock().lock();
try (var client = ESClientUtil.createClient()) {
var response = client.ingest().getPipeline(g -> g.id(pipelineId));
if (!response.result().isEmpty()) {
// 管道已存在,更新缓存
pipelineCache.put(pipelineId, System.currentTimeMillis());
return;
}
} catch (Exception e) {
// 忽略检查错误,尝试创建
logger.debug("检查管道存在性出错,尝试创建: {}", e.getMessage());
} finally {
pipelineLock.readLock().unlock();
}
// 管道不存在,创建新管道
createInferencePipeline(pipelineId, modelId, sourceField, targetField, description);
}
/**
* 更新现有管道使用新模型
*/
@Override
public void updatePipelineModel(String pipelineId, String newModelId) {
Objects.requireNonNull(pipelineId, "管道ID不能为空");
Objects.requireNonNull(newModelId, "新模型ID不能为空");
pipelineLock.writeLock().lock();
try (var client = ESClientUtil.createClient()) {
// 获取现有管道
var response = client.ingest().getPipeline(g -> g.id(pipelineId));
if (response.result().isEmpty()) {
throw new InferenceException("管道不存在,无法更新", pipelineId);
}
// 获取当前管道配置
var pipeline = response.result().get(pipelineId);
var description = pipeline.description();
var processors = pipeline.processors();
// 查找并替换推理处理器中的模型ID
boolean modelUpdated = false;
for (var processor : processors) {
if (processor.inference() != null) {
processor.inference().modelId(newModelId);
modelUpdated = true;
}
}
if (!modelUpdated) {
throw new InferenceException("管道中未找到推理处理器", pipelineId);
}
// 更新管道
client.ingest().putPipeline(p -> p
.id(pipelineId)
.description(description)
.processors(processors)
);
logger.info("管道 {} 已更新使用新模型 {}", pipelineId, newModelId);
// 验证更新后的管道
verifyPipeline(client, pipelineId);
// 更新缓存
pipelineCache.put(pipelineId, System.currentTimeMillis());
} catch (ElasticsearchException e) {
logger.error("更新管道模型失败: {}, 状态码: {}", e.getMessage(), e.status(), e);
throw new InferenceException("更新管道模型失败", pipelineId, e);
} catch (IOException e) {
logger.error("IO异常: {}", e.getMessage(), e);
throw new InferenceException("更新管道IO异常", pipelineId, e);
} finally {
pipelineLock.writeLock().unlock();
}
}
/**
* 定期清理过期缓存
*/
@Scheduled(fixedRate = 3600000) // 每小时执行一次
public void cleanupCache() {
long now = System.currentTimeMillis();
int removedCount = 0;
for (Map.Entry<String, Long> entry : pipelineCache.entrySet()) {
if (now - entry.getValue() > CACHE_TTL_MS) {
pipelineCache.remove(entry.getKey());
removedCount++;
}
}
if (removedCount > 0) {
logger.info("缓存清理完成,移除了{}个过期管道记录", removedCount);
}
}
}
4.5 系统初始化与配置
添加一个系统初始化组件,确保系统启动时完成必要的设置:
java
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import javax.annotation.PostConstruct;
import java.io.IOException;
import java.util.concurrent.CompletableFuture;
@Component
public class ESMLSystemInitializer {
private static final Logger logger = LoggerFactory.getLogger(ESMLSystemInitializer.class);
private final IndexManager indexManager;
private final ElasticsearchProperties esProperties;
@Autowired
public ESMLSystemInitializer(
IndexManager indexManager,
ElasticsearchProperties esProperties) {
this.indexManager = indexManager;
this.esProperties = esProperties;
}
@PostConstruct
public void initialize() {
logger.info("初始化ES ML系统...");
// 异步初始化,避免阻塞应用启动
CompletableFuture.runAsync(() -> {
try {
// 1. 验证集群连接
if (!ESClientUtil.isClusterHealthy()) {
logger.warn("Elasticsearch集群状态不健康,初始化可能不完整");
}
// 2. 创建索引模板
try (var client = ESClientUtil.createClient()) {
indexManager.createInferenceIndexTemplate(client);
// 3. 创建必要的索引
for (String indexName : esProperties.getRequiredIndices()) {
indexManager.createInferenceIndex(client, indexName);
}
}
logger.info("ES ML系统初始化完成");
} catch (IOException e) {
logger.error("初始化ES ML系统失败: {}", e.getMessage(), e);
}
});
}
}
5. 推理服务设计与实现
接下来,我们设计更具体的推理接口和实现:
java
/**
* 单文本推理接口
*/
public interface SingleInference {
/**
* 执行单文本推理
* @param indexName 索引名称
* @param pipelineId 管道ID
* @param modelId 模型ID
* @param text 输入文本
* @return 推理结果
*/
Map<String, Object> inferSingle(String indexName, String pipelineId,
String modelId, String text);
}
/**
* 批量推理接口
*/
public interface BatchInference {
/**
* 执行批量文本推理
* @param indexName 索引名称
* @param pipelineId 管道ID
* @param modelId 模型ID
* @param texts 输入文本列表
* @return 推理结果列表
*/
List<Map<String, Object>> inferBatch(String indexName, String pipelineId,
String modelId, List<String> texts);
}
/**
* 链式推理接口
*/
public interface ChainedInference {
/**
* 执行多模型链式推理
* @param modelChain 模型链配置
* @param text 输入文本
* @return 推理结果
*/
Map<String, Object> chainedInference(List<ModelConfig> modelChain, String text);
/**
* 模型配置类
*/
class ModelConfig {
private final String modelId;
private final String pipelineId;
private final String indexName;
private final String sourceField;
private final String targetField;
public ModelConfig(String modelId, String pipelineId, String indexName,
String sourceField, String targetField) {
this.modelId = modelId;
this.pipelineId = pipelineId;
this.indexName = indexName;
this.sourceField = sourceField;
this.targetField = targetField;
}
// Getters
public String getModelId() { return modelId; }
public String getPipelineId() { return pipelineId; }
public String getIndexName() { return indexName; }
public String getSourceField() { return sourceField; }
public String getTargetField() { return targetField; }
}
}
/**
* 完整推理服务接口
*/
public interface InferenceService extends SingleInference, BatchInference, ChainedInference {
// 组合上述三个接口
}
实现上述接口的推理服务:
java
import co.elastic.clients.elasticsearch.ElasticsearchClient;
import co.elastic.clients.elasticsearch.core.*;
import co.elastic.clients.elasticsearch.core.bulk.BulkResponseItem;
import co.elastic.clients.elasticsearch._types.ElasticsearchException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import java.io.IOException;
import java.util.*;
import java.util.concurrent.atomic.AtomicInteger;
@Service
public class ESInferenceService implements InferenceService {
private static final Logger logger = LoggerFactory.getLogger(ESInferenceService.class);
private final ModelUploader modelUploader;
private final InferencePipelineManager pipelineManager;
// 批处理控制参数
private static final int MAX_BATCH_SIZE = 100;
private static final int DEFAULT_BATCH_SIZE = 20;
@Autowired
public ESInferenceService(ModelUploader modelUploader,
InferencePipelineManager pipelineManager) {
this.modelUploader = modelUploader;
this.pipelineManager = pipelineManager;
}
/**
* 单条文本推理
*/
@Override
public Map<String, Object> inferSingle(
String indexName,
String pipelineId,
String modelId,
String text) {
// 参数验证
Objects.requireNonNull(indexName, "索引名称不能为空");
Objects.requireNonNull(pipelineId, "管道ID不能为空");
Objects.requireNonNull(modelId, "模型ID不能为空");
Objects.requireNonNull(text, "输入文本不能为空");
// 确保管道存在
pipelineManager.ensurePipelineExists(
pipelineId, modelId, "text", "prediction", "推理管道");
try (var client = ESClientUtil.createClient()) {
// 准备文档
String documentId = generateDocumentId();
Map<String, Object> document = new HashMap<>();
document.put("text", text);
document.put("timestamp", System.currentTimeMillis());
// 记录处理开始时间
long startTime = System.currentTimeMillis();
// 使用推理管道索引文档
IndexResponse response = ESClientUtil.executeWithRetry(() ->
client.index(IndexRequest.of(builder ->
builder
.index(indexName)
.id(documentId)
.pipeline(pipelineId)
.document(document)
))
);
if (!response.result().name().contains("CREATED") &&
!response.result().name().contains("UPDATED")) {
throw new InferenceException(
"索引文档失败: " + response.result().name(), pipelineId);
}
// 查询推理结果
GetResponse<Map> getResponse = ESClientUtil.executeWithRetry(() ->
client.get(g -> g
.index(indexName)
.id(documentId),
Map.class)
);
if (!getResponse.found()) {
throw new InferenceException("推理后文档未找到", pipelineId);
}
// 计算处理时间
long processingTime = System.currentTimeMillis() - startTime;
Map<String, Object> result = getResponse.source();
result.put("processing_time_ms", processingTime);
logger.debug("单文本推理完成,处理时间: {}ms", processingTime);
return result;
} catch (ElasticsearchException e) {
logger.error("ES推理异常: {}, 状态码: {}", e.getMessage(), e.status(), e);
throw new InferenceException("推理失败: " + e.getMessage(), pipelineId, e);
} catch (IOException e) {
logger.error("IO异常: {}", e.getMessage(), e);
throw new InferenceException("推理IO异常", pipelineId, e);
}
}
/**
* 批量文本推理
*/
@Override
public List<Map<String, Object>> inferBatch(
String indexName,
String pipelineId,
String modelId,
List<String> texts) {
// 检查输入参数
Objects.requireNonNull(indexName, "索引名称不能为空");
Objects.requireNonNull(pipelineId, "管道ID不能为空");
Objects.requireNonNull(modelId, "模型ID不能为空");
if (texts == null || texts.isEmpty()) {
return Collections.emptyList();
}
// 确保管道存在
pipelineManager.ensurePipelineExists(
pipelineId, modelId, "text", "prediction", "推理管道");
// 大批量数据分批处理
if (texts.size() > MAX_BATCH_SIZE) {
return processBatchesInChunks(indexName, pipelineId, texts);
}
try (var client = ESClientUtil.createClient()) {
// 准备批量请求
List<String> documentIds = new ArrayList<>(texts.size());
var bulkRequest = client.bulk().builder();
// 记录处理开始时间
long startTime = System.currentTimeMillis();
// 添加所有文档到批量请求
for (String text : texts) {
String documentId = generateDocumentId();
documentIds.add(documentId);
Map<String, Object> document = new HashMap<>();
document.put("text", text);
document.put("timestamp", System.currentTimeMillis());
bulkRequest.operations(op -> op
.index(idx -> idx
.index(indexName)
.id(documentId)
.document(document)
)
);
}
// 执行批量请求,指定推理管道
BulkResponse bulkResponse = ESClientUtil.executeWithRetry(() ->
bulkRequest.pipeline(pipelineId).build().send()
);
// 检查批量操作结果
if (bulkResponse.errors()) {
logger.warn("批量推理部分失败");
for (BulkResponseItem item : bulkResponse.items()) {
if (item.error() != null) {
logger.error("文档 {} 处理失败: {}",
item.id(), item.error().reason());
}
}
}
// 批量获取结果文档
MgetResponse<Map> response = ESClientUtil.executeWithRetry(() ->
client.mget(m -> m
.index(indexName)
.ids(documentIds),
Map.class)
);
// 计算总处理时间
long totalProcessingTime = System.currentTimeMillis() - startTime;
long avgProcessingTime = totalProcessingTime / texts.size();
// 提取结果
List<Map<String, Object>> results = new ArrayList<>(texts.size());
for (var doc : response.docs()) {
if (doc.found()) {
Map<String, Object> result = doc.source();
result.put("processing_time_ms", avgProcessingTime);
results.add(result);
} else {
Map<String, Object> errorDoc = new HashMap<>();
errorDoc.put("error", "推理后文档未找到");
errorDoc.put("id", doc.id());
results.add(errorDoc);
}
}
logger.debug("批量推理完成,{}条文本,总耗时: {}ms,平均: {}ms/条",
texts.size(), totalProcessingTime, avgProcessingTime);
return results;
} catch (ElasticsearchException e) {
logger.error("ES批量推理异常: {}, 状态码: {}", e.getMessage(), e.status(), e);
throw new InferenceException("批量推理失败: " + e.getMessage(), pipelineId, e);
} catch (IOException e) {
logger.error("IO异常: {}", e.getMessage(), e);
throw new InferenceException("批量推理IO异常", pipelineId, e);
}
}
/**
* 大批量数据分块处理
*/
private List<Map<String, Object>> processBatchesInChunks(
String indexName, String pipelineId, List<String> texts) {
List<Map<String, Object>> allResults = new ArrayList<>(texts.size());
List<List<String>> batches = new ArrayList<>();
// 分割成小批次
for (int i = 0; i < texts.size(); i += DEFAULT_BATCH_SIZE) {
int end = Math.min(i + DEFAULT_BATCH_SIZE, texts.size());
batches.add(texts.subList(i, end));
}
// 处理每个批次
AtomicInteger processedCount = new AtomicInteger(0);
for (List<String> batch : batches) {
try {
List<Map<String, Object>> batchResults =
inferBatchInternal(indexName, pipelineId, batch);
allResults.addAll(batchResults);
int completed = processedCount.addAndGet(batch.size());
logger.info("批量推理进度: {}/{} ({}%)",
completed, texts.size(), (completed * 100 / texts.size()));
} catch (Exception e) {
logger.error("处理批次失败: {}", e.getMessage(), e);
// 添加错误结果
for (int i = 0; i < batch.size(); i++) {
Map<String, Object> errorDoc = new HashMap<>();
errorDoc.put("error", "批次处理失败: " + e.getMessage());
errorDoc.put("text", batch.get(i));
allResults.add(errorDoc);
}
}
}
return allResults;
}
/**
* 内部批量推理实现
* 处理单个批次,确保批次大小合理
*/
private List<Map<String, Object>> inferBatchInternal(
String indexName, String pipelineId, List<String> batch) {
if (batch.size() > MAX_BATCH_SIZE) {
throw new IllegalArgumentException(
"批次大小超过限制: " + batch.size() + " > " + MAX_BATCH_SIZE);
}
try (var client = ESClientUtil.createClient()) {
// 准备批量请求
List<String> documentIds = new ArrayList<>(batch.size());
var bulkRequest = client.bulk().builder();
// 添加所有文档到批量请求
for (String text : batch) {
String documentId = generateDocumentId();
documentIds.add(documentId);
Map<String, Object> document = new HashMap<>();
document.put("text", text);
document.put("timestamp", System.currentTimeMillis());
bulkRequest.operations(op -> op
.index(idx -> idx
.index(indexName)
.id(documentId)
.document(document)
)
);
}
// 执行批量请求,指定推理管道
BulkResponse bulkResponse = ESClientUtil.executeWithRetry(() ->
bulkRequest.pipeline(pipelineId).build().send()
);
// 检查批量操作结果
checkBulkResponse(bulkResponse);
// 批量获取结果文档
MgetResponse<Map> response = ESClientUtil.executeWithRetry(() ->
client.mget(m -> m
.index(indexName)
.ids(documentIds),
Map.class)
);
// 提取结果
List<Map<String, Object>> results = new ArrayList<>(batch.size());
for (var doc : response.docs()) {
if (doc.found()) {
results.add(doc.source());
} else {
Map<String, Object> errorDoc = new HashMap<>();
errorDoc.put("error", "推理后文档未找到");
errorDoc.put("id", doc.id());
results.add(errorDoc);
}
}
return results;
} catch (Exception e) {
logger.error("批次处理异常: {}", e.getMessage(), e);
throw new InferenceException("批次处理失败", pipelineId, e);
}
}
/**
* 检查批量响应错误
*/
private void checkBulkResponse(BulkResponse response) {
if (response.errors()) {
StringBuilder errorMsg = new StringBuilder("批量操作部分失败: ");
for (BulkResponseItem item : response.items()) {
if (item.error() != null) {
errorMsg.append(item.id())
.append("(")
.append(item.error().reason())
.append("), ");
}
}
logger.warn(errorMsg.toString());
}
}
/**
* 多模型链式推理
*/
@Override
public Map<String, Object> chainedInference(
List<ModelConfig> modelChain, String text) {
// 参数验证
Objects.requireNonNull(modelChain, "模型链不能为空");
if (modelChain.isEmpty()) {
throw new IllegalArgumentException("模型链不能为空");
}
Objects.requireNonNull(text, "输入文本不能为空");
Map<String, Object> document = new HashMap<>();
document.put("original_text", text);
document.put("text", text);
document.put("timestamp", System.currentTimeMillis());
// 按顺序执行每个模型
for (int i = 0; i < modelChain.size(); i++) {
ModelConfig config = modelChain.get(i);
String stageName = "stage_" + (i + 1);
try {
// 确保管道存在
String pipelineId = config.getPipelineId();
pipelineManager.ensurePipelineExists(
pipelineId, config.getModelId(),
config.getSourceField(), config.getTargetField(),
"链式推理管道" + stageName);
// 执行当前阶段推理
document = runPipelineStage(
config.getIndexName(), pipelineId, document, stageName);
// 如果有下一个模型,将当前结果作为下一阶段输入
if (i < modelChain.size() - 1) {
ModelConfig nextConfig = modelChain.get(i + 1);
// 从当前输出中提取数据作为下一阶段输入
Object nextInput = extractFieldFromPath(
document, config.getTargetField());
// 转换为字符串或保持其结构,取决于下一个模型需要
document.put(nextConfig.getSourceField(), nextInput);
}
} catch (Exception e) {
logger.error("链式推理阶段 {} 失败: {}", stageName, e.getMessage(), e);
document.put("error_" + stageName, e.getMessage());
// 链式失败,中断后续处理
break;
}
}
return document;
}
/**
* 执行单个管道阶段
*/
private Map<String, Object> runPipelineStage(
String indexName, String pipelineId,
Map<String, Object> document, String stageName) throws IOException {
try (var client = ESClientUtil.createClient()) {
String documentId = generateDocumentId();
// 索引文档,应用管道
IndexResponse response = ESClientUtil.executeWithRetry(() ->
client.index(IndexRequest.of(builder ->
builder
.index(indexName)
.id(documentId)
.pipeline(pipelineId)
.document(document)
))
);
// 获取处理后的文档
GetResponse<Map> getResponse = ESClientUtil.executeWithRetry(() ->
client.get(g -> g
.index(indexName)
.id(documentId),
Map.class)
);
if (!getResponse.found()) {
throw new InferenceException("阶段 " + stageName + " 处理后文档未找到", pipelineId);
}
// 保留处理阶段标记
Map<String, Object> result = getResponse.source();
result.put("_stage", stageName);
return result;
}
}
/**
* 从嵌套字段路径中提取数据
*/
private Object extractFieldFromPath(Map<String, Object> document, String fieldPath) {
if (fieldPath == null || fieldPath.isEmpty()) {
return null;
}
String[] parts = fieldPath.split("\\.");
Object current = document;
for (String part : parts) {
if (current instanceof Map) {
current = ((Map<?, ?>) current).get(part);
if (current == null) {
return null;
}
} else {
return null;
}
}
return current;
}
/**
* 生成唯一文档ID
*/
private String generateDocumentId() {
return UUID.randomUUID().toString();
}
}
这部分代码提供了以下内容:
- 基本的异常类型定义,确保错误处理的明确性
- ES 客户端工具类,带有连接池、重试机制和 SSL 配置
- 模型上传服务,支持大文件流式处理
- 推理管道管理,包括缓存和过期清理
- 系统初始化组件,确保启动时完成必要设置
- 多接口推理服务设计,分离单文本、批量和链式推理
在下一部分,我们将继续深入探讨高级功能,包括:
- 熔断器与故障降级策略
- 模型优化技术
- 性能监控
- Spring Boot 集成
- 实际应用案例