OCR Fusion: EasyOCR/Tesseract/PaddleOCR/TrOCR/GOT

文章目录


前言

OCROptical Character Recognition,光学字符识别)是指对包含文本内容的图像或视频进行处理和识别,并提取其中所包含的的文字及排版信息的过程(摘自维基百科)。根据其应用场景可分为印刷文本识别、手写文本识别、公式文本识别、场景文本识别以及古籍文本识别。

举一个实用的例子:想阅读一本电子书,但该书是扫描版的 PDF 文档,具有文件体积大、文字不可选、无法编辑和可读性差的缺点;我们可以借助OCR将文档识别并转换成轻量的 EPUB 格式,并提升阅读体验。有意义的应用场景还有很多,此处不一一列举。

最近由于实际需求,对之前和时下流行的OCR工具进行了一些货比三家式的接触和使用,尤其是近期(2024年9月)刚出端到端的GOT-OCR2.0,效果惊艳。遂决定在此记录,内容包括EasyOCRTesseractPaddleOCR及其PyTorch移植版、单行手写文本识别TrOCR,以及GOT,对应的repo地址为https://github.com/DaiHaoguang3151/ocr_fusion


一、基类 OCRExecutorBase

为了统一所有这些OCR的使用,我在repo的src/ocr_executor/ocr_executor_base.py文件中定义了基类OCRExecutorBase,代码如下所示。

1)在初始化self.__init__()部分,主要是用来完成模型加载等工作;

2)对图片进行批量的OCR识别则是通过self.execute()方法,其输入paths_images则是批量(image_path, opencv_image)对,这么设计是因为有些OCR工具的输入既可以是图片的路径,也可以是图片本身,比如OpenCV或者PIL读取的图片;

3)self._generate()self.execute()的核心部分,就是直接调用OCR的地方。

python 复制代码
# src/ocr_executor/ocr_executor_base.py
import math
from typing import List

import numpy as np


class OCRExecutorBase:
    def __init__(self):
        pass

    def _generate(self, images: List[np.ndarray]) -> List:
        """
        批量生成
        """
        raise NotImplementedError

    def execute(self, paths_images: List, batch_size=16):
        """
        执行ocr
        """
        results = []
        num = len(paths_images)
        paths = [ele[0] for ele in paths_images]
        # print("IMAGE PATHS: ", paths)
        images = [ele[1] for ele in paths_images]
        iterations = math.ceil(num / batch_size)
        for iter in range(iterations):
            batch_images = images[iter * batch_size: min((iter + 1) * batch_size, num)]
            batch_paths = paths[iter * batch_size: min((iter + 1) * batch_size, num)]

            batch_results = self._generate(batch_images)
            results += batch_results
        print("GENERATED: ", results)
        return results

二、EasyOCR

1.安装

EasyOCR是个流行的轻量化OCR工具,底层依赖于PyTorch,所以需要一起安装。

bash 复制代码
# eg: CUDA 11.7
pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu117

pip install easyocr

2.模型下载

EasyOCR会使用到文本检测模型文本识别模型,在首次使用时会自动下载相应的模型,但是由于网络原因,很可能报错,需要手动下载模型,参考这篇博客,具体步骤如下:

1)首先找到模型下载存放的路径,默认是~/.EasyOCR;

2)去modelhub中下载权重文件,解压后放置到上述模型存放路径;文本检测模型是CRAFT,文本识别模型是2rd Generation Models下方的english_g2(因为我们选择的语言是英文)。

3.DEMO

EasyOCRExecutor部分代码如下,1)在初始化阶段使用easyocr.Reader加载模型,["en"]表示选择英文,download_enabled=False则表示使用已经下载到本地的模型,而不再去下载;2)重写self._generate()方法,通过self.reader.readtext(image, detail=1)即可获取识别的文本、文本包围盒以及得分。

python 复制代码
# src/ocr_executor/easyocr_executor.py
import math
from typing import List

import numpy as np
import easyocr

from ocr_executor.ocr_executor_base import OCRExecutorBase
from util.util import save_detection

class EasyOCRExecutor(OCRExecutorBase):
    def __init__(self):
        super(EasyOCRExecutor, self).__init__()
        self.reader = easyocr.Reader(["en"], download_enabled=False)

    def _generate(self, images: List[np.ndarray]) -> List:
        """
        批量生成
        """
        results = []
        for image in images:
            # 可以传图片或者文件路径,detail=1返回检测结果
            result = self.reader.readtext(image, detail=1)
            # for detection in result:
            #     bbox, text, score = detection
            #     print(f"Text: {text}, BBox: {bbox}, Score: {score}")
            results.append(result)
        return results

我在src/main.py中完成了所有OCR识别样例图片的脚本,现在就来看一下使用EasyOCRExecutor的demo:

python 复制代码
import os

import numpy as np
import cv2

from ocr_executor.easyocr_executor import EasyOCRExecutor
from util.util import save_detection   # 用于绘制检测和识别结果的

# 构造输入
image_path = "/home/ubuntu/Projects_ubuntu/ocr_fusion/src/images/handwriting.png"
image = cv2.imread(image_path)
paths_images = [(image_path, image)]

output_dir = "./images_output"
if not os.path.exists(output_dir):
    os.makedirs(output_dir, exist_ok=True)

# easyocr
easyocr_executor = EasyOCRExecutor()
result = easyocr_executor.execute(paths_images)[0]   # 取第一个结果,对应第一张图片

# 简单的输出处理,方便喂给save_detection绘制图片
bboxes = []
texts = []
for idx, detection in enumerate(result):
    bbox, text, score = detection
    bbox = [[int(pt[0]), int(pt[1])] for pt in bbox]
    bbox = np.array(bbox).astype(np.int32).reshape(-1, 2)
    bboxes.append(bbox)
    texts.append(text)
save_detection(os.path.join(output_dir, "easyocr.png"), image, bboxes, texts=texts, poly=True)

结果如下,左边是原图,蓝色框是EasyOCR的文本检测框,右边是每个框中识别的文字,效果一般,对于这种简单场景还是会有一些识别错误。

三、Tesseract

1.安装

想要使用Tesseract,需要先安装tesseract-ocr,然后安装pytesseract这个Python包。我尝试了两种安装方式,针对使用conda虚拟环境的情况,我推荐第一种。

  1. 方式1:使用conda命令同时安装tesseractpytesseract,比如我使用环境是python3.7;
bash 复制代码
# 安装参考链接为:
# https://anaconda.org/conda-forge/tesseract
# https://anaconda.org/conda-forge/pytesseract

# 1) 安装:使用conda-forge源,有点慢
conda install conda-forge::tesseract pytesseract

# 2) 测试tesseract
# conda activate your_env
tesseract --version
# tesseract显示的版本有可能是5.3.0,也有可能是4.1.1,问题不大
# tesseract 4.1.1
#  leptonica-1.80.0
#   libgif 5.2.1 : libjpeg 9e : libpng 1.6.39 : libtiff 4.4.0 : zlib 1.2.13 : libwebp 1.2.4 : libopenjp2 2.4.0
#  Found AVX2
#  Found AVX
#  Found FMA
#  Found SSE
#  Found libarchive 3.4.0 zlib/1.2.11 liblzma/5.2.4 bz2lib/1.0.8 liblz4/1.9.2 libzstd/1.4.4

# 查找可执行文件路径
which tesseract
# /home/ubuntu/anaconda3/envs/ocr_env/bin/tesseract

# 3) 测试pytesseract
# conda list找到了pytesseract-0.3.10

# 检查是否安装成功
python
# >>> import pytesseract
# >>> print(pytesseract.get_tesseract_version())   # 4.1.1
  1. 方式2:先在系统层级安装tesseract-ocr,然后在虚拟环境中安装pytesseract;这种安装方式的缺点是没有进行完全的环境隔离,在有些使用场景下,可能会遇到加载动态链接库报错的问题。
bash 复制代码
# 1) 安装tesseract-ocr
sudo apt-get update
sudo apt-get install tesseract-ocr

# 配置环境变量
vim ~/.bashrc
  # 在文件最后添加:
  export PATH=$PATH:/usr/local/bin

# 重新加载文件
source ~/.bashrc

# 验证配置
echo $PATH
# /home/ubuntu/anaconda3/bin:/home/ubuntu/anaconda3/condabin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/usr/games:/usr/local/games:/snap/bin:/usr/local/cuda/bin:/usr/local/bin

# 2) 验证 Tesseract 安装
tesseract --version
# tesseract 4.1.1
#  leptonica-1.79.0
#   libgif 5.1.4 : libjpeg 8d (libjpeg-turbo 2.0.3) : libpng 1.6.37 : libtiff 4.1.0 : zlib 1.2.11 : libwebp 0.6.1 : libopenjp2 2.3.1
#  Found AVX2
#  Found AVX
#  Found FMA
#  Found SSE
#  Found libarchive 3.4.0 zlib/1.2.11 liblzma/5.2.4 bz2lib/1.0.8 liblz4/1.9.2 libzstd/1.4.4

# 查找可执行文件路径
which tesseract
# /usr/bin/tesseract  --> 有点奇怪,PATH=$PATH:/usr/local/bin这个设置好像是无效的,实际路径在/usr/bin/

# 3) 安装pytesseract
pip install pytesseract   # 0.3.10

2.使用问题

当我们开始使用pytesseract时有可能遇到如下报错:

python 复制代码
from PIL import Image

image = Image.open("/home/ubuntu/Projects_ubuntu/TrOCR/images/cropped_image/17.jpg")
printed_text = pytesseract.image_to_string(image, config="--psm 7")

# 报错:
# Traceback (most recent call last):
#   File "/home/ubuntu/Projects_ubuntu/TrOCR/tesseract_ocr.py", line 24, in <module>
#     printed_text = pytesseract.image_to_string(image, config="--psm 7")
#   File "/home/ubuntu/anaconda3/envs/fasttext/lib/python3.7/site-packages/pytesseract/pytesseract.py", line 427, in image_to_string
#     }[output_type]()
#   File "/home/ubuntu/anaconda3/envs/fasttext/lib/python3.7/site-packages/pytesseract/pytesseract.py", line 426, in <lambda>
#     Output.STRING: lambda: run_and_get_output(*args),
#   File "/home/ubuntu/anaconda3/envs/fasttext/lib/python3.7/site-packages/pytesseract/pytesseract.py", line 288, in run_and_get_output
#     run_tesseract(**kwargs)
#   File "/home/ubuntu/anaconda3/envs/fasttext/lib/python3.7/site-packages/pytesseract/pytesseract.py", line 264, in run_tesseract
#     raise TesseractError(proc.returncode, get_errors(error_string))
# pytesseract.pytesseract.TesseractError: (1, 'Error opening data file /home/ubuntu/anaconda3/envs/ocr_env/share/tessdata Please make sure the TESSDATA_PREFIX environment variable is set to your "tessdata" directory. Failed loading language \'eng\' Tesseract couldn\'t load any languages! Could not initialize tesseract.')

查看最后一行报错,实际上是找不到数据,奇怪的是:

1)在/home/ubuntu/anaconda3/envs/ocr_env/share/tessdata文件夹下能找到eng.traineddata

2)/home/ubuntu/anaconda3/envs/ocr_env/share/tessdata并不是报错中所说的文件,而是文件夹。

解决方式是按照提示修改环境变量:

python 复制代码
# python中
import os
os.environ['TESSDATA_PREFIX'] = '/home/ubuntu/anaconda3/envs/fasttext/share/tessdata/'

# 或者bash中
export TESSDATA_PREFIX=/home/ubuntu/anaconda3/envs/fasttext/share/tessdata/

3.DEMO

Tesseract是支持多种形式的OCR,1)比如pytesseract.image_to_string(image, config="--psm 7")是只识别文字,参数--psm可以指定识别模式,比如"--psm 7"表示单行文本识别,可以自己查一下;2)pytesseract.image_to_data(image, output_type=pytesseract.Output.DICT)则以字典的形式输出所有信息,包括识别的层次结构级别level(比如level=5可以过滤出单词级别的结果)、包围盒以及置信度conf等等。

python 复制代码
# src/ocr_executor/tesseractocr_executor_base.py
from typing import List

import numpy as np
import pytesseract

from ocr_executor.ocr_executor_base import OCRExecutorBase

class TesseractOCRExecutor(OCRExecutorBase):
    def __init__(self):
        super(TesseractOCRExecutor, self).__init__()
    
    def _generate(self, images: np.ndarray) -> List:
        """
        根据输入图片批量生成文字
        """
        generated = []
        # for image in images:
        #     result = pytesseract.image_to_string(image, config="--psm 7")   # --psm 7 表示单行识别
        #     generated.append(result)

        for image in images:
            data = pytesseract.image_to_data(image, output_type=pytesseract.Output.DICT)
            
            result = []
            num = len(data["level"])
            for i in range(num):
                if not data["level"][i] == 5:
                    continue
                (x, y, w, h) = (data['left'][i], data['top'][i], data['width'][i], data['height'][i])
                box = [[x, y], [x + w, y + h]]
                text = data["text"][i]
                conf = data["conf"][i]
                result.append((box, text, conf))

            generated.append(result)
        
        return generated

TesseractOCRExecutor识别demo如下:

python 复制代码
from ocr_executor.tesseractocr_executor import TesseractOCRExecutor

# 数据准备和前面是一样,省略

# tesseractocr
tesseractocr_executor = TesseractOCRExecutor()
result = tesseractocr_executor.execute(paths_images)[0]
bboxes = []
texts = []
for idx, detection in enumerate(result):
    bbox, text, conf = detection
    bboxes.append(bbox)
    texts.append(text)
save_detection(os.path.join(output_dir, "tesseractocr.png"), image, bboxes, texts=texts, poly=False)

结果如下,这里只展示了单词级别的包围盒,你可以根据需求选择你需要的level

python 复制代码
# level 字段可能的值及其含义
# 1: 表示整个页面(Page)。
# 2: 表示块(Block),通常是一个独立的区域,如段落或图像。
# 3: 表示段落(Paragraph)。
# 4: 表示行(Line),即文本行。
# 5: 表示单词(Word),即单个单词。
# 6: 表示字符(Symbol),即单个字符。

四、PaddleOCR

PaddleOCR底层依赖于PaddlePaddle,但是它的安装可能会有点麻烦,并且对于CUDA版本的支持和其他框架比如PyTorch不是很同步。如果对这方面有些头疼的朋友,可以选择PaddleOCRPyTorch移植版,以便快速体验。

1.安装

代码如下(示例):

  1. PaddlePaddle安装(我参考了这篇文章):
    1)使用conda安装paddlepaddle-gpu==2.5.2
bash 复制代码
conda install paddlepaddle-gpu==2.5.2 cudatoolkit=11.7 -c https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/Paddle/ -c conda-forge 

在安装过程中,你会看到提示信息,会下载和安装cudatoolkitcudnn,我们来查看一下具体安装结果:

bash 复制代码
(ocr_env) ubuntu@ubuntu:~$ conda list cudatoolkit
# packages in environment at /home/ubuntu/anaconda3/envs/ocr_env:
#
# Name                    Version                   Build  Channel
cudatoolkit               11.7.1              h4bc3d14_13    conda-forge
(ocr_env) ubuntu@ubuntu:~$ conda list cudnn
# packages in environment at /home/ubuntu/anaconda3/envs/ocr_env:
#
# Name                    Version                   Build  Channel
cudnn                     8.4.1.50             hed8a83a_0    conda-forge

2)按照paddlepaddle官方文档,检查是否安装成功:

bash 复制代码
ubuntu@ubuntu:~/anaconda3/envs/ocr_env/lib$ python
>>> import paddle
>>> paddle.utils.run_check()

# 报错日志:
    # PreconditionNotMetError: Cannot load cudnn shared library. Cannot invoke method cudnnGetVersion.
    #   [Hint: cudnn_dso_handle should not be null.] (at ../paddle/phi/backends/dynload/cudnn.cc:64)
    #   [operator < fill_constant > error]

报错显示,找不到cudnn相关的shared library

3)查看shared library中有没有libcudnn.solibcublas.so

  • 使用命令ls /usr/lib |grep lib查看,发现上述文件不存在;
bash 复制代码
(ocr_env) ubuntu@ubuntu:/usr/lib$ ls /usr/lib |grep lib
klibc
klibc-abS-oVB3xeRN8SFypUWbQvR33nc.so
libGL.so.1
libpsm1
libreoffice
  • 使用命令find $CONDA_PREFIX/lib -name "libcublas.so"手动查找cudatoolkit库文件,这是在当前conda环境中查找
bash 复制代码
(ocr_env) ubuntu@ubuntu:~$ find $CONDA_PREFIX/lib -name "libcublas.so"
/home/ubuntu/anaconda3/envs/ocr_env/lib/libcublas.so   # 找到了文件路径

4)根据这篇博客,创建软链接:

bash 复制代码
(ocr_env) ubuntu@ubuntu:~/anaconda3/envs/ocr_env/lib$ cd /usr/lib
(ocr_env) ubuntu@ubuntu:/usr/lib$ sudo ln -s /home/ubuntu/anaconda3/envs/ocr_env/lib/libcudnn.so.8.4.1 libcudnn.so
(ocr_env) ubuntu@ubuntu:/usr/lib$ sudo ln -s /home/ubuntu/anaconda3/envs/ocr_env/lib/libcublas.so.11.10.3.66 libcublas.so

然后检测相关的lib是否存在:发现已存在

bash 复制代码
(ocr_env) ubuntu@ubuntu:/usr/lib$ ls /usr/lib |grep lib
klibc
klibc-abS-oVB3xeRN8SFypUWbQvR33nc.so
libcublas.so
libcudnn.so
libGL.so.1
libpsm1
libreoffice

5)重复2),再次确认是否安装成功,这次报错如下:

bash 复制代码
(ocr_env) ubuntu@ubuntu:~/anaconda3/envs/ocr_env/lib$ python
>>> import paddle
>>> paddle.utils.run_check()

# 报错
# Running verify PaddlePaddle program ... 
# I0813 17:02:59.401835  6514 interpretercore.cc:237] New Executor is Running.
# W0813 17:02:59.401952  6514 gpu_resources.cc:119] Please NOTE: device: 0, GPU Compute Capability: 8.6, Driver API Version: 12.2, Runtime API Version: 11.7
# W0813 17:02:59.421336  6514 gpu_resources.cc:149] device: 0, cuDNN Version: 8.4.
# python: symbol lookup error: /usr/local/cuda-11.0/targets/x86_64-linux/lib/libcublas.so: undefined symbol: runGemmShortApi, version libcublasLt.so.11

发现找的是/usr/local/cuda-11.0,这应该是我之前安装的系统级的cudatoolkit,说明paddlepaddle-gpu没有去找虚拟环境的cudatoolkit

6)设置环境变量:export LD_LIBRARY_PATH=/home/ubuntu/anaconda3/envs/ocr_env/lib:$LD_LIBRARY_PATH

bash 复制代码
(ocr_env) ubuntu@ubuntu:/usr/local/cuda-11.0/targets/x86_64-linux/lib$ echo $LD_LIBRARY_PATH

(ocr_env) ubuntu@ubuntu:/usr/local/cuda-11.0/targets/x86_64-linux/lib$ conda list paddlepaddle-gpu
# packages in environment at /home/ubuntu/anaconda3/envs/ocr_env:
#
# Name                    Version                   Build  Channel
paddlepaddle-gpu          2.5.2.post117            pypi_0    pypi
(ocr_env) ubuntu@ubuntu:/usr/local/cuda-11.0/targets/x86_64-linux/lib$ export LD_LIBRARY_PATH=/home/ubuntu/anaconda3/envs/ocr_env/lib:$LD_LIBRARY_PATH
(ocr_env) ubuntu@ubuntu:/usr/local/cuda-11.0/targets/x86_64-linux/lib$ echo $LD_LIBRARY_PATH
/home/ubuntu/anaconda3/envs/ocr_env/lib:


# 方法2:永久设置(对所有终端会话有效,如果有其它虚拟环境,不是很推荐):
# echo 'export LD_LIBRARY_PATH=/path/to/library:$LD_LIBRARY_PATH' >> ~/.bashrc
# source ~/.bashrc

7)重复2),这次检查是安装成功的:

bash 复制代码
>>> import paddle
>>> paddle.utils.run_check()
# Running verify PaddlePaddle program ... 
# I0813 17:19:18.161489  8275 interpretercore.cc:237] New Executor is Running.
# W0813 17:19:18.161605  8275 gpu_resources.cc:119] Please NOTE: device: 0, GPU Compute Capability: 8.6, Driver API Version: 12.2, Runtime API Version: 11.7
# W0813 17:19:18.163659  8275 gpu_resources.cc:149] device: 0, cuDNN Version: 8.4.
# I0813 17:19:18.836459  8275 interpreter_util.cc:518] Standalone Executor is Used.
# PaddlePaddle works well on 1 GPU.
  1. PaddleOCR安装:
bash 复制代码
pip install paddleocr   # version = 2.7.0.2

2.DEMO

PaddleOCRExecutor代码如下,1)初始化阶段设置了语言,以及是否识别(可以只检测不识别);2)PaddleOCR自带的draw_ocr可以直接绘制识别结果。

python 复制代码
# src/ocr_executor/paddleocr_executor_base.py
from typing import Dict, List

import numpy as np
import cv2
from PIL import ImageFont
from paddleocr import PaddleOCR, draw_ocr

from ocr_executor.ocr_executor_base import OCRExecutorBase

class PaddleOCRExecutor(OCRExecutorBase):
    def __init__(self, lang: str = "en", rec: bool = True):
        super(PaddleOCRExecutor, self).__init__()
        # 识别语言
        self.lang: str = lang
        # 是否识别
        self.rec = rec
        
        self._init()

    def _init(self):
        # 初始化
        self.model = PaddleOCR(use_angle_cls=True, 
                               lang=self.lang, 
                               det=True, 
                               rec=self.rec)

    def _generate(self, images: np.ndarray) -> List:
        """
        根据输入图片生成文字(可选)和返回相应的box
        images可以传入路径
        """
        # 暂时不能批量处理
        generated = []
        for image in images:
            # result = self.model.ocr(image_path, det=True, rec=self.rec, cls=True)[0]
            if not isinstance(image, np.ndarray):
                image_ = np.array(image)   # TODO: 最好看一下通道顺序
                print("image_.shape: ", image_.shape)
            else:
                image_ = image
            result = self.model.ocr(image_, det=True, rec=self.rec, cls=True)[0]
            # self._draw_single_result(image_, result)

            generated.append(result)
        return generated
    
    def _draw_single_result(self, image, result):
        """
        绘制一张图片的检测和识别结果
        """
        # for line in result:
        #     print(line)
        boxes = [line[0] for line in result]
        txts = [line[1][0] for line in result]
        scores = [line[1][1] for line in result]

        # font = ImageFont.load_default()
        im_show = draw_ocr(image, boxes, txts, scores, font_path="/usr/share/fonts/truetype/ttf-khmeros-core/KhmerOS.ttf")
        cv2.imwrite("./result.jpg", im_show)

PaddleOCRExecutor识别demo如下:

python 复制代码
from ocr_executor.paddleocr_executor import PaddleOCRExecutor

# paddleocr
paddleocr_executor = PaddleOCRExecutor()
result = paddleocr_executor.execute(paths_images)[0]

bboxes = []
texts = []
for idx, detection in enumerate(result):
    bbox, (text, score) = detection
    bbox = np.array(bbox).astype(np.int32).reshape(-1, 2)
    bboxes.append(bbox)
    texts.append(text)
save_detection(os.path.join(output_dir, "paddleocr.png"), image, bboxes, texts=texts, poly=True)

结果如下,检测和识别效果都是不错的,具体来说就是检测包围盒刚好是我们所期望的样子,颗粒度刚好,同时识别也没有出错;个人经验是其检测模型可靠性挺强,可以和其他识别模型结合使用,识别模型适用于简单场景或者要求不是很高的场景;模型都不大,赞一个。

五、PaddleOCR(PyTorch移植版)

上一小节说了,如果你不想安装PaddlePaddle,那么也可以使用好心人提供的PyTorch移植版。小缺点是如果你需要在项目中使用,需要copy和整理一下代码;当然也可以转成onnx格式进行推理。

1.代码整理

由于我当前只是更想用PaddleOCR的检测模型,所以在repo中整理出了ch_ptocr_v3_det_infer以及ch_ptocr_v4_det_infer这两个版本的检测模型及其代码,存放于src/paddle_ocr。这两个模型都可以检测中英等多种语言文本。如果你另有需求或者想多探索一下其他模型,可以去原作者仓库把玩。

2.DEMO

我将检测模型封装成了类TextDetector,具体网络模型我们不在此展开,来说一下ch_ptocr_v3_det_infer以及ch_ptocr_v4_det_infer对应的参数传递。如下所示,如果使用ch_ptocr_v4_det_infer,你需要传入模型路径det_model_path以及配置文件路径det_yaml_path;而如果使用ch_ptocr_v3_det_infer,只需传入det_model_path,因为它的配置并不是以文件形式存在。我这边都写的绝对路径,需要你改成自己的路径;文件都已存在,不需要你自行下载。

python 复制代码
# src/paddle_ocr/text_detector.py
from paddle_ocr.model_args import parse_args

_args = parse_args()

# det v4:
# _args.det_model_path = "/home/ubuntu/Projects_ubuntu/torchocr/src/paddle_ocr/pretrained_models/det_v4/ch_ptocr_v4_det_infer.pth"
# _args.det_yaml_path = "/home/ubuntu/Projects_ubuntu/torchocr/src/paddle_ocr/pretrained_models/det_v4/ch_PP-OCRv4_det_student.yml"

# det v3:
_args.det_model_path = "/home/ubuntu/Projects_ubuntu/torchocr/src/paddle_ocr/pretrained_models/det_v3/ch_ptocr_v3_det_infer.pth"
    

class TextDetector:
    def __init__(self, args=_args, **kwargs):
        # ...

TextDetector检测demo如下:

python 复制代码
from paddle_ocr.text_detector import TextDetector

# paddleocr (pytorch det model)
text_detector = TextDetector()
bboxes, _ = text_detector(image)
bboxes = [np.array(bbox).astype(np.int32).reshape((-1, 2)) for bbox in bboxes.tolist()]
save_detection(os.path.join(output_dir, "paddleocr_pytorch_det.png"), image, bboxes, texts=None, poly=True)

结果如下,你会发现检测结果和原生的PaddleOCR是完全一致的。

六、TrOCR

TrOCR是微软出品,主要用于单行手写文字识别,效果不错;缺点是1)只能支持单行识别,一般需要结合文本检测模型使用;2)该模型使用Transformer序列生成的方式来生成识别的文字的,速度略慢;3)由于是序列生成,没有很好的方式提供识别结果整个的置信度(github上有人提过这样的问题,但没有解决)。

1.安装

bash 复制代码
pip install torch   # 选择合适gpu版本进行安装
pip install transformers

2.模型下载

Hugging Face上下载trocr-base-handwritten模型到本地,如下图所示。

3.DEMO

TrOCRExecutor代码如下,主要组件是处理器TrOCRProcessor(可以同时处理图像和文本)和模型VisionEncoderDecoderModel

python 复制代码
# src/ocr_exeutor/trocr_executor.py
from typing import List

import numpy as np
from PIL import Image
from transformers import TrOCRProcessor, VisionEncoderDecoderModel

from ocr_executor.ocr_executor_base import OCRExecutorBase

DAFAULT_MODEL_PATH = "/home/ubuntu/Projects_ubuntu/TrOCR/trocr_base_handwritten"

class TrOCRExecutor(OCRExecutorBase):
    def __init__(self, model_path: str = DAFAULT_MODEL_PATH):
        super(TrOCRExecutor, self).__init__()
        # ocr模型
        self.model_path = model_path

        self._init()

    def _init(self):
        # 初始化
        if "trocr" in self.model_path:
            self.processor = TrOCRProcessor.from_pretrained(self.model_path)
            self.model = VisionEncoderDecoderModel.from_pretrained(self.model_path)
        else:
            raise NotImplementedError

    def _generate(self, images: np.ndarray) -> List:
        """
        根据输入图片批量生成文字
        """
        pixel_values = self.processor(images=images, return_tensors="pt").pixel_values
        # generated_ids = self.model.generate(pixel_values, output_scores=True, return_dict_in_generate=True)
        generated_ids = self.model.generate(pixel_values)  # TODO
        generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)
        return generated_text

TrOCRExecutor识别demo如下:

python 复制代码
from ocr_executor.trocr_executor import TrOCRExecutor

# 注意,和上面不同,这边换成了单行手写文本的图片
image_path = "/home/ubuntu/Projects_ubuntu/ocr_fusion/src/images/handwriting_single_line.png"  # trocr只能识别单行手写文本
image = cv2.imread(image_path)
paths_images = [(image_path, image)]

# trocr
trocr_executor = TrOCRExecutor()
text = trocr_executor.execute(paths_images)[0]
save_detection(os.path.join(output_dir, "trocr.png"), image, [], texts=[text], poly=True)

结果如下,个人体验是效果还可以,但是鲁棒性有待商榷,有时候单行文字截取多一些少一些会影响识别效果。

七、GOT

GOT是阶跃星辰、旷世、中国科学院以及清华的作品,效果确实不错,同时它支持多种形式的OCR,比如markdown、音符、分子式等等,通用性强;该模型是端到端的,直接输出文本,但是应该不能输出文本检测框。

1.安装

安装可以参考源仓库(其中提到的Flash-Attention不是必须的),此处不赘述。

2.模型下载

Hugging Face上下载GOT-OCR2_0模型到本地,如下图所示。

3.DEMO

我把源代码中必要的部分抽出来,放到了src/got_ocr中,主体部分封装成了类GOTTextGenerator,为GOTOCRExecutor所用,GOTOCRExecutor代码如下:

python 复制代码
# src/ocr_exeutor/gotocr_executor.py
class GOTOCRExecutor(OCRExecutorBase):
    def __init__(self, model_name: str = _MODEL_NAME):
        super(GOTOCRExecutor, self).__init__()
        # 模型
        self.text_generator = GOTTextGenerator(model_name)
    
    def _generate(self, images: List[np.ndarray]) -> List:
        """
        批量生成
        """
        results = []
        for image in images:
            # 构建模型输入
            input_dict = {
                "image": image,
                "type": "ocr"
            }
            result = self.text_generator.generate(input_dict)
            results.append(result)
        return results

GOTOCRExecutor识别demo如下:

python 复制代码
from ocr_executor.gotocr_executor import GOTOCRExecutor

# gotocr
gotocr_executor = GOTOCRExecutor()
result = gotocr_executor.execute(paths_images)[0]
save_detection(os.path.join(output_dir, "gotocr.png"), image, [], texts=[result], poly=True)

结果如下,我们发现直接让它识别多行文本差强人意,单词之间缺少空格,个人猜测有两个原因:1)图片比较大,模型处理时对图片的压缩比较厉害;2)训练数据的分布上可能有问题,因为论文中隐约可以看出他们收集的手写数据是小片段的拼接出来的,所以可能数据上不是很到位。

需要说明的是,其实中文都正确识别出来了,只是OpenCV绘制文本时默认不支持中文,所以显示的都是问号。

针对上述问题,我们可以换一种思路,就是让GOT负责单行级别的识别,在复杂场景下只要在前面加一个文本检测模型即可。因此,我是用单行手写文本又测试了一次,结果如下,完美解决。


总结

本篇配合repo讲述了一些流行的OCR工具的使用方法,比如从好用的paddleOCROCR2.0的端到端模型GOT。希望能为需要使用OCR工具的朋友提供便利。

相关推荐
只怕自己不够好2 分钟前
《OpenCV 图像缩放、翻转与变换全攻略:从基础操作到高级应用实战》
人工智能·opencv·计算机视觉
网络研究院9 分钟前
国土安全部发布关键基础设施安全人工智能框架
人工智能·安全·框架·关键基础设施
mqiqe14 分钟前
Elasticsearch 分词器
python·elasticsearch
不去幼儿园1 小时前
【MARL】深入理解多智能体近端策略优化(MAPPO)算法与调参
人工智能·python·算法·机器学习·强化学习
想成为高手4992 小时前
生成式AI在教育技术中的应用:变革与创新
人工智能·aigc
YSGZJJ2 小时前
股指期货的套保策略如何精准选择和规避风险?
人工智能·区块链
无脑敲代码,bug漫天飞2 小时前
COR 损失函数
人工智能·机器学习
幽兰的天空2 小时前
Python 中的模式匹配:深入了解 match 语句
开发语言·python
HPC_fac130520678163 小时前
以科学计算为切入点:剖析英伟达服务器过热难题
服务器·人工智能·深度学习·机器学习·计算机视觉·数据挖掘·gpu算力
网易独家音乐人Mike Zhou6 小时前
【卡尔曼滤波】数据预测Prediction观测器的理论推导及应用 C语言、Python实现(Kalman Filter)
c语言·python·单片机·物联网·算法·嵌入式·iot