从零开发短视频电商 PaddleOCR Java推理 (二)优化Translator模型输入和输出

PaddleOCR提供了一系列测试图片,你可以通过点击这里来下载。

值得注意的是,PaddleOCR的模型更新速度远远快于DJL,这导致了一些DJL的优化滞后问题。因此,我们需要采取一些策略来跟上PaddleOCR的最新进展。

针对文本识别模型,你可以参考以下资源:

请注意,PP-OCRv4的识别模型使用的输入形状为3,48,320。此外,PP-OCRv4的识别模型默认使用的rec_algorithmSVTR_LCNet,需要留意其与原始SVTR的区别。

默认的识别模型算法可以在这里找到。

模型的输入和输出在PpWordRecognitionTranslator.java中。

java 复制代码
Criteria<Image, String> criteria = Criteria.builder()
                .optEngine("PaddlePaddle")
                .setTypes(Image.class, String.class)
                .optModelPath(Paths.get("C:\\laker-1"))
                .optTranslator(new PpWordRecognitionTranslator())
                .build();
java 复制代码
public class PpWordRecognitionTranslator implements NoBatchifyTranslator<Image, String> {

    private List<String> table;

    /** 
     * 准备方法,用于加载模型所需的数据。
     * 
     * @param ctx 翻译上下文
     * @throws IOException 如果读取数据时发生错误
     */
    @Override
    public void prepare(TranslatorContext ctx) throws IOException {
        try (InputStream is = ctx.getModel().getArtifact("ppocr_keys_v1.txt").openStream()) {
            // 从文本文件中读取表格数据
            table = Utils.readLines(is, true);
            // 在表格开头添加"blank"
            table.add(0, "blank");
            // 在表格末尾添加空字符串
            table.add("");
        }
    }

    /** 
     * 处理输出方法,将模型的输出转换为字符串。
     * 
     * @param ctx 翻译上下文
     * @param list 模型的输出列表
     * @return 转换后的字符串
     */
    @Override
    public String processOutput(TranslatorContext ctx, NDList list) {
        StringBuilder sb = new StringBuilder();
        NDArray tokens = list.singletonOrThrow();
        long[] indices = tokens.get(0).argMax(1).toLongArray();
        int lastIdx = 0;
        for (int i = 0; i < indices.length; i++) {
            if (indices[i] > 0 && !(i > 0 && indices[i] == lastIdx)) {
                // 将索引映射为相应的字符串并添加到结果字符串中
                sb.append(table.get((int) indices[i]));
            }
        }
        return sb.toString();
    }

    /** 
     * 处理输入方法,将图像数据转换为模型可接受的格式。
     * 
     * @param ctx 翻译上下文
     * @param input 输入图像
     * @return 转换后的NDList对象
     */
    @Override
    public NDList processInput(TranslatorContext ctx, Image input) {
        NDArray img = input.toNDArray(ctx.getNDManager());
        int[] hw = resize32(input.getWidth());
        // 调整图像大小、转换为张量并归一化
        img = NDImageUtils.resize(img, hw[1], hw[0]);
        img = NDImageUtils.toTensor(img).sub(0.5f).div(0.5f);
        // 在第一个维度上添加一个维度,通常用于将单个图像添加到批处理中
        img = img.expandDims(0);
        return new NDList(img);
    }

    private int[] resize32(double w) {
        // Paddle不依赖于宽高比
        // 计算新的图像宽度,确保它是32的倍数
        int width = ((int) Math.max(32, w)) / 32 * 32;
        return new int[]{32, width};
    }
}

这里有很多过时的了。

  • 输入尺寸最新为3,48,320
  • 输出的置信度没输出。

我们就来优化这2点。

1.修改模型加载这个地方

java 复制代码
Criteria<Image, String> criteria = Criteria.builder()
                .optEngine("PaddlePaddle")
                .setTypes(Image.class, String.class)
                .optModelPath(Paths.get("C:\\laker-1"))
                .optTranslator(new PpWordRecognitionTranslator2())
                .build();

2.新增PpWordRecognitionTranslator2类

java 复制代码
    /** 
     * 处理输出方法,将模型的输出转换为字符串。
     * 
     * @param ctx 翻译上下文
     * @param list 模型的输出列表
     * @return 转换后的字符串
     * @throws IOException 如果处理输出时发生错误
     */
    @Override
    public String processOutput(TranslatorContext ctx, NDList list) throws IOException {
        StringBuilder sb = new StringBuilder();
        NDArray tokens = list.singletonOrThrow();
        System.out.println("输出:" + tokens);
        
        // 计算出每行中最大值的索引位置 ND: (20, 97) 即 20行 97列,97列
        // 97列是初始化字典的行数,所以肯定是97列。
        // 20为图片中可能的字符数
        long[] indices = tokens.get(0).argMax(1).toLongArray();
        
        // 字符置信度
        float[] probs = new float[indices.length];

        for (int row = 0; row < indices.length; row++) {
            long dictIndex = indices[row];
            if (dictIndex > 0) { // 剔除 blank
                float[] v = tokens.get(0).get(row).toFloatArray();
                NDArray value = tokens.get(0).get(new NDIndex("" + row + ":" + (row + 1) + "," + dictIndex + ":" + (dictIndex + 1)));
                probs[row] = value.toFloatArray()[0];
                System.out.println(table.get((int) dictIndex) + " " + probs[row]);
                // 剔除置信度小于 0.6的
                if (probs[row] < 0.6f) {
                    continue;
                }
                sb.append(table.get((int) dictIndex));
            }
        }

        return sb.toString();
    }

    /** 
     * 处理输入方法,将图像数据转换为模型可接受的格式。
     * 
     * @param ctx 翻译上下文
     * @param input 输入图像
     * @return 转换后的NDList对象
     */
    @Override
    public NDList processInput(TranslatorContext ctx, Image input) {
        NDArray img = input.toNDArray(ctx.getNDManager());
        System.out.println(img);
        int[] hw = resize48(input.getWidth(), input.getHeight());
        img = NDImageUtils.resize(img, hw[1], hw[0]);
        // 将图像转换为张量,并进行归一化操作,减去0.5并除以0.5。
        img = NDImageUtils.toTensor(img).sub(0.5f).div(0.5f);
        // 在张量的第一个维度上添加一个维度,通常用于将单个图像添加到批处理中。
        img = img.expandDims(0);
        System.out.println("输入 :" + img);
        return new NDList(img);
    }

    /** 
     * 计算新的图像宽度和高度,确保宽度不超过48,并按比例调整高度。
     * 
     * @param w 图像原始宽度
     * @param h 图像原始高度
     * @return 包含新的宽度和高度的整数数组
     */
    private int[] resize48(double w, double h) {
        double maxWhRatio = w / h;
        int imgW = (int) (48 * maxWhRatio);
        // 检查按比例调整高度后是否超过了目标宽度
        int resizedW = (int) Math.ceil(48 * maxWhRatio);
        return new int[]{48, resizedW};
    }

输出为

ini 复制代码
// 原始图片 高 39,宽105 
: (39, 105, 3) cpu() uint8
[ Exceed max print size ]
// resize后的 高 48,宽130
输入 :ND: (1, 3, 48, 130) cpu() float32
[ Exceed max print size ]
// 字符识别个数为16个,字符是97个的英文字典
输出:softmax_2.tmp_0: (1, 16, 97) cpu() float32
[ Exceed max print size ]
[ 0.932977
1 0.99981683
+ 0.99966896
1 0.99980944
= 0.9967675
2 0.9998343
] 0.9975802

[1+1=2]
相关推荐
JH30739 小时前
SpringBoot 优雅处理金额格式化:拦截器+自定义注解方案
java·spring boot·spring
Coder_Boy_10 小时前
技术让开发更轻松的底层矛盾
java·大数据·数据库·人工智能·深度学习
2401_8362358610 小时前
中安未来SDK15:以AI之眼,解锁企业档案的数字化基因
人工智能·科技·深度学习·ocr·生活
invicinble11 小时前
对tomcat的提供的功能与底层拓扑结构与实现机制的理解
java·tomcat
较真的菜鸟11 小时前
使用ASM和agent监控属性变化
java
黎雁·泠崖11 小时前
【魔法森林冒险】5/14 Allen类(三):任务进度与状态管理
java·开发语言
qq_124987075312 小时前
基于SSM的动物保护系统的设计与实现(源码+论文+部署+安装)
java·数据库·spring boot·毕业设计·ssm·计算机毕业设计
Coder_Boy_12 小时前
基于SpringAI的在线考试系统-考试系统开发流程案例
java·数据库·人工智能·spring boot·后端
Mr_sun.12 小时前
Day06——权限认证-项目集成
java
瑶山12 小时前
Spring Cloud微服务搭建四、集成RocketMQ消息队列
java·spring cloud·微服务·rocketmq·dashboard