model.onnx在语义匹配系统中的应用实践
📋 文档概览
本文档是 model.onnx 深度分析系列的第二篇,将深入解析实际项目中如何使用ONNX模型进行语义匹配。
文档目标:
- ✅ 深入理解VectorEncoder类的完整实现
- ✅ 掌握批量推理的性能优化技巧
- ✅ 学习GPU/CUDA加速配置方法
- ✅ 理解双缓存策略的设计思路
- ✅ 掌握ONNX Runtime的实战技巧
🏗️ 一、系统架构全景图
1.1 语义匹配系统整体架构
配置层
数据层
AI推理层
应用层
用户层
前端界面
API请求
Controller层
Service层
SimilarityService
VectorEncoder
ONNX Runtime
model.onnx
HNSW索引
向量数据库
双缓存系统
Caffeine缓存
application-similarity.yml
1.2 核心组件职责
| 组件 | 职责 | 关键类 |
|---|---|---|
| VectorEncoder | 文本向量化编码 | VectorEncoder.java |
| ONNX Runtime | 模型推理执行 | OrtSession, OrtEnvironment |
| Tokenizer | 文本预处理分词 | 内置在VectorEncoder中 |
| Cache System | 向量缓存管理 | Caffeine双缓存 |
| HNSW Index | 高效向量检索 | HNSW算法实现 |
| Configuration | 系统配置管理 | SimilarityProperties |
💻 二、VectorEncoder核心实现解析
2.1 类结构与生命周期管理
完整的类依赖关系
java
@Component
@Slf4j
public class VectorEncoder {
// ========== 依赖注入 ==========
@Autowired
private SimilarityProperties properties;
@Autowired
@Qualifier("vectorCache") // 动态缓存(用户查询)
private Cache<String, float[]> vectorCache;
@Autowired
@Qualifier("preloadVectorCache") // 预加载缓存(系统字段)
private Cache<String, float[]> preloadVectorCache;
@Autowired
private ApplicationEventPublisher eventPublisher;
// ========== ONNX核心组件 ==========
private OrtEnvironment environment; // ONNX运行环境(单例)
private OrtSession session; // ONNX会话(模型实例)
private boolean modelLoaded = false; // 模型加载状态
private final Object modelLock = new Object(); // 推理同步锁
// ========== 模型就绪管理 ==========
private final AtomicBoolean modelReady = new AtomicBoolean(false);
private final CountDownLatch modelReadyLatch = new CountDownLatch(1);
// ========== 词汇表 ==========
private final Map<String, Integer> vocab = new ConcurrentHashMap<>();
// ========== BERT特殊Token ==========
private final int CLS_TOKEN = 101; // [CLS]
private final int SEP_TOKEN = 102; // [SEP]
private final int PAD_TOKEN = 0; // [PAD]
private final int UNK_TOKEN = 100; // [UNK]
}
2.2 分阶段初始化设计(性能优化关键)
设计思路
传统的单阶段初始化会导致应用启动缓慢(加载409MB模型需要5-7秒),影响用户体验。因此采用两阶段异步初始化:
缓存系统 ONNX Runtime VectorEncoder Spring容器 缓存系统 ONNX Runtime VectorEncoder Spring容器 阶段1:应用启动时(同步,快速) 阶段2:应用启动后(异步,不阻塞) @PostConstruct init() 创建OrtEnvironment 加载词汇表 (vocab.txt) 初始化完成 (耗时: 100-200ms) @EventListener ApplicationReadyEvent 加载model.onnx (409MB) 模型加载完成 (耗时: 5-7秒) 设置modelReady=true 模型预热 (推理5个样本) 填充预加载缓存 发布VectorModelReadyEvent
实现代码详解
阶段1:轻量级初始化(@PostConstruct)
java
@PostConstruct
public void init() {
try {
log.info("开始初始化向量编码器(阶段1:词汇表加载)...");
// 1. 创建ONNX Runtime环境(单例,轻量级)
// 设置日志级别为ERROR,避免控制台乱码
environment = OrtEnvironment.getEnvironment(
OrtLoggingLevel.ORT_LOGGING_LEVEL_ERROR
);
// 2. 加载词汇表(vocab.txt,约21KB)
loadVocabulary();
log.info("词汇表加载完成,词汇量:{},等待应用启动后加载ONNX模型",
vocab.size());
} catch (Exception e) {
log.error("词汇表加载失败", e);
throw new RuntimeException("词汇表加载失败", e);
}
}
阶段2:重量级模型加载(@EventListener + @Async)
java
@EventListener(ApplicationReadyEvent.class)
@Async("similarityExecutor") // 异步执行,不阻塞主线程
public void initModelAsync() {
try {
long startTime = System.currentTimeMillis();
log.info("开始异步加载ONNX模型(阶段2:模型加载与预热)...");
// 1. 加载ONNX模型(409MB,耗时5-7秒)
loadModel();
// 2. 立即标记模型就绪(此时已可用)
modelLoaded = true;
modelReady.set(true);
modelReadyLatch.countDown();
log.info("模型加载完成,已标记为就绪状态");
// 3. 模型预热(推理几个样本,加速后续推理)
warmupModel();
long duration = System.currentTimeMillis() - startTime;
log.info("ONNX模型异步加载完成,耗时:{}ms,模型维度:{}",
duration, properties.getModel().getVectorDimension());
// 4. 发布模型就绪事件(触发索引构建等后续任务)
VectorModelReadyEvent event = new VectorModelReadyEvent(
this,
properties.getModel().getVectorDimension(),
vocab.size(),
duration
);
eventPublisher.publishEvent(event);
log.info("已发布向量模型就绪事件,触发索引构建");
} catch (Exception e) {
log.error("ONNX模型异步加载失败", e);
// 不抛异常,允许应用继续运行(但相似度功能不可用)
}
}
性能对比
| 初始化策略 | 启动时间 | 功能可用时间 | 用户体验 |
|---|---|---|---|
| 单阶段同步 | 8-10秒 | 8-10秒 | ❌ 差(等待时间长) |
| 两阶段异步 | 1-2秒 | 7-9秒(后台加载) | ✅ 好(快速启动) |
2.3 模型加载详细流程
关键代码解析
java
private void loadModel() throws OrtException, IOException {
String modelPath = properties.getModel().getActualPath();
log.info("加载ONNX模型:{}", modelPath);
try {
// ========== 步骤1:加载模型文件 ==========
byte[] modelBytes;
if (modelPath.startsWith("classpath:")) {
// 从classpath加载(JAR包内资源)
String resourcePath = modelPath.substring("classpath:".length());
ClassPathResource resource = new ClassPathResource(resourcePath);
modelBytes = readStreamToByteArray(resource.getInputStream());
} else if (modelPath.startsWith("/")) {
// 从绝对路径加载(转为classpath路径)
String resourcePath = modelPath.substring(1); // 去掉开头的/
ClassPathResource resource = new ClassPathResource(resourcePath);
if (resource.exists()) {
modelBytes = readStreamToByteArray(resource.getInputStream());
} else {
throw new FileNotFoundException("模型文件不存在: " + resourcePath);
}
} else {
// 从文件系统加载
File modelFile = new File(modelPath);
if (!modelFile.exists()) {
throw new FileNotFoundException("模型文件不存在: " + modelPath);
}
try (FileInputStream fis = new FileInputStream(modelFile)) {
modelBytes = readStreamToByteArray(fis);
}
}
// ========== 步骤2:创建ONNX会话选项 ==========
OrtSession.SessionOptions sessionOptions = new OrtSession.SessionOptions();
// 设置优化级别
sessionOptions.setOptimizationLevel(
OrtSession.SessionOptions.OptLevel.BASIC_OPT
);
// ========== 步骤3:配置硬件加速(GPU/CPU) ==========
configureGpuAcceleration(sessionOptions);
// ========== 步骤4:创建ONNX会话(加载模型) ==========
session = environment.createSession(modelBytes, sessionOptions);
modelLoaded = true;
log.info("ONNX模型加载成功,大小:{}MB,路径:{}",
modelBytes.length / 1024 / 1024, modelPath);
} catch (IOException e) {
throw SimilarityException.modelLoadError(modelPath, e);
}
}
内存优化技巧
java
// ❌ 错误做法:直接读取文件到内存
File file = new File("model.onnx");
byte[] bytes = Files.readAllBytes(file.toPath()); // 一次性读取409MB
// ✅ 正确做法:流式读取,避免内存峰值
private byte[] readStreamToByteArray(InputStream inputStream) throws IOException {
try (ByteArrayOutputStream buffer = new ByteArrayOutputStream()) {
byte[] data = new byte[8192]; // 8KB缓冲区
int nRead;
while ((nRead = inputStream.read(data, 0, data.length)) != -1) {
buffer.write(data, 0, nRead);
}
return buffer.toByteArray();
}
}
2.4 GPU加速配置详解
GPU加速架构
ONNX Runtime GPU加速栈
│
├── 应用层:VectorEncoder
│ └── configureGpuAcceleration()
│
├── ONNX Runtime层
│ ├── CUDA Execution Provider
│ ├── cuDNN算子库
│ └── cuBLAS线性代数库
│
├── CUDA层
│ ├── CUDA Toolkit (12.x)
│ ├── CUDA Runtime
│ └── CUDA Driver
│
└── 硬件层
└── NVIDIA GPU (算力 >= 3.5)
完整的GPU配置实现
java
private void configureGpuAcceleration(OrtSession.SessionOptions sessionOptions)
throws OrtException {
SimilarityProperties.Performance.Gpu gpuConfig =
properties.getPerformance().getGpu();
// ========== 判断1:用户是否启用GPU ==========
if (!gpuConfig.getEnabled()) {
log.info("GPU加速未启用,使用CPU执行");
configureCpuExecution(sessionOptions);
return;
}
// ========== 判断2:检测CUDA环境兼容性 ==========
log.info("GPU加速已启用,开始环境检测...");
CudaCompatibilityResult compatibilityResult = checkCudaCompatibility();
if (!compatibilityResult.isCompatible()) {
// CUDA环境不兼容,记录警告
log.warn("CUDA环境检测发现问题: {}",
compatibilityResult.getFailureReason());
log.warn("将尝试配置GPU,如果失败会抛出异常");
} else {
log.info("CUDA环境检测通过");
}
// ========== 判断3:尝试配置CUDA执行提供程序 ==========
boolean gpuConfigured = tryConfigureGpuAcceleration(
sessionOptions,
gpuConfig,
compatibilityResult.getEnvironmentInfo()
);
if (!gpuConfigured) {
// GPU配置失败,抛出详细错误信息
String errorMessage = "GPU加速已启用但配置失败";
log.error(errorMessage);
log.error("错误详情: CUDA执行提供程序配置失败");
log.error("可能的解决方案:");
log.error(" 1. 在 NVIDIA 官网下载与 CUDA 12.x 匹配的 cuDNN 9.x");
log.error(" 2. 将 cuDNN DLL 文件复制到 CUDA bin 目录");
log.error(" 3. 将 CUDA bin 目录添加到系统 PATH 环境变量");
log.error(" 4. 重启 IDE、终端或服务");
throw new SimilarityException("GPU_CONFIGURATION_FAILED",
"GPU配置失败,请检查CUDA和cuDNN环境");
}
log.info("GPU加速配置成功,已启用CUDA执行");
}
CUDA环境检测实现
java
private CudaCompatibilityResult checkCudaCompatibility() {
log.info("=== 开始CUDA环境兼容性检测 ===");
try {
// 1. 检测CUDA环境信息
CudaEnvironmentInfo envInfo = detectCudaEnvironment();
logCudaEnvironmentStatus(envInfo);
// 2. 检查GPU可用性
if (!envInfo.isGpuAvailable()) {
return CudaCompatibilityResult.failure("未检测到NVIDIA GPU", envInfo);
}
// 3. 检查CUDA Toolkit
if (!envInfo.isCudaToolkitAvailable()) {
return CudaCompatibilityResult.failure(
"未检测到CUDA Toolkit,请安装: " +
"https://developer.nvidia.com/cuda-downloads",
envInfo
);
}
// 4. 检查ONNX Runtime GPU支持
if (!checkOnnxRuntimeGpuSupport()) {
return CudaCompatibilityResult.failure(
"ONNX Runtime不支持GPU或addCUDA方法不可用",
envInfo
);
}
log.info("CUDA环境兼容性检测通过");
return CudaCompatibilityResult.success(envInfo);
} catch (Exception e) {
log.error("CUDA兼容性检测过程中发生异常: {}", e.getMessage());
return CudaCompatibilityResult.failure(
"检测过程异常: " + e.getMessage(),
new CudaEnvironmentInfo()
);
}
}
GPU信息检测示例
java
private CudaEnvironmentInfo detectCudaEnvironment() {
CudaEnvironmentInfo info = new CudaEnvironmentInfo();
// 检测NVIDIA GPU(通过nvidia-smi命令)
try {
Process process = Runtime.getRuntime().exec(
"nvidia-smi --query-gpu=name,memory.total --format=csv,noheader,nounits"
);
BufferedReader reader = new BufferedReader(
new InputStreamReader(process.getInputStream())
);
String line = reader.readLine();
if (line != null && !line.trim().isEmpty()) {
String[] parts = line.split(",");
if (parts.length >= 2) {
info.setGpuAvailable(true);
info.setGpuName(parts[0].trim());
info.setGpuMemoryMB(Integer.parseInt(parts[1].trim()));
}
}
process.waitFor();
} catch (Exception e) {
log.debug("检测NVIDIA GPU失败: {}", e.getMessage());
}
// 检测CUDA Toolkit版本(通过nvcc命令)
try {
Process process = Runtime.getRuntime().exec("nvcc --version");
BufferedReader reader = new BufferedReader(
new InputStreamReader(process.getInputStream())
);
String line;
while ((line = reader.readLine()) != null) {
if (line.contains("release")) {
// 解析版本号:Cuda compilation tools, release 12.6, V12.6.85
String[] parts = line.split(",");
for (String part : parts) {
if (part.trim().startsWith("release")) {
info.setCudaToolkitAvailable(true);
info.setCudaVersion(part.trim().replace("release ", ""));
break;
}
}
break;
}
}
process.waitFor();
} catch (Exception e) {
log.debug("检测CUDA Toolkit失败: {}", e.getMessage());
}
return info;
}
2.5 核心编码流程:从文本到向量
完整的编码流程
java
public float[] encode(String text) throws Exception {
// ========== 步骤0:等待模型就绪 ==========
if (!modelReady.get()) {
log.warn("ONNX模型尚未就绪,等待异步加载完成...");
boolean ready = modelReadyLatch.await(60, TimeUnit.SECONDS);
if (!ready) {
throw new IllegalStateException("ONNX模型加载超时(等待60秒)");
}
log.info("ONNX模型已就绪,继续执行编码");
}
// ========== 步骤1:处理空文本 ==========
if (text == null || text.trim().isEmpty()) {
return new float[properties.getModel().getVectorDimension()]; // 零向量
}
// ========== 步骤2:检查缓存 ==========
String cacheKey = generateCacheKey(text); // MD5哈希
float[] cachedVector = vectorCache.getIfPresent(cacheKey);
if (cachedVector != null) {
log.debug("使用缓存向量: {} (缓存命中)", text);
return cachedVector;
}
// ========== 步骤3:尝试命中预加载缓存 ==========
float[] preloadedVector = preloadVectorCache.getIfPresent(cacheKey);
if (preloadedVector != null) {
log.info("使用预加载缓存向量: {}", text);
vectorCache.put(cacheKey, preloadedVector); // 同步到动态缓存
return preloadedVector;
}
log.debug("编码新文本: {} (缓存未命中)", text);
// ========== 步骤4:执行ONNX推理 ==========
float[] vector = performEncoding(text);
// ========== 步骤5:存入缓存 ==========
vectorCache.put(cacheKey, vector);
log.debug("向量编码完成并缓存: {}", text);
return vector;
}
performEncoding详细实现
java
private float[] performEncoding(String text) throws Exception {
// ========== 步骤1:Tokenize文本 ==========
long[] inputIds = tokenize(text); // [128]
long[] attentionMask = createAttentionMask(inputIds); // [128]
long[] tokenTypeIds = new long[inputIds.length]; // [128],全0
// ========== 步骤2:创建ONNX输入张量 ==========
synchronized (modelLock) {
Map<String, OnnxTensor> inputs = new HashMap<>();
try (OnnxTensor inputTensor = OnnxTensor.createTensor(
environment, new long[][]{inputIds});
OnnxTensor attentionTensor = OnnxTensor.createTensor(
environment, new long[][]{attentionMask});
OnnxTensor tokenTypeTensor = OnnxTensor.createTensor(
environment, new long[][]{tokenTypeIds})) {
inputs.put("input_ids", inputTensor);
inputs.put("attention_mask", attentionTensor);
inputs.put("token_type_ids", tokenTypeTensor);
// ========== 步骤3:执行ONNX推理 ==========
try (OrtSession.Result result = session.run(inputs)) {
return processModelOutputOnly(result, text);
}
}
}
}
Tokenization实现(完整的BERT分词)
java
private long[] tokenize(String text) {
List<Integer> tokens = new ArrayList<>();
// 1. 添加[CLS]标记
tokens.add(CLS_TOKEN); // 101
// 2. 文本预处理
String normalizedText = text.trim();
// 3. 使用贪心最长匹配进行分词
int i = 0;
while (i < normalizedText.length()) {
boolean matched = false;
// 从最长可能的子串开始匹配(最多10个字符)
for (int j = Math.min(i + 10, normalizedText.length()); j > i; j--) {
String subStr = normalizedText.substring(i, j);
Integer tokenId = vocab.get(subStr);
if (tokenId != null) {
tokens.add(tokenId);
i = j;
matched = true;
break;
}
}
// 如果没有匹配到,尝试单字符匹配
if (!matched) {
char c = normalizedText.charAt(i);
String charStr = String.valueOf(c);
Integer tokenId = vocab.get(charStr);
if (tokenId != null) {
tokens.add(tokenId);
} else {
tokens.add(UNK_TOKEN); // 100,未知token
}
i++;
}
}
// 4. 添加[SEP]标记
tokens.add(SEP_TOKEN); // 102
// 5. 截断或填充到固定长度128
int maxLength = properties.getTokenizer().getMaxLength(); // 128
while (tokens.size() < maxLength) {
tokens.add(PAD_TOKEN); // 0
}
if (tokens.size() > maxLength) {
tokens = tokens.subList(0, maxLength);
tokens.set(maxLength - 1, SEP_TOKEN); // 确保最后是[SEP]
}
// 6. 转换为long数组
long[] result = tokens.stream().mapToLong(Integer::longValue).toArray();
log.debug("文本 '{}' tokenize结果前10个: {}", text,
Arrays.toString(Arrays.copyOf(result, Math.min(10, result.length))));
return result;
}
Tokenization示例
输入文本:"中国首都北京"
步骤1:添加[CLS] token
tokens = [101]
步骤2-3:贪心最长匹配分词
"中" -> token_id = 704 -> tokens = [101, 704]
"国" -> token_id = 1744 -> tokens = [101, 704, 1744]
"首" -> token_id = 7674 -> tokens = [101, 704, 1744, 7674]
"都" -> token_id = 7662 -> tokens = [101, 704, 1744, 7674, 7662]
"北" -> token_id = 738 -> tokens = [101, 704, 1744, 7674, 7662, 738]
"京" -> token_id = 776 -> tokens = [101, 704, 1744, 7674, 7662, 738, 776]
步骤4:添加[SEP] token
tokens = [101, 704, 1744, 7674, 7662, 738, 776, 102]
步骤5:填充到长度128
tokens = [101, 704, 1744, 7674, 7662, 738, 776, 102, 0, 0, ..., 0]
(长度: 128)
对应的attention_mask:
mask = [1, 1, 1, 1, 1, 1, 1, 1, 0, 0, ..., 0]
(1表示真实token,0表示padding)
🚀 三、批量推理优化技术
3.1 为什么需要批量推理?
单个推理 vs 批量推理性能对比
场景:编码1000个文本
单个推理(batch_size=1):
├── 推理次数: 1000次
├── 总耗时: 1000 * 25ms = 25,000ms (25秒)
├── GPU利用率: ~15%(大量时间浪费在数据传输)
└── 内存拷贝: 1000次 CPU→GPU
批量推理(batch_size=128):
├── 推理次数: 8次 (1000/128=7.8,向上取整)
├── 总耗时: 8 * 150ms = 1,200ms (1.2秒)
├── GPU利用率: ~85%(充分利用GPU并行计算)
├── 内存拷贝: 8次 CPU→GPU
└── 性能提升: 25秒 → 1.2秒,提速 20.8倍
批量推理的核心优势
- 减少数据传输开销:CPU→GPU的数据传输是瓶颈
- 提高GPU利用率:GPU擅长并行计算,批量处理更高效
- 摊薄固定开销:模型加载、context切换等固定开销被分摊
3.2 批量编码实现详解
java
public Map<String, float[]> batchEncode(List<SimilarityItem> items) {
if (!modelLoaded) {
log.error("ONNX模型未加载,批量编码失败");
return createEmptyVectors(items);
}
Map<String, float[]> vectors = new ConcurrentHashMap<>();
// ========== 步骤1:检查缓存,分离已缓存和未缓存的项 ==========
List<SimilarityItem> uncachedItems = new ArrayList<>();
Map<String, String> itemTextMap = new HashMap<>();
for (SimilarityItem item : items) {
String text = item.getChineseName();
String cacheKey = generateCacheKey(text);
float[] cachedVector = vectorCache.getIfPresent(cacheKey);
if (cachedVector != null) {
// 缓存命中
vectors.put(item.getId(), cachedVector);
} else {
// 缓存未命中,需要编码
uncachedItems.add(item);
itemTextMap.put(item.getId(), text);
}
}
log.info("批量编码开始,总数:{},缓存命中:{},需要编码:{}",
items.size(), vectors.size(), uncachedItems.size());
if (uncachedItems.isEmpty()) {
log.debug("所有向量都已缓存,无需编码");
return vectors;
}
// ========== 步骤2:对未缓存的项进行批量编码 ==========
boolean batchEnabled = properties.getPerformance().getOnnxBatch().getEnabled();
int onnxBatchSize = batchEnabled ?
Math.min(
properties.getPerformance().getOnnxBatch().getSize(), // 配置的批量大小
uncachedItems.size()
) : 1;
// 分批处理
List<List<SimilarityItem>> batches =
CollectionUtils.partitionList(uncachedItems, onnxBatchSize);
for (List<SimilarityItem> batch : batches) {
try {
if (batchEnabled && batch.size() > 1) {
// 使用批量ONNX推理
Map<String, float[]> batchVectors = batchEncodeOnnx(batch);
vectors.putAll(batchVectors);
} else {
// 使用单个编码
for (SimilarityItem item : batch) {
try {
float[] vector = encode(item.getChineseName());
vectors.put(item.getId(), vector);
} catch (Exception ex) {
log.error("单个编码失败: {}", item.getChineseName(), ex);
vectors.put(item.getId(),
new float[properties.getModel().getVectorDimension()]);
}
}
}
} catch (Exception e) {
log.error("批量ONNX编码失败,回退到单个编码", e);
// 降级:回退到单个编码
for (SimilarityItem item : batch) {
try {
float[] vector = encode(item.getChineseName());
vectors.put(item.getId(), vector);
} catch (Exception ex) {
log.error("单个编码也失败: {}", item.getChineseName(), ex);
vectors.put(item.getId(),
new float[properties.getModel().getVectorDimension()]);
}
}
}
}
log.debug("批量编码完成,成功:{},缓存统计:{}",
vectors.size(), getCacheStats());
return vectors;
}
3.3 批量ONNX推理核心实现
java
private Map<String, float[]> batchEncodeOnnx(List<SimilarityItem> items)
throws Exception {
if (items.isEmpty()) {
return new HashMap<>();
}
log.debug("开始批量ONNX推理,批大小:{}", items.size());
// ========== 步骤1:批量tokenize所有文本 ==========
List<String> texts = items.stream()
.map(SimilarityItem::getChineseName)
.collect(Collectors.toList());
List<long[]> allInputIds = new ArrayList<>();
List<long[]> allAttentionMasks = new ArrayList<>();
List<long[]> allTokenTypeIds = new ArrayList<>();
for (String text : texts) {
long[] inputIds = tokenize(text);
long[] attentionMask = createAttentionMask(inputIds);
long[] tokenTypeIds = new long[inputIds.length]; // 全0
allInputIds.add(inputIds);
allAttentionMasks.add(attentionMask);
allTokenTypeIds.add(tokenTypeIds);
}
// ========== 步骤2:转换为二维数组 ==========
long[][] inputIdsArray = allInputIds.toArray(new long[0][]);
long[][] attentionMaskArray = allAttentionMasks.toArray(new long[0][]);
long[][] tokenTypeIdsArray = allTokenTypeIds.toArray(new long[0][]);
// 形状:[batch_size, 128]
// 例如:batch_size=128时,inputIdsArray.shape = [128, 128]
// ========== 步骤3:执行批量ONNX推理 ==========
Map<String, float[]> results = new HashMap<>();
synchronized (modelLock) {
try (OnnxTensor inputTensor = OnnxTensor.createTensor(
environment, inputIdsArray);
OnnxTensor attentionTensor = OnnxTensor.createTensor(
environment, attentionMaskArray);
OnnxTensor tokenTypeTensor = OnnxTensor.createTensor(
environment, tokenTypeIdsArray)) {
Map<String, OnnxTensor> inputs = new HashMap<>();
inputs.put("input_ids", inputTensor);
inputs.put("attention_mask", attentionTensor);
inputs.put("token_type_ids", tokenTypeTensor);
try (OrtSession.Result result = session.run(inputs)) {
// ========== 步骤4:处理批量输出 ==========
OnnxValue outputValue = result.get(0);
float[][][] batchOutput = (float[][][]) outputValue.getValue();
// 形状:[batch_size, 128, 768]
// ========== 步骤5:提取每个文本的向量并缓存 ==========
for (int i = 0; i < items.size(); i++) {
SimilarityItem item = items.get(i);
String text = texts.get(i);
// 提取[CLS] token的向量表示(第一个token)
float[] clsVector = batchOutput[i][0].clone();
// clsVector.shape = [768]
// 归一化向量(L2归一化)
normalize(clsVector);
// 存入缓存
String cacheKey = generateCacheKey(text);
vectorCache.put(cacheKey, clsVector);
results.put(item.getId(), clsVector);
}
}
}
}
log.debug("批量ONNX推理完成,处理了{}个文本", items.size());
return results;
}
批量推理内存布局
输入张量形状(batch_size=3的示例):
input_ids: [3, 128] int64
attention_mask: [3, 128] int64
token_type_ids: [3, 128] int64
输出张量形状:
output: [3, 128, 768] float32
内存布局可视化:
[
[ [0.12, 0.34, ..., 0.78], // 文本1, token 0 ([CLS])
[0.22, 0.44, ..., 0.88], // 文本1, token 1
...
[0.02, 0.04, ..., 0.08] // 文本1, token 127
],
[ [0.15, 0.37, ..., 0.79], // 文本2, token 0 ([CLS])
...
],
[ [0.18, 0.39, ..., 0.81], // 文本3, token 0 ([CLS])
...
]
]
提取结果:
文本1向量 = output[0][0] = [0.12, 0.34, ..., 0.78] (768维)
文本2向量 = output[1][0] = [0.15, 0.37, ..., 0.79] (768维)
文本3向量 = output[2][0] = [0.18, 0.39, ..., 0.81] (768维)
💾 四、双缓存策略设计
4.1 为什么需要双缓存?
业务场景分析
字典关系管理系统的向量使用模式:
类型1:系统预置数据(预加载缓存)
├── 标准字段名称(如:"客户名称"、"订单编号")
├── 词根向量(如:"名称"、"编号"、"日期")
├── 特点:数量固定、频繁使用、永不过期
└── 数量:约5000-10000个
类型2:用户查询数据(动态缓存)
├── 用户输入的搜索词(如:"客户姓名"、"订单号")
├── 临时生成的映射字段
├── 特点:数量不定、偶尔使用、定时过期
└── 数量:约1000-50000个
单缓存的问题
java
// ❌ 问题:使用单一缓存(有过期时间)
Cache<String, float[]> singleCache = Caffeine.newBuilder()
.maximumSize(50000)
.expireAfterWrite(1, TimeUnit.HOURS) // 1小时过期
.build();
// 后果:
// 1. 系统预置的向量也会过期 -> 重新计算浪费资源
// 2. 高频字段向量频繁失效 -> 性能下降
// 3. 无法区分重要数据和临时数据 -> 缓存效率低
双缓存解决方案
java
// ✅ 解决方案:双缓存架构
Cache<String, float[]> preloadVectorCache = Caffeine.newBuilder()
.maximumSize(50000)
.build(); // 永不过期
Cache<String, float[]> vectorCache = Caffeine.newBuilder()
.maximumSize(50000)
.expireAfterWrite(1, TimeUnit.HOURS) // 1小时过期
.build();
4.2 双缓存配置详解
配置文件
yaml
# application-similarity.yml
similarity:
cache:
# 动态缓存配置
enabled: true
max-size: 50000 # 最多缓存5万个向量
expire-minutes: 3600 # 60小时后过期(0=永不过期)
# 预加载缓存配置
preload:
never-expire: true # 永不过期
max-size: 50000 # 最多缓存5万个向量
key-prefix: "preload:" # 缓存键前缀(用于区分)
缓存Bean配置
java
@Configuration
public class CacheConfig {
@Bean
@Qualifier("vectorCache")
public Cache<String, float[]> vectorCache(SimilarityProperties properties) {
Caffeine<Object, Object> builder = Caffeine.newBuilder()
.maximumSize(properties.getCache().getMaxSize())
.recordStats(); // 启用统计
// 根据配置决定是否过期
Integer expireMinutes = properties.getCache().getExpireMinutes();
if (expireMinutes != null && expireMinutes > 0) {
builder.expireAfterWrite(expireMinutes, TimeUnit.MINUTES);
}
return builder.build();
}
@Bean
@Qualifier("preloadVectorCache")
public Cache<String, float[]> preloadVectorCache(
SimilarityProperties properties) {
// 预加载缓存:永不过期
return Caffeine.newBuilder()
.maximumSize(properties.getCache().getPreload().getMaxSize())
.recordStats()
.build();
}
}
4.3 双缓存使用策略
查询流程(encode方法)
java
public float[] encode(String text) throws Exception {
String cacheKey = generateCacheKey(text);
// ========== 策略1:优先查询动态缓存 ==========
float[] cachedVector = vectorCache.getIfPresent(cacheKey);
if (cachedVector != null) {
log.debug("使用缓存向量: {} (动态缓存命中)", text);
return cachedVector;
}
// ========== 策略2:降级查询预加载缓存 ==========
float[] preloadedVector = preloadVectorCache.getIfPresent(cacheKey);
if (preloadedVector != null) {
log.info("使用预加载缓存向量: {} (预加载缓存命中)", text);
// 同步到动态缓存(热点数据提升)
vectorCache.put(cacheKey, preloadedVector);
return preloadedVector;
}
// ========== 策略3:缓存未命中,执行推理 ==========
log.debug("编码新文本: {} (缓存未命中)", text);
float[] vector = performEncoding(text);
// 存入动态缓存
vectorCache.put(cacheKey, vector);
return vector;
}
预加载流程(preloadEncode方法)
java
public float[] preloadEncode(String text) throws Exception {
String cacheKey = generateCacheKey(text);
// ========== 只查询预加载缓存 ==========
float[] cachedVector = preloadVectorCache.getIfPresent(cacheKey);
if (cachedVector != null) {
log.debug("使用预加载缓存向量: {} (预加载缓存命中)", text);
return cachedVector;
}
// ========== 缓存未命中,执行推理 ==========
log.debug("预加载编码新文本: {} (预加载缓存未命中)", text);
float[] vector = performEncoding(text);
// ========== 存入预加载缓存(永不过期) ==========
preloadVectorCache.put(cacheKey, vector);
log.debug("预加载向量编码完成并缓存: {}", text);
return vector;
}
4.4 缓存统计与监控
java
public Map<String, Object> getCacheStats() {
Map<String, Object> stats = new HashMap<>();
// ========== 动态缓存统计 ==========
if (vectorCache != null) {
CacheStats dynamicStats = vectorCache.stats();
Map<String, Object> dynamicCacheStats = new HashMap<>();
dynamicCacheStats.put("hitCount", dynamicStats.hitCount());
dynamicCacheStats.put("missCount", dynamicStats.missCount());
dynamicCacheStats.put("hitRate", dynamicStats.hitRate());
dynamicCacheStats.put("evictionCount", dynamicStats.evictionCount());
dynamicCacheStats.put("estimatedSize", vectorCache.estimatedSize());
stats.put("dynamicCache", dynamicCacheStats);
}
// ========== 预加载缓存统计 ==========
if (preloadVectorCache != null) {
CacheStats preloadStats = preloadVectorCache.stats();
Map<String, Object> preloadCacheStats = new HashMap<>();
preloadCacheStats.put("hitCount", preloadStats.hitCount());
preloadCacheStats.put("missCount", preloadStats.missCount());
preloadCacheStats.put("hitRate", preloadStats.hitRate());
preloadCacheStats.put("evictionCount", preloadStats.evictionCount());
preloadCacheStats.put("estimatedSize", preloadVectorCache.estimatedSize());
stats.put("preloadCache", preloadCacheStats);
}
// ========== 总计统计 ==========
stats.put("totalSize", getTotalCacheSize());
stats.put("status", "双缓存策略已启用");
return stats;
}
缓存统计示例输出
json
{
"dynamicCache": {
"hitCount": 15234,
"missCount": 3421,
"hitRate": 0.816,
"evictionCount": 124,
"estimatedSize": 8523
},
"preloadCache": {
"hitCount": 45678,
"missCount": 0,
"hitRate": 1.0,
"evictionCount": 0,
"estimatedSize": 5240
},
"totalSize": 13763,
"status": "双缓存策略已启用"
}
缓存性能指标
| 指标 | 计算公式 | 理想值 |
|---|---|---|
| 命中率 | hitCount / (hitCount + missCount) | > 0.80 |
| 驱逐次数 | evictionCount | < 100 |
| 缓存大小 | estimatedSize | < maxSize |