model.onnx 深度分析报告(第2篇)

model.onnx在语义匹配系统中的应用实践


📋 文档概览

本文档是 model.onnx 深度分析系列的第二篇,将深入解析实际项目中如何使用ONNX模型进行语义匹配。

文档目标

  • ✅ 深入理解VectorEncoder类的完整实现
  • ✅ 掌握批量推理的性能优化技巧
  • ✅ 学习GPU/CUDA加速配置方法
  • ✅ 理解双缓存策略的设计思路
  • ✅ 掌握ONNX Runtime的实战技巧

🏗️ 一、系统架构全景图

1.1 语义匹配系统整体架构

配置层
数据层
AI推理层
应用层
用户层
前端界面
API请求
Controller层
Service层
SimilarityService
VectorEncoder
ONNX Runtime
model.onnx
HNSW索引
向量数据库
双缓存系统
Caffeine缓存
application-similarity.yml


1.2 核心组件职责

组件 职责 关键类
VectorEncoder 文本向量化编码 VectorEncoder.java
ONNX Runtime 模型推理执行 OrtSession, OrtEnvironment
Tokenizer 文本预处理分词 内置在VectorEncoder中
Cache System 向量缓存管理 Caffeine双缓存
HNSW Index 高效向量检索 HNSW算法实现
Configuration 系统配置管理 SimilarityProperties

💻 二、VectorEncoder核心实现解析

2.1 类结构与生命周期管理

完整的类依赖关系
java 复制代码
@Component
@Slf4j
public class VectorEncoder {
    // ========== 依赖注入 ==========
    @Autowired
    private SimilarityProperties properties;

    @Autowired
    @Qualifier("vectorCache")           // 动态缓存(用户查询)
    private Cache<String, float[]> vectorCache;

    @Autowired
    @Qualifier("preloadVectorCache")    // 预加载缓存(系统字段)
    private Cache<String, float[]> preloadVectorCache;

    @Autowired
    private ApplicationEventPublisher eventPublisher;

    // ========== ONNX核心组件 ==========
    private OrtEnvironment environment;    // ONNX运行环境(单例)
    private OrtSession session;            // ONNX会话(模型实例)
    private boolean modelLoaded = false;   // 模型加载状态
    private final Object modelLock = new Object();  // 推理同步锁

    // ========== 模型就绪管理 ==========
    private final AtomicBoolean modelReady = new AtomicBoolean(false);
    private final CountDownLatch modelReadyLatch = new CountDownLatch(1);

    // ========== 词汇表 ==========
    private final Map<String, Integer> vocab = new ConcurrentHashMap<>();

    // ========== BERT特殊Token ==========
    private final int CLS_TOKEN = 101;  // [CLS]
    private final int SEP_TOKEN = 102;  // [SEP]
    private final int PAD_TOKEN = 0;    // [PAD]
    private final int UNK_TOKEN = 100;  // [UNK]
}

2.2 分阶段初始化设计(性能优化关键)

设计思路

传统的单阶段初始化会导致应用启动缓慢(加载409MB模型需要5-7秒),影响用户体验。因此采用两阶段异步初始化
缓存系统 ONNX Runtime VectorEncoder Spring容器 缓存系统 ONNX Runtime VectorEncoder Spring容器 阶段1:应用启动时(同步,快速) 阶段2:应用启动后(异步,不阻塞) @PostConstruct init() 创建OrtEnvironment 加载词汇表 (vocab.txt) 初始化完成 (耗时: 100-200ms) @EventListener ApplicationReadyEvent 加载model.onnx (409MB) 模型加载完成 (耗时: 5-7秒) 设置modelReady=true 模型预热 (推理5个样本) 填充预加载缓存 发布VectorModelReadyEvent

实现代码详解

阶段1:轻量级初始化(@PostConstruct)

java 复制代码
@PostConstruct
public void init() {
    try {
        log.info("开始初始化向量编码器(阶段1:词汇表加载)...");

        // 1. 创建ONNX Runtime环境(单例,轻量级)
        // 设置日志级别为ERROR,避免控制台乱码
        environment = OrtEnvironment.getEnvironment(
            OrtLoggingLevel.ORT_LOGGING_LEVEL_ERROR
        );

        // 2. 加载词汇表(vocab.txt,约21KB)
        loadVocabulary();

        log.info("词汇表加载完成,词汇量:{},等待应用启动后加载ONNX模型",
            vocab.size());

    } catch (Exception e) {
        log.error("词汇表加载失败", e);
        throw new RuntimeException("词汇表加载失败", e);
    }
}

阶段2:重量级模型加载(@EventListener + @Async)

java 复制代码
@EventListener(ApplicationReadyEvent.class)
@Async("similarityExecutor")  // 异步执行,不阻塞主线程
public void initModelAsync() {
    try {
        long startTime = System.currentTimeMillis();
        log.info("开始异步加载ONNX模型(阶段2:模型加载与预热)...");

        // 1. 加载ONNX模型(409MB,耗时5-7秒)
        loadModel();

        // 2. 立即标记模型就绪(此时已可用)
        modelLoaded = true;
        modelReady.set(true);
        modelReadyLatch.countDown();
        log.info("模型加载完成,已标记为就绪状态");

        // 3. 模型预热(推理几个样本,加速后续推理)
        warmupModel();

        long duration = System.currentTimeMillis() - startTime;
        log.info("ONNX模型异步加载完成,耗时:{}ms,模型维度:{}",
            duration, properties.getModel().getVectorDimension());

        // 4. 发布模型就绪事件(触发索引构建等后续任务)
        VectorModelReadyEvent event = new VectorModelReadyEvent(
            this,
            properties.getModel().getVectorDimension(),
            vocab.size(),
            duration
        );
        eventPublisher.publishEvent(event);
        log.info("已发布向量模型就绪事件,触发索引构建");

    } catch (Exception e) {
        log.error("ONNX模型异步加载失败", e);
        // 不抛异常,允许应用继续运行(但相似度功能不可用)
    }
}
性能对比
初始化策略 启动时间 功能可用时间 用户体验
单阶段同步 8-10秒 8-10秒 ❌ 差(等待时间长)
两阶段异步 1-2秒 7-9秒(后台加载) ✅ 好(快速启动)

2.3 模型加载详细流程

关键代码解析
java 复制代码
private void loadModel() throws OrtException, IOException {
    String modelPath = properties.getModel().getActualPath();
    log.info("加载ONNX模型:{}", modelPath);

    try {
        // ========== 步骤1:加载模型文件 ==========
        byte[] modelBytes;
        if (modelPath.startsWith("classpath:")) {
            // 从classpath加载(JAR包内资源)
            String resourcePath = modelPath.substring("classpath:".length());
            ClassPathResource resource = new ClassPathResource(resourcePath);
            modelBytes = readStreamToByteArray(resource.getInputStream());
        } else if (modelPath.startsWith("/")) {
            // 从绝对路径加载(转为classpath路径)
            String resourcePath = modelPath.substring(1); // 去掉开头的/
            ClassPathResource resource = new ClassPathResource(resourcePath);
            if (resource.exists()) {
                modelBytes = readStreamToByteArray(resource.getInputStream());
            } else {
                throw new FileNotFoundException("模型文件不存在: " + resourcePath);
            }
        } else {
            // 从文件系统加载
            File modelFile = new File(modelPath);
            if (!modelFile.exists()) {
                throw new FileNotFoundException("模型文件不存在: " + modelPath);
            }
            try (FileInputStream fis = new FileInputStream(modelFile)) {
                modelBytes = readStreamToByteArray(fis);
            }
        }

        // ========== 步骤2:创建ONNX会话选项 ==========
        OrtSession.SessionOptions sessionOptions = new OrtSession.SessionOptions();

        // 设置优化级别
        sessionOptions.setOptimizationLevel(
            OrtSession.SessionOptions.OptLevel.BASIC_OPT
        );

        // ========== 步骤3:配置硬件加速(GPU/CPU) ==========
        configureGpuAcceleration(sessionOptions);

        // ========== 步骤4:创建ONNX会话(加载模型) ==========
        session = environment.createSession(modelBytes, sessionOptions);
        modelLoaded = true;

        log.info("ONNX模型加载成功,大小:{}MB,路径:{}",
            modelBytes.length / 1024 / 1024, modelPath);

    } catch (IOException e) {
        throw SimilarityException.modelLoadError(modelPath, e);
    }
}
内存优化技巧
java 复制代码
// ❌ 错误做法:直接读取文件到内存
File file = new File("model.onnx");
byte[] bytes = Files.readAllBytes(file.toPath());  // 一次性读取409MB

// ✅ 正确做法:流式读取,避免内存峰值
private byte[] readStreamToByteArray(InputStream inputStream) throws IOException {
    try (ByteArrayOutputStream buffer = new ByteArrayOutputStream()) {
        byte[] data = new byte[8192];  // 8KB缓冲区
        int nRead;
        while ((nRead = inputStream.read(data, 0, data.length)) != -1) {
            buffer.write(data, 0, nRead);
        }
        return buffer.toByteArray();
    }
}

2.4 GPU加速配置详解

GPU加速架构
复制代码
ONNX Runtime GPU加速栈
│
├── 应用层:VectorEncoder
│   └── configureGpuAcceleration()
│
├── ONNX Runtime层
│   ├── CUDA Execution Provider
│   ├── cuDNN算子库
│   └── cuBLAS线性代数库
│
├── CUDA层
│   ├── CUDA Toolkit (12.x)
│   ├── CUDA Runtime
│   └── CUDA Driver
│
└── 硬件层
    └── NVIDIA GPU (算力 >= 3.5)
完整的GPU配置实现
java 复制代码
private void configureGpuAcceleration(OrtSession.SessionOptions sessionOptions)
    throws OrtException {

    SimilarityProperties.Performance.Gpu gpuConfig =
        properties.getPerformance().getGpu();

    // ========== 判断1:用户是否启用GPU ==========
    if (!gpuConfig.getEnabled()) {
        log.info("GPU加速未启用,使用CPU执行");
        configureCpuExecution(sessionOptions);
        return;
    }

    // ========== 判断2:检测CUDA环境兼容性 ==========
    log.info("GPU加速已启用,开始环境检测...");
    CudaCompatibilityResult compatibilityResult = checkCudaCompatibility();

    if (!compatibilityResult.isCompatible()) {
        // CUDA环境不兼容,记录警告
        log.warn("CUDA环境检测发现问题: {}",
            compatibilityResult.getFailureReason());
        log.warn("将尝试配置GPU,如果失败会抛出异常");
    } else {
        log.info("CUDA环境检测通过");
    }

    // ========== 判断3:尝试配置CUDA执行提供程序 ==========
    boolean gpuConfigured = tryConfigureGpuAcceleration(
        sessionOptions,
        gpuConfig,
        compatibilityResult.getEnvironmentInfo()
    );

    if (!gpuConfigured) {
        // GPU配置失败,抛出详细错误信息
        String errorMessage = "GPU加速已启用但配置失败";
        log.error(errorMessage);
        log.error("错误详情: CUDA执行提供程序配置失败");
        log.error("可能的解决方案:");
        log.error("  1. 在 NVIDIA 官网下载与 CUDA 12.x 匹配的 cuDNN 9.x");
        log.error("  2. 将 cuDNN DLL 文件复制到 CUDA bin 目录");
        log.error("  3. 将 CUDA bin 目录添加到系统 PATH 环境变量");
        log.error("  4. 重启 IDE、终端或服务");
        throw new SimilarityException("GPU_CONFIGURATION_FAILED",
            "GPU配置失败,请检查CUDA和cuDNN环境");
    }

    log.info("GPU加速配置成功,已启用CUDA执行");
}
CUDA环境检测实现
java 复制代码
private CudaCompatibilityResult checkCudaCompatibility() {
    log.info("=== 开始CUDA环境兼容性检测 ===");

    try {
        // 1. 检测CUDA环境信息
        CudaEnvironmentInfo envInfo = detectCudaEnvironment();
        logCudaEnvironmentStatus(envInfo);

        // 2. 检查GPU可用性
        if (!envInfo.isGpuAvailable()) {
            return CudaCompatibilityResult.failure("未检测到NVIDIA GPU", envInfo);
        }

        // 3. 检查CUDA Toolkit
        if (!envInfo.isCudaToolkitAvailable()) {
            return CudaCompatibilityResult.failure(
                "未检测到CUDA Toolkit,请安装: " +
                "https://developer.nvidia.com/cuda-downloads",
                envInfo
            );
        }

        // 4. 检查ONNX Runtime GPU支持
        if (!checkOnnxRuntimeGpuSupport()) {
            return CudaCompatibilityResult.failure(
                "ONNX Runtime不支持GPU或addCUDA方法不可用",
                envInfo
            );
        }

        log.info("CUDA环境兼容性检测通过");
        return CudaCompatibilityResult.success(envInfo);

    } catch (Exception e) {
        log.error("CUDA兼容性检测过程中发生异常: {}", e.getMessage());
        return CudaCompatibilityResult.failure(
            "检测过程异常: " + e.getMessage(),
            new CudaEnvironmentInfo()
        );
    }
}
GPU信息检测示例
java 复制代码
private CudaEnvironmentInfo detectCudaEnvironment() {
    CudaEnvironmentInfo info = new CudaEnvironmentInfo();

    // 检测NVIDIA GPU(通过nvidia-smi命令)
    try {
        Process process = Runtime.getRuntime().exec(
            "nvidia-smi --query-gpu=name,memory.total --format=csv,noheader,nounits"
        );
        BufferedReader reader = new BufferedReader(
            new InputStreamReader(process.getInputStream())
        );
        String line = reader.readLine();

        if (line != null && !line.trim().isEmpty()) {
            String[] parts = line.split(",");
            if (parts.length >= 2) {
                info.setGpuAvailable(true);
                info.setGpuName(parts[0].trim());
                info.setGpuMemoryMB(Integer.parseInt(parts[1].trim()));
            }
        }
        process.waitFor();
    } catch (Exception e) {
        log.debug("检测NVIDIA GPU失败: {}", e.getMessage());
    }

    // 检测CUDA Toolkit版本(通过nvcc命令)
    try {
        Process process = Runtime.getRuntime().exec("nvcc --version");
        BufferedReader reader = new BufferedReader(
            new InputStreamReader(process.getInputStream())
        );
        String line;
        while ((line = reader.readLine()) != null) {
            if (line.contains("release")) {
                // 解析版本号:Cuda compilation tools, release 12.6, V12.6.85
                String[] parts = line.split(",");
                for (String part : parts) {
                    if (part.trim().startsWith("release")) {
                        info.setCudaToolkitAvailable(true);
                        info.setCudaVersion(part.trim().replace("release ", ""));
                        break;
                    }
                }
                break;
            }
        }
        process.waitFor();
    } catch (Exception e) {
        log.debug("检测CUDA Toolkit失败: {}", e.getMessage());
    }

    return info;
}

2.5 核心编码流程:从文本到向量

完整的编码流程
java 复制代码
public float[] encode(String text) throws Exception {
    // ========== 步骤0:等待模型就绪 ==========
    if (!modelReady.get()) {
        log.warn("ONNX模型尚未就绪,等待异步加载完成...");
        boolean ready = modelReadyLatch.await(60, TimeUnit.SECONDS);
        if (!ready) {
            throw new IllegalStateException("ONNX模型加载超时(等待60秒)");
        }
        log.info("ONNX模型已就绪,继续执行编码");
    }

    // ========== 步骤1:处理空文本 ==========
    if (text == null || text.trim().isEmpty()) {
        return new float[properties.getModel().getVectorDimension()]; // 零向量
    }

    // ========== 步骤2:检查缓存 ==========
    String cacheKey = generateCacheKey(text);  // MD5哈希
    float[] cachedVector = vectorCache.getIfPresent(cacheKey);
    if (cachedVector != null) {
        log.debug("使用缓存向量: {} (缓存命中)", text);
        return cachedVector;
    }

    // ========== 步骤3:尝试命中预加载缓存 ==========
    float[] preloadedVector = preloadVectorCache.getIfPresent(cacheKey);
    if (preloadedVector != null) {
        log.info("使用预加载缓存向量: {}", text);
        vectorCache.put(cacheKey, preloadedVector);  // 同步到动态缓存
        return preloadedVector;
    }

    log.debug("编码新文本: {} (缓存未命中)", text);

    // ========== 步骤4:执行ONNX推理 ==========
    float[] vector = performEncoding(text);

    // ========== 步骤5:存入缓存 ==========
    vectorCache.put(cacheKey, vector);

    log.debug("向量编码完成并缓存: {}", text);
    return vector;
}
performEncoding详细实现
java 复制代码
private float[] performEncoding(String text) throws Exception {
    // ========== 步骤1:Tokenize文本 ==========
    long[] inputIds = tokenize(text);              // [128]
    long[] attentionMask = createAttentionMask(inputIds);  // [128]
    long[] tokenTypeIds = new long[inputIds.length];       // [128],全0

    // ========== 步骤2:创建ONNX输入张量 ==========
    synchronized (modelLock) {
        Map<String, OnnxTensor> inputs = new HashMap<>();

        try (OnnxTensor inputTensor = OnnxTensor.createTensor(
                environment, new long[][]{inputIds});
             OnnxTensor attentionTensor = OnnxTensor.createTensor(
                environment, new long[][]{attentionMask});
             OnnxTensor tokenTypeTensor = OnnxTensor.createTensor(
                environment, new long[][]{tokenTypeIds})) {

            inputs.put("input_ids", inputTensor);
            inputs.put("attention_mask", attentionTensor);
            inputs.put("token_type_ids", tokenTypeTensor);

            // ========== 步骤3:执行ONNX推理 ==========
            try (OrtSession.Result result = session.run(inputs)) {
                return processModelOutputOnly(result, text);
            }
        }
    }
}
Tokenization实现(完整的BERT分词)
java 复制代码
private long[] tokenize(String text) {
    List<Integer> tokens = new ArrayList<>();

    // 1. 添加[CLS]标记
    tokens.add(CLS_TOKEN);  // 101

    // 2. 文本预处理
    String normalizedText = text.trim();

    // 3. 使用贪心最长匹配进行分词
    int i = 0;
    while (i < normalizedText.length()) {
        boolean matched = false;

        // 从最长可能的子串开始匹配(最多10个字符)
        for (int j = Math.min(i + 10, normalizedText.length()); j > i; j--) {
            String subStr = normalizedText.substring(i, j);
            Integer tokenId = vocab.get(subStr);

            if (tokenId != null) {
                tokens.add(tokenId);
                i = j;
                matched = true;
                break;
            }
        }

        // 如果没有匹配到,尝试单字符匹配
        if (!matched) {
            char c = normalizedText.charAt(i);
            String charStr = String.valueOf(c);
            Integer tokenId = vocab.get(charStr);

            if (tokenId != null) {
                tokens.add(tokenId);
            } else {
                tokens.add(UNK_TOKEN);  // 100,未知token
            }
            i++;
        }
    }

    // 4. 添加[SEP]标记
    tokens.add(SEP_TOKEN);  // 102

    // 5. 截断或填充到固定长度128
    int maxLength = properties.getTokenizer().getMaxLength();  // 128
    while (tokens.size() < maxLength) {
        tokens.add(PAD_TOKEN);  // 0
    }
    if (tokens.size() > maxLength) {
        tokens = tokens.subList(0, maxLength);
        tokens.set(maxLength - 1, SEP_TOKEN);  // 确保最后是[SEP]
    }

    // 6. 转换为long数组
    long[] result = tokens.stream().mapToLong(Integer::longValue).toArray();

    log.debug("文本 '{}' tokenize结果前10个: {}", text,
        Arrays.toString(Arrays.copyOf(result, Math.min(10, result.length))));

    return result;
}
Tokenization示例
复制代码
输入文本:"中国首都北京"

步骤1:添加[CLS] token
tokens = [101]

步骤2-3:贪心最长匹配分词
"中"     -> token_id = 704   -> tokens = [101, 704]
"国"     -> token_id = 1744  -> tokens = [101, 704, 1744]
"首"     -> token_id = 7674  -> tokens = [101, 704, 1744, 7674]
"都"     -> token_id = 7662  -> tokens = [101, 704, 1744, 7674, 7662]
"北"     -> token_id = 738   -> tokens = [101, 704, 1744, 7674, 7662, 738]
"京"     -> token_id = 776   -> tokens = [101, 704, 1744, 7674, 7662, 738, 776]

步骤4:添加[SEP] token
tokens = [101, 704, 1744, 7674, 7662, 738, 776, 102]

步骤5:填充到长度128
tokens = [101, 704, 1744, 7674, 7662, 738, 776, 102, 0, 0, ..., 0]
         (长度: 128)

对应的attention_mask:
mask = [1, 1, 1, 1, 1, 1, 1, 1, 0, 0, ..., 0]
       (1表示真实token,0表示padding)

🚀 三、批量推理优化技术

3.1 为什么需要批量推理?

单个推理 vs 批量推理性能对比
复制代码
场景:编码1000个文本

单个推理(batch_size=1):
├── 推理次数: 1000次
├── 总耗时: 1000 * 25ms = 25,000ms (25秒)
├── GPU利用率: ~15%(大量时间浪费在数据传输)
└── 内存拷贝: 1000次 CPU→GPU

批量推理(batch_size=128):
├── 推理次数: 8次 (1000/128=7.8,向上取整)
├── 总耗时: 8 * 150ms = 1,200ms (1.2秒)
├── GPU利用率: ~85%(充分利用GPU并行计算)
├── 内存拷贝: 8次 CPU→GPU
└── 性能提升: 25秒 → 1.2秒,提速 20.8倍
批量推理的核心优势
  1. 减少数据传输开销:CPU→GPU的数据传输是瓶颈
  2. 提高GPU利用率:GPU擅长并行计算,批量处理更高效
  3. 摊薄固定开销:模型加载、context切换等固定开销被分摊

3.2 批量编码实现详解

java 复制代码
public Map<String, float[]> batchEncode(List<SimilarityItem> items) {
    if (!modelLoaded) {
        log.error("ONNX模型未加载,批量编码失败");
        return createEmptyVectors(items);
    }

    Map<String, float[]> vectors = new ConcurrentHashMap<>();

    // ========== 步骤1:检查缓存,分离已缓存和未缓存的项 ==========
    List<SimilarityItem> uncachedItems = new ArrayList<>();
    Map<String, String> itemTextMap = new HashMap<>();

    for (SimilarityItem item : items) {
        String text = item.getChineseName();
        String cacheKey = generateCacheKey(text);
        float[] cachedVector = vectorCache.getIfPresent(cacheKey);

        if (cachedVector != null) {
            // 缓存命中
            vectors.put(item.getId(), cachedVector);
        } else {
            // 缓存未命中,需要编码
            uncachedItems.add(item);
            itemTextMap.put(item.getId(), text);
        }
    }

    log.info("批量编码开始,总数:{},缓存命中:{},需要编码:{}",
        items.size(), vectors.size(), uncachedItems.size());

    if (uncachedItems.isEmpty()) {
        log.debug("所有向量都已缓存,无需编码");
        return vectors;
    }

    // ========== 步骤2:对未缓存的项进行批量编码 ==========
    boolean batchEnabled = properties.getPerformance().getOnnxBatch().getEnabled();
    int onnxBatchSize = batchEnabled ?
        Math.min(
            properties.getPerformance().getOnnxBatch().getSize(),  // 配置的批量大小
            uncachedItems.size()
        ) : 1;

    // 分批处理
    List<List<SimilarityItem>> batches =
        CollectionUtils.partitionList(uncachedItems, onnxBatchSize);

    for (List<SimilarityItem> batch : batches) {
        try {
            if (batchEnabled && batch.size() > 1) {
                // 使用批量ONNX推理
                Map<String, float[]> batchVectors = batchEncodeOnnx(batch);
                vectors.putAll(batchVectors);
            } else {
                // 使用单个编码
                for (SimilarityItem item : batch) {
                    try {
                        float[] vector = encode(item.getChineseName());
                        vectors.put(item.getId(), vector);
                    } catch (Exception ex) {
                        log.error("单个编码失败: {}", item.getChineseName(), ex);
                        vectors.put(item.getId(),
                            new float[properties.getModel().getVectorDimension()]);
                    }
                }
            }
        } catch (Exception e) {
            log.error("批量ONNX编码失败,回退到单个编码", e);
            // 降级:回退到单个编码
            for (SimilarityItem item : batch) {
                try {
                    float[] vector = encode(item.getChineseName());
                    vectors.put(item.getId(), vector);
                } catch (Exception ex) {
                    log.error("单个编码也失败: {}", item.getChineseName(), ex);
                    vectors.put(item.getId(),
                        new float[properties.getModel().getVectorDimension()]);
                }
            }
        }
    }

    log.debug("批量编码完成,成功:{},缓存统计:{}",
        vectors.size(), getCacheStats());
    return vectors;
}

3.3 批量ONNX推理核心实现

java 复制代码
private Map<String, float[]> batchEncodeOnnx(List<SimilarityItem> items)
    throws Exception {

    if (items.isEmpty()) {
        return new HashMap<>();
    }

    log.debug("开始批量ONNX推理,批大小:{}", items.size());

    // ========== 步骤1:批量tokenize所有文本 ==========
    List<String> texts = items.stream()
        .map(SimilarityItem::getChineseName)
        .collect(Collectors.toList());

    List<long[]> allInputIds = new ArrayList<>();
    List<long[]> allAttentionMasks = new ArrayList<>();
    List<long[]> allTokenTypeIds = new ArrayList<>();

    for (String text : texts) {
        long[] inputIds = tokenize(text);
        long[] attentionMask = createAttentionMask(inputIds);
        long[] tokenTypeIds = new long[inputIds.length];  // 全0

        allInputIds.add(inputIds);
        allAttentionMasks.add(attentionMask);
        allTokenTypeIds.add(tokenTypeIds);
    }

    // ========== 步骤2:转换为二维数组 ==========
    long[][] inputIdsArray = allInputIds.toArray(new long[0][]);
    long[][] attentionMaskArray = allAttentionMasks.toArray(new long[0][]);
    long[][] tokenTypeIdsArray = allTokenTypeIds.toArray(new long[0][]);

    // 形状:[batch_size, 128]
    // 例如:batch_size=128时,inputIdsArray.shape = [128, 128]

    // ========== 步骤3:执行批量ONNX推理 ==========
    Map<String, float[]> results = new HashMap<>();

    synchronized (modelLock) {
        try (OnnxTensor inputTensor = OnnxTensor.createTensor(
                environment, inputIdsArray);
             OnnxTensor attentionTensor = OnnxTensor.createTensor(
                environment, attentionMaskArray);
             OnnxTensor tokenTypeTensor = OnnxTensor.createTensor(
                environment, tokenTypeIdsArray)) {

            Map<String, OnnxTensor> inputs = new HashMap<>();
            inputs.put("input_ids", inputTensor);
            inputs.put("attention_mask", attentionTensor);
            inputs.put("token_type_ids", tokenTypeTensor);

            try (OrtSession.Result result = session.run(inputs)) {
                // ========== 步骤4:处理批量输出 ==========
                OnnxValue outputValue = result.get(0);
                float[][][] batchOutput = (float[][][]) outputValue.getValue();
                // 形状:[batch_size, 128, 768]

                // ========== 步骤5:提取每个文本的向量并缓存 ==========
                for (int i = 0; i < items.size(); i++) {
                    SimilarityItem item = items.get(i);
                    String text = texts.get(i);

                    // 提取[CLS] token的向量表示(第一个token)
                    float[] clsVector = batchOutput[i][0].clone();
                    // clsVector.shape = [768]

                    // 归一化向量(L2归一化)
                    normalize(clsVector);

                    // 存入缓存
                    String cacheKey = generateCacheKey(text);
                    vectorCache.put(cacheKey, clsVector);

                    results.put(item.getId(), clsVector);
                }
            }
        }
    }

    log.debug("批量ONNX推理完成,处理了{}个文本", items.size());
    return results;
}
批量推理内存布局
复制代码
输入张量形状(batch_size=3的示例):
input_ids:       [3, 128] int64
attention_mask:  [3, 128] int64
token_type_ids:  [3, 128] int64

输出张量形状:
output: [3, 128, 768] float32

内存布局可视化:
[
  [ [0.12, 0.34, ..., 0.78],  // 文本1, token 0 ([CLS])
    [0.22, 0.44, ..., 0.88],  // 文本1, token 1
    ...
    [0.02, 0.04, ..., 0.08]   // 文本1, token 127
  ],
  [ [0.15, 0.37, ..., 0.79],  // 文本2, token 0 ([CLS])
    ...
  ],
  [ [0.18, 0.39, ..., 0.81],  // 文本3, token 0 ([CLS])
    ...
  ]
]

提取结果:
文本1向量 = output[0][0] = [0.12, 0.34, ..., 0.78]  (768维)
文本2向量 = output[1][0] = [0.15, 0.37, ..., 0.79]  (768维)
文本3向量 = output[2][0] = [0.18, 0.39, ..., 0.81]  (768维)

💾 四、双缓存策略设计

4.1 为什么需要双缓存?

业务场景分析
复制代码
字典关系管理系统的向量使用模式:

类型1:系统预置数据(预加载缓存)
├── 标准字段名称(如:"客户名称"、"订单编号")
├── 词根向量(如:"名称"、"编号"、"日期")
├── 特点:数量固定、频繁使用、永不过期
└── 数量:约5000-10000个

类型2:用户查询数据(动态缓存)
├── 用户输入的搜索词(如:"客户姓名"、"订单号")
├── 临时生成的映射字段
├── 特点:数量不定、偶尔使用、定时过期
└── 数量:约1000-50000个
单缓存的问题
java 复制代码
// ❌ 问题:使用单一缓存(有过期时间)
Cache<String, float[]> singleCache = Caffeine.newBuilder()
    .maximumSize(50000)
    .expireAfterWrite(1, TimeUnit.HOURS)  // 1小时过期
    .build();

// 后果:
// 1. 系统预置的向量也会过期 -> 重新计算浪费资源
// 2. 高频字段向量频繁失效 -> 性能下降
// 3. 无法区分重要数据和临时数据 -> 缓存效率低
双缓存解决方案
java 复制代码
// ✅ 解决方案:双缓存架构
Cache<String, float[]> preloadVectorCache = Caffeine.newBuilder()
    .maximumSize(50000)
    .build();  // 永不过期

Cache<String, float[]> vectorCache = Caffeine.newBuilder()
    .maximumSize(50000)
    .expireAfterWrite(1, TimeUnit.HOURS)  // 1小时过期
    .build();

4.2 双缓存配置详解

配置文件
yaml 复制代码
# application-similarity.yml
similarity:
  cache:
    # 动态缓存配置
    enabled: true
    max-size: 50000              # 最多缓存5万个向量
    expire-minutes: 3600         # 60小时后过期(0=永不过期)

    # 预加载缓存配置
    preload:
      never-expire: true         # 永不过期
      max-size: 50000            # 最多缓存5万个向量
      key-prefix: "preload:"     # 缓存键前缀(用于区分)
缓存Bean配置
java 复制代码
@Configuration
public class CacheConfig {

    @Bean
    @Qualifier("vectorCache")
    public Cache<String, float[]> vectorCache(SimilarityProperties properties) {
        Caffeine<Object, Object> builder = Caffeine.newBuilder()
            .maximumSize(properties.getCache().getMaxSize())
            .recordStats();  // 启用统计

        // 根据配置决定是否过期
        Integer expireMinutes = properties.getCache().getExpireMinutes();
        if (expireMinutes != null && expireMinutes > 0) {
            builder.expireAfterWrite(expireMinutes, TimeUnit.MINUTES);
        }

        return builder.build();
    }

    @Bean
    @Qualifier("preloadVectorCache")
    public Cache<String, float[]> preloadVectorCache(
        SimilarityProperties properties) {

        // 预加载缓存:永不过期
        return Caffeine.newBuilder()
            .maximumSize(properties.getCache().getPreload().getMaxSize())
            .recordStats()
            .build();
    }
}

4.3 双缓存使用策略

查询流程(encode方法)
java 复制代码
public float[] encode(String text) throws Exception {
    String cacheKey = generateCacheKey(text);

    // ========== 策略1:优先查询动态缓存 ==========
    float[] cachedVector = vectorCache.getIfPresent(cacheKey);
    if (cachedVector != null) {
        log.debug("使用缓存向量: {} (动态缓存命中)", text);
        return cachedVector;
    }

    // ========== 策略2:降级查询预加载缓存 ==========
    float[] preloadedVector = preloadVectorCache.getIfPresent(cacheKey);
    if (preloadedVector != null) {
        log.info("使用预加载缓存向量: {} (预加载缓存命中)", text);
        // 同步到动态缓存(热点数据提升)
        vectorCache.put(cacheKey, preloadedVector);
        return preloadedVector;
    }

    // ========== 策略3:缓存未命中,执行推理 ==========
    log.debug("编码新文本: {} (缓存未命中)", text);
    float[] vector = performEncoding(text);

    // 存入动态缓存
    vectorCache.put(cacheKey, vector);
    return vector;
}
预加载流程(preloadEncode方法)
java 复制代码
public float[] preloadEncode(String text) throws Exception {
    String cacheKey = generateCacheKey(text);

    // ========== 只查询预加载缓存 ==========
    float[] cachedVector = preloadVectorCache.getIfPresent(cacheKey);
    if (cachedVector != null) {
        log.debug("使用预加载缓存向量: {} (预加载缓存命中)", text);
        return cachedVector;
    }

    // ========== 缓存未命中,执行推理 ==========
    log.debug("预加载编码新文本: {} (预加载缓存未命中)", text);
    float[] vector = performEncoding(text);

    // ========== 存入预加载缓存(永不过期) ==========
    preloadVectorCache.put(cacheKey, vector);
    log.debug("预加载向量编码完成并缓存: {}", text);
    return vector;
}

4.4 缓存统计与监控

java 复制代码
public Map<String, Object> getCacheStats() {
    Map<String, Object> stats = new HashMap<>();

    // ========== 动态缓存统计 ==========
    if (vectorCache != null) {
        CacheStats dynamicStats = vectorCache.stats();
        Map<String, Object> dynamicCacheStats = new HashMap<>();
        dynamicCacheStats.put("hitCount", dynamicStats.hitCount());
        dynamicCacheStats.put("missCount", dynamicStats.missCount());
        dynamicCacheStats.put("hitRate", dynamicStats.hitRate());
        dynamicCacheStats.put("evictionCount", dynamicStats.evictionCount());
        dynamicCacheStats.put("estimatedSize", vectorCache.estimatedSize());
        stats.put("dynamicCache", dynamicCacheStats);
    }

    // ========== 预加载缓存统计 ==========
    if (preloadVectorCache != null) {
        CacheStats preloadStats = preloadVectorCache.stats();
        Map<String, Object> preloadCacheStats = new HashMap<>();
        preloadCacheStats.put("hitCount", preloadStats.hitCount());
        preloadCacheStats.put("missCount", preloadStats.missCount());
        preloadCacheStats.put("hitRate", preloadStats.hitRate());
        preloadCacheStats.put("evictionCount", preloadStats.evictionCount());
        preloadCacheStats.put("estimatedSize", preloadVectorCache.estimatedSize());
        stats.put("preloadCache", preloadCacheStats);
    }

    // ========== 总计统计 ==========
    stats.put("totalSize", getTotalCacheSize());
    stats.put("status", "双缓存策略已启用");

    return stats;
}
缓存统计示例输出
json 复制代码
{
  "dynamicCache": {
    "hitCount": 15234,
    "missCount": 3421,
    "hitRate": 0.816,
    "evictionCount": 124,
    "estimatedSize": 8523
  },
  "preloadCache": {
    "hitCount": 45678,
    "missCount": 0,
    "hitRate": 1.0,
    "evictionCount": 0,
    "estimatedSize": 5240
  },
  "totalSize": 13763,
  "status": "双缓存策略已启用"
}
缓存性能指标
指标 计算公式 理想值
命中率 hitCount / (hitCount + missCount) > 0.80
驱逐次数 evictionCount < 100
缓存大小 estimatedSize < maxSize
相关推荐
川西胖墩墩16 小时前
团队协作泳道图制作工具 PC中文免费
大数据·论文阅读·人工智能·架构·流程图
Codebee16 小时前
ooder SkillFlow:破解 AI 编程冲击,重构企业级开发全流程
人工智能
TOPGUS16 小时前
黑帽GEO手法揭秘:AI搜索阴影下的新型搜索劫持与风险
人工智能·搜索引擎·chatgpt·aigc·谷歌·数字营销
Sammyyyyy16 小时前
Symfony AI 正式发布,PHP 原生 AI 时代开启
开发语言·人工智能·后端·php·symfony·servbay
汽车仪器仪表相关领域16 小时前
光轴精准测量,安全照明保障——NHD-8101/8000型远近光检测仪项目实战分享
数据库·人工智能·安全·压力测试·可用性测试
WJSKad123516 小时前
基于yolov5-RepNCSPELAN的商品价格标签识别系统实现
人工智能·yolo·目标跟踪
早日退休!!!16 小时前
现代公司开发AI编译器的多元技术路线(非LLVM方向全解析)
人工智能
Sahadev_16 小时前
向量搜索:让电脑理解你的搜索意图
人工智能
大模型真好玩16 小时前
大模型训练全流程实战指南(一)——为什么要学习大模型训练?
人工智能·pytorch·python·大模型·deep learning