开源通用验证码识别OCR —— DdddOcr 源码赏析(一)

### 文章目录

  • [@[toc]](#文章目录 @[toc] 前言 DdddOcr 环境准备 安装DdddOcr 使用示例 源码分析 实例化DdddOcr 实例化过程 分类识别 分类识别过程 未完待续)
  • [前言](#文章目录 @[toc] 前言 DdddOcr 环境准备 安装DdddOcr 使用示例 源码分析 实例化DdddOcr 实例化过程 分类识别 分类识别过程 未完待续)
  • [DdddOcr](#文章目录 @[toc] 前言 DdddOcr 环境准备 安装DdddOcr 使用示例 源码分析 实例化DdddOcr 实例化过程 分类识别 分类识别过程 未完待续)
  • [环境准备](#文章目录 @[toc] 前言 DdddOcr 环境准备 安装DdddOcr 使用示例 源码分析 实例化DdddOcr 实例化过程 分类识别 分类识别过程 未完待续)
  • [安装DdddOcr](#文章目录 @[toc] 前言 DdddOcr 环境准备 安装DdddOcr 使用示例 源码分析 实例化DdddOcr 实例化过程 分类识别 分类识别过程 未完待续)
  • [使用示例](#文章目录 @[toc] 前言 DdddOcr 环境准备 安装DdddOcr 使用示例 源码分析 实例化DdddOcr 实例化过程 分类识别 分类识别过程 未完待续)
  • [源码分析](#文章目录 @[toc] 前言 DdddOcr 环境准备 安装DdddOcr 使用示例 源码分析 实例化DdddOcr 实例化过程 分类识别 分类识别过程 未完待续)
  • [实例化DdddOcr](#文章目录 @[toc] 前言 DdddOcr 环境准备 安装DdddOcr 使用示例 源码分析 实例化DdddOcr 实例化过程 分类识别 分类识别过程 未完待续)
  • [实例化过程](#文章目录 @[toc] 前言 DdddOcr 环境准备 安装DdddOcr 使用示例 源码分析 实例化DdddOcr 实例化过程 分类识别 分类识别过程 未完待续)
  • [分类识别](#文章目录 @[toc] 前言 DdddOcr 环境准备 安装DdddOcr 使用示例 源码分析 实例化DdddOcr 实例化过程 分类识别 分类识别过程 未完待续)
  • [分类识别过程](#文章目录 @[toc] 前言 DdddOcr 环境准备 安装DdddOcr 使用示例 源码分析 实例化DdddOcr 实例化过程 分类识别 分类识别过程 未完待续)
  • [未完待续](#文章目录 @[toc] 前言 DdddOcr 环境准备 安装DdddOcr 使用示例 源码分析 实例化DdddOcr 实例化过程 分类识别 分类识别过程 未完待续)

前言

DdddOcr 源码赏析

DdddOcr

DdddOcr是开源的通用验证码识别OCR
官方传送门

环境准备

安装DdddOcr

bash 复制代码
pip install ddddocr

使用示例

示例图片如下

Python3 复制代码
import ddddocr

ocr = ddddocr.DdddOcr(show_ad=False)

image = open("example.png", "rb").read()
result = ocr.classification(image)
print(result)
# 识别结果 aFtf

源码分析

我们以实例代码为例,分析源码里面都做了什么

实例化DdddOcr

python3 复制代码
ocr = ddddocr.DdddOcr(show_ad=False)

对应源码如下

python 复制代码
class DdddOcr(object):
    def __init__(self, ocr: bool = True, det: bool = False, old: bool = False, beta: bool = False,
                 use_gpu: bool = False,
                 device_id: int = 0, show_ad=True, import_onnx_path: str = "", charsets_path: str = ""):
        if show_ad:
            print("欢迎使用ddddocr,本项目专注带动行业内卷,个人博客:wenanzhe.com")
            print("训练数据支持来源于:http://146.56.204.113:19199/preview")
            print("爬虫框架feapder可快速一键接入,快速开启爬虫之旅:https://github.com/Boris-code/feapder")
            print(
                "谷歌reCaptcha验证码 / hCaptcha验证码 / funCaptcha验证码商业级识别接口:https://yescaptcha.com/i/NSwk7i")
        if not hasattr(Image, 'ANTIALIAS'):
            setattr(Image, 'ANTIALIAS', Image.LANCZOS)
        self.use_import_onnx = False
        self.__word = False
        self.__resize = []
        self.__charset_range = []
        self.__channel = 1
        if import_onnx_path != "":
            det = False
            ocr = False
            self.__graph_path = import_onnx_path
            with open(charsets_path, 'r', encoding="utf-8") as f:
                info = json.loads(f.read())
            self.__charset = info['charset']
            self.__word = info['word']
            self.__resize = info['image']
            self.__channel = info['channel']
            self.use_import_onnx = True

        if det:
            ocr = False
            self.__graph_path = os.path.join(os.path.dirname(__file__), 'common_det.onnx')
            self.__charset = []

实例化过程

1 show_ad

先来一波广告推广,开源不易,尤其是DdddOcr这么良心的开源Ocr,大家多多支持DdddOcr

2 ANTIALIAS 判断

python3 复制代码
if not hasattr(Image, 'ANTIALIAS'):
    setattr(Image, 'ANTIALIAS', Image.LANCZOS)

Image.LANCZOS,这是一种图像重采样过滤器,通常用于图像缩放时减少锯齿状边缘和模糊。

这段代码的作用主要是向后兼容或者为旧代码提供一种便捷的访问方式,使得即使PIL或Pillow库的官方API中没有直接提供ANTIALIAS这个属性,开发者也可以通过这种方式来使用LANCZOS过滤器进行图像缩放等操作。

3 然后初始化一些变量

python3 复制代码
self.use_import_onnx = False
        self.__word = False
        self.__resize = []
        self.__charset_range = []
        self.__channel = 1

4 判断是否使用自己的Ocr模型

python3 复制代码
if import_onnx_path != "":
     det = False
      ocr = False
      self.__graph_path = import_onnx_path
      with open(charsets_path, 'r', encoding="utf-8") as f:
          info = json.loads(f.read())
      self.__charset = info['charset']
      self.__word = info['word']
      self.__resize = info['image']
      self.__channel = info['channel']
      self.use_import_onnx = True

如果使用自己的Ocr模型,通过import_onnx_path指定模型路径,同时charsets_path指定字符集信息

5 是否启用目标检测

python3 复制代码
if det:
    ocr = False
    self.__graph_path = os.path.join(os.path.dirname(__file__), 'common_det.onnx')
    self.__charset = []
     

1.6 是否启用ocr

beta为True表示启用新的ocr模型, 为False启用老的ocr模型

python3 复制代码
if ocr:
    if not beta:
        self.__graph_path = os.path.join(os.path.dirname(__file__), 'common_old.onnx')
        self.__charset = [....]
    else:
        self.__graph_path = os.path.join(os.path.dirname(__file__), 'common.onnx')
        self.__charset = [...]

6 是否启用GPU

python3 复制代码
 if use_gpu:
    self.__providers = [
          ('CUDAExecutionProvider', {
              'device_id': device_id,
              'arena_extend_strategy': 'kNextPowerOfTwo',
              'cuda_mem_limit': 2 * 1024 * 1024 * 1024,
              'cudnn_conv_algo_search': 'EXHAUSTIVE',
              'do_copy_in_default_stream': True,
          }),
      ]
  else:
      self.__providers = [
          'CPUExecutionProvider',
      ]

这里根据use_gpu来决定是使用GPU还是CPU作为计算提供者(ExecutionProvider)

如果use_gpu为True,即决定使用GPU进行计算,那么会创建一个名为CUDAExecutionProvider的提供者配置列表,并设置了一系列与CUDA(GPU计算平台)相关的参数。这些参数包括:

  1. device_id:指定使用的GPU设备的ID,这允许在多GPU环境中选择特定的GPU进行计算。
  2. arena_extend_strategy:内存分配策略,这里设置为'kNextPowerOfTwo',意味着内存分配时会向上取到最近的2的幂次方大小,这有助于减少内存碎片。
  3. cuda_mem_limit:限制CUDA设备(GPU)的内存使用量,这里设置为2GB(2 * 1024 * 1024 * 1024字节)。
  4. cudnn_conv_algo_search:指定卷积算法搜索策略,'EXHAUSTIVE'表示使用穷举搜索策略来找到最佳的卷积算法,这可能会增加预处理时间但可能提高执行效率。
  5. do_copy_in_default_stream:指定是否在默认流中执行数据复制操作,这里设置为True。
    如果use_gpu为False,即决定使用CPU进行计算,那么会简单地设置计算提供者列表为仅包含一个'CPUExecutionProvider'的列表。

7 加载onnx模型

python3 复制代码
self.__ort_session = onnxruntime.InferenceSession(self.__graph_path, providers=self.__providers)

❓疑问❓

从代码来看只能加载一种模型,ocr模型(新/旧)、det模型、自己的onnx模型,三种模型三选一,这里self.__graph_path指定模型路径时,却使用了3个if, 而不是if-elif-else结构,个人感觉不太合理, 只能说瑕不掩瑜

源码结构如下

python 复制代码
if import_onnx_path != "":
   	self.__graph_path = import_onnx_path
if det:
	self.__graph_path = os.path.join(os.path.dirname(__file__), 'common_det.onnx')
if ocr:
	if not beta:
		self.__graph_path = os.path.join(os.path.dirname(__file__), 'common_old.onnx')
	else:
        self.__graph_path = os.path.join(os.path.dirname(__file__), 'common.onnx')

分类识别

python3 复制代码
image = open("example.jpg", "rb").read()
result = ocr.classification(image)
print(result)

对应源码如下

python 复制代码
def classification(self, img, png_fix: bool = False, probability=False):
        if self.det:
            raise TypeError("当前识别类型为目标检测")
        if not isinstance(img, (bytes, str, pathlib.PurePath, Image.Image)):
            raise TypeError("未知图片类型")
        if isinstance(img, bytes):
            image = Image.open(io.BytesIO(img))
        elif isinstance(img, Image.Image):
            image = img.copy()
        elif isinstance(img, str):
            image = base64_to_image(img)
        else:
            assert isinstance(img, pathlib.PurePath)
            image = Image.open(img)
        if not self.use_import_onnx:
            image = image.resize((int(image.size[0] * (64 / image.size[1])), 64), Image.ANTIALIAS).convert('L')
        else:
            if self.__resize[0] == -1:
                if self.__word:
                    image = image.resize((self.__resize[1], self.__resize[1]), Image.ANTIALIAS)
                else:
                    image = image.resize((int(image.size[0] * (self.__resize[1] / image.size[1])), self.__resize[1]),
                                         Image.ANTIALIAS)
            else:
                image = image.resize((self.__resize[0], self.__resize[1]), Image.ANTIALIAS)
            if self.__channel == 1:
                image = image.convert('L')
            else:
                if png_fix:
                    image = png_rgba_black_preprocess(image)
                else:
                    image = image.convert('RGB')
        image = np.array(image).astype(np.float32)
        image = np.expand_dims(image, axis=0) / 255.
        if not self.use_import_onnx:
            image = (image - 0.5) / 0.5
        else:
            if self.__channel == 1:
                image = (image - 0.456) / 0.224
            else:
                image = (image - np.array([0.485, 0.456, 0.406])) / np.array([0.229, 0.224, 0.225])
                image = image[0]
                image = image.transpose((2, 0, 1))

        ort_inputs = {'input1': np.array([image]).astype(np.float32)}
        ort_outs = self.__ort_session.run(None, ort_inputs)
        result = []

        last_item = 0

        if self.__word:
            for item in ort_outs[1]:
                result.append(self.__charset[item])
        else:
            if not self.use_import_onnx:
                # 概率输出仅限于使用官方模型
                if probability:
                    ort_outs = ort_outs[0]
                    ort_outs = np.exp(ort_outs) / np.sum(np.exp(ort_outs))
                    ort_outs_sum = np.sum(ort_outs, axis=2)
                    ort_outs_probability = np.empty_like(ort_outs)
                    for i in range(ort_outs.shape[0]):
                        ort_outs_probability[i] = ort_outs[i] / ort_outs_sum[i]
                    ort_outs_probability = np.squeeze(ort_outs_probability).tolist()
                    result = {}
                    if len(self.__charset_range) == 0:
                        # 返回全部
                        result['charsets'] = self.__charset
                        result['probability'] = ort_outs_probability
                    else:
                        result['charsets'] = self.__charset_range
                        probability_result_index = []
                        for item in self.__charset_range:
                            if item in self.__charset:
                                probability_result_index.append(self.__charset.index(item))
                            else:
                                # 未知字符
                                probability_result_index.append(-1)
                        probability_result = []
                        for item in ort_outs_probability:
                            probability_result.append([item[i] if i != -1 else -1 for i in probability_result_index ])
                        result['probability'] = probability_result
                    return result
                else:
                    last_item = 0
                    argmax_result = np.squeeze(np.argmax(ort_outs[0], axis=2))
                    for item in argmax_result:
                        if item == last_item:
                            continue
                        else:
                            last_item = item
                        if item != 0:
                            result.append(self.__charset[item])
                    return ''.join(result)

            else:
                last_item = 0
                for item in ort_outs[0][0]:
                    if item == last_item:
                        continue
                    else:
                        last_item = item
                    if item != 0:
                        result.append(self.__charset[item])
                return ''.join(result)

分类识别过程

1 目标检测任务不支持分类

python3 复制代码
if self.det:
	raise TypeError("当前识别类型为目标检测")
  1. 图片格式转换
python3 复制代码
 if not isinstance(img, (bytes, str, pathlib.PurePath, Image.Image)):
            raise TypeError("未知图片类型")
        if isinstance(img, bytes):
            image = Image.open(io.BytesIO(img))
        elif isinstance(img, Image.Image):
            image = img.copy()
        elif isinstance(img, str):
            image = base64_to_image(img)
        else:
            assert isinstance(img, pathlib.PurePath)
            image = Image.open(img)

未完待续

明天见

相关推荐
cdut_suye1 分钟前
Linux工具使用指南:从apt管理、gcc编译到makefile构建与gdb调试
java·linux·运维·服务器·c++·人工智能·python
ONE米球兔22 分钟前
OCR(四)windows 环境基于c++的 paddle ocr 编译【GPU版本】
ocr·paddle
dundunmm24 分钟前
机器学习之scikit-learn(简称 sklearn)
python·算法·机器学习·scikit-learn·sklearn·分类算法
古希腊掌管学习的神24 分钟前
[机器学习]sklearn入门指南(1)
人工智能·python·算法·机器学习·sklearn
一道微光38 分钟前
Mac的M2芯片运行lightgbm报错,其他python包可用,x86_x64架构运行
开发语言·python·macos
HelloGitHub44 分钟前
跟着 8.6k Star 的开源数据库,搞 RAG!
开源·github
GitCode官方1 小时前
GitCode 光引计划投稿 | GoIoT:开源分布式物联网开发平台
分布式·开源·gitcode
Eric.Lee20211 小时前
ubuntu paddle ocr 部署bug问题解决
ubuntu·ocr·paddle
m0_748256781 小时前
WebGIS实战开源项目:智慧机场三维可视化(学习笔记)
笔记·学习·开源
四口鲸鱼爱吃盐1 小时前
Pytorch | 利用AI-FGTM针对CIFAR10上的ResNet分类器进行对抗攻击
人工智能·pytorch·python