Paddle-OCR根据垂直类场景自定义数据微调PP-OCRv4模型

Paddle-OCR根据垂直类场景自定义数据微调PP-OCRv4模型

1 文本检测模型微调

数据准备:

  • 加入少量真实数据(检测任务>=500张, 识别任务>=5000张),会大幅提升垂类场景的检测与识别效果
  • 在模型微调时,加入真实通用场景数据,可以进一步提升模型精度与泛化性能
  • 在图像检测任务中,增大图像的预测尺度,能够进一步提升较小文字区域的检测效果
  • 在模型微调时,需要适当调整超参数(学习率,batch size最为重要),以获得更优的微调效果。
  • 数据标注:单行文本标注格式,建议标注的检测框与实际语义内容一致。

1-1 数据准备

训练集&校验集

PaddleOCR 中的文本检测算法支持的标注文件格式如下,中间用"\t"分隔:

" 图像文件名                    json.dumps编码的图像标注信息"
ch4_test_images/img_61.jpg    [{"transcription": "MASA", "points": [[310, 104], [416, 141], [418, 216], [312, 179]]}, {...}]

json.dumps编码前的图像标注信息是包含多个字典的list,字典中的 points 表示文本框的四个点的坐标(x, y),从左上角的点开始顺时针排列。 transcription 表示当前文本框的文字,当其内容为"###"时,表示该文本框无效,在训练时会跳过。

公开数据集
数据集名称 图片下载地址 PaddleOCR 标注下载地址
ICDAR 2015 https://rrc.cvc.uab.es/?ch=4\&com=downloads train / test
ctw1500 https://paddleocr.bj.bcebos.com/dataset/ctw1500.zip 图片下载地址中已包含
total text https://paddleocr.bj.bcebos.com/dataset/total_text.tar 图片下载地址中已包含
td tr https://paddleocr.bj.bcebos.com/dataset/TD_TR.tar 图片下载地址中已包含

1-2 下载预训练模型

ch_PP-OCRv4_det_train

1-3 参数配置

配置文件: configs/det/ch_PP-OCRv4/ch_PP-OCRv4_det_student.yml

yaml 复制代码
Global:
  debug: false
  use_gpu: true
  epoch_num: &epoch_num 500
  log_smooth_window: 20
  print_batch_step: 100
  save_model_dir: ./output/ch_PP-OCRv4
  save_epoch_step: 10
  eval_batch_step:
  - 0
  - 1500
  cal_metric_during_train: false
  checkpoints:
  pretrained_model: https://paddleocr.bj.bcebos.com/pretrained/PPLCNetV3_x0_75_ocr_det.pdparams
  save_inference_dir: null
  use_visualdl: false
  infer_img: doc/imgs_en/img_10.jpg
  save_res_path: ./checkpoints/det_db/predicts_db.txt
  distributed: true

Architecture:
  model_type: det
  algorithm: DB
  Transform: null
  Backbone:
    name: PPLCNetV3
    scale: 0.75
    det: True
  Neck:
    name: RSEFPN
    out_channels: 96
    shortcut: True
  Head:
    name: DBHead
    k: 50

Loss:
  name: DBLoss
  balance_loss: true
  main_loss_type: DiceLoss
  alpha: 5
  beta: 10
  ohem_ratio: 3

Optimizer:
  name: Adam
  beta1: 0.9
  beta2: 0.999
  lr:
    name: Cosine
    learning_rate: 0.001 #(8*8c)
    warmup_epoch: 2
  regularizer:
    name: L2
    factor: 5.0e-05

PostProcess:
  name: DBPostProcess
  thresh: 0.3
  box_thresh: 0.6
  max_candidates: 1000
  unclip_ratio: 1.5

Metric:
  name: DetMetric
  main_indicator: hmean

Train:
  dataset:
    name: SimpleDataSet
    data_dir: ./train_data/icdar2015/text_localization/
    label_file_list:
      - ./train_data/icdar2015/text_localization/train_icdar2015_label.txt
    ratio_list: [1.0]
    transforms:
    - DecodeImage:
        img_mode: BGR
        channel_first: false
    - DetLabelEncode: null
    - CopyPaste: null
    - IaaAugment:
        augmenter_args:
        - type: Fliplr
          args:
            p: 0.5
        - type: Affine
          args:
            rotate:
            - -10
            - 10
        - type: Resize
          args:
            size:
            - 0.5
            - 3
    - EastRandomCropData:
        size:
        - 640
        - 640
        max_tries: 50
        keep_ratio: true
    - MakeBorderMap:
        shrink_ratio: 0.4
        thresh_min: 0.3
        thresh_max: 0.7
        total_epoch: *epoch_num
    - MakeShrinkMap:
        shrink_ratio: 0.4
        min_text_size: 8
        total_epoch: *epoch_num
    - NormalizeImage:
        scale: 1./255.
        mean:
        - 0.485
        - 0.456
        - 0.406
        std:
        - 0.229
        - 0.224
        - 0.225
        order: hwc
    - ToCHWImage: null
    - KeepKeys:
        keep_keys:
        - image
        - threshold_map
        - threshold_mask
        - shrink_map
        - shrink_mask
  loader:
    shuffle: true
    drop_last: false
    batch_size_per_card: 8
    num_workers: 8

Eval:
  dataset:
    name: SimpleDataSet
    data_dir: ./train_data/icdar2015/text_localization/
    label_file_list:
      - ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
    transforms:
    - DecodeImage:
        img_mode: BGR
        channel_first: false
    - DetLabelEncode: null
    - DetResizeForTest:
    - NormalizeImage:
        scale: 1./255.
        mean:
        - 0.485
        - 0.456
        - 0.406
        std:
        - 0.229
        - 0.224
        - 0.225
        order: hwc
    - ToCHWImage: null
    - KeepKeys:
        keep_keys:
        - image
        - shape
        - polys
        - ignore_tags
  loader:
    shuffle: false
    drop_last: false
    batch_size_per_card: 1
    num_workers: 2
profiler_options: null

1-4 训练

bash 复制代码
python tools/train.py -c configs/det/ch_PP-OCRv4/ch_PP-OCRv4_det_student.yml \
     -o Global.pretrained_model=./pretrain_models/MobileNetV3_large_x0_5_pretrained

1-5 评估

python tools/eval.py -c configs/det/ch_PP-OCRv4/ch_PP-OCRv4_det_student.yml -o Global.checkpoints="{path/to/weights}/best_accuracy"

1-6 推理

yaml 复制代码
python tools/infer_det.py -c configs/det/ch_PP-OCRv4/ch_PP-OCRv4_det_student.yml -o Global.infer_img="./doc/imgs_en/img_10.jpg" Global.pretrained_model="./output/det_db/best_accuracy"

1-7 导出

yaml 复制代码
python3 tools/export_model.py -c configs/det/ch_PP-OCRv4/ch_PP-OCRv4_det_student.yml -o Global.pretrained_model="./output/det_db/best_accuracy" Global.save_inference_dir="./output/det_db_inference/"

2 文本识别模型微调

2-1 数据准备

训练集&校验集

建议将训练图片放入同一个文件夹,并用一个txt文件(rec_gt_train.txt)记录图片路径和标签,txt文件里的内容如下:

注意: txt文件中默认请将图片路径和图片标签用 \t 分割,如用其他方式分割将造成训练报错。

bash 复制代码
" 图像文件名                 图像标注信息 "

train_data/rec/train/word_001.jpg   简单可依赖
train_data/rec/train/word_002.jpg   用科技让复杂的世界更简单
...

最终训练集应有如下文件结构:

bash 复制代码
|-train_data
  |-rec
    |- rec_gt_train.txt
    |- train
        |- word_001.png
        |- word_002.jpg
        |- word_003.jpg
        | ...

除上述单张图像为一行格式之外,PaddleOCR也支持对离线增广后的数据进行训练,为了防止相同样本在同一个batch中被多次采样,我们可以将相同标签对应的图片路径写在一行中,以列表的形式给出,在训练中,PaddleOCR会随机选择列表中的一张图片进行训练。对应地,标注文件的格式如下。

bash 复制代码
["11.jpg", "12.jpg"]   简单可依赖
["21.jpg", "22.jpg", "23.jpg"]   用科技让复杂的世界更简单
3.jpg   ocr

上述示例标注文件中,"11.jpg"和"12.jpg"的标签相同,都是简单可依赖,在训练的时候,对于该行标注,会随机选择其中的一张图片进行训练。

如果有通用真实场景数据加进来,建议每个epoch中,垂类场景数据与真实场景的数据量保持在1:1左右。

比如:您自己的垂类场景识别数据量为1W,数据标签文件为vertical.txt,收集到的通用场景识别数据量为10W,数据标签文件为general.txt

那么,可以设置label_file_listratio_list参数如下所示。每个epoch中,vertical.txt中会进行全采样(采样比例为1.0),包含1W条数据;general.txt中会按照0.1的采样比例进行采样,包含10W*0.1=1W条数据,最终二者的比例为1:1

Train:
  dataset:
    name: SimpleDataSet
    data_dir: ./train_data/
    label_file_list:
    - vertical.txt
    - general.txt
    ratio_list: [1.0, 0.1]
字典

需要提供一个自定义字典({word_dict_name}.txt),使模型在训练时,可以将所有出现的字符映射为字典的索引。

因此字典需要包含所有希望被正确识别的字符,{word_dict_name}.txt需要写成如下格式,并以 utf-8 编码格式保存:

l
d
a
d
r
n

word_dict.txt 每行有一个单字,将字符与数字索引映射在一起,"and" 将被映射成 [2 5 1]

  • 内置字典

    PaddleOCR内置了一部分字典,可以按需使用。

    ppocr/utils/ppocr_keys_v1.txt 是一个包含6623个字符的中文字典

    ppocr/utils/ic15_dict.txt 是一个包含36个字符的英文字典

    ppocr/utils/en_dict.txt 是一个包含96个字符的英文字典

公开数据集
数据集名称 图片下载地址 PaddleOCR 标注下载地址
en benchmark(MJ, SJ, IIIT, SVT, IC03, IC13, IC15, SVTP, and CUTE.) DTRB LMDB格式,可直接用lmdb_dataset.py加载
ICDAR 2015 http://rrc.cvc.uab.es/?ch=4\&com=downloads train/ test
多语言数据集 百度网盘 提取码:frgi google drive 图片下载地址中已包含

2-2 下载预训练模型

ch_PP-OCRv4_rec_train

2-3 参数配置

配置文件:configs/rec/PP-OCRv4/ch_PP-OCRv4_rec.yml

yaml 复制代码
Global:
  debug: false
  use_gpu: true
  epoch_num: 200
  log_smooth_window: 20
  print_batch_step: 10
  save_model_dir: ./output/rec_ppocr_v4
  save_epoch_step: 10
  eval_batch_step: [0, 2000]
  cal_metric_during_train: true
  pretrained_model:
  checkpoints:
  save_inference_dir:
  use_visualdl: false
  infer_img: doc/imgs_words/ch/word_1.jpg
  character_dict_path: ppocr/utils/ppocr_keys_v1.txt
  max_text_length: &max_text_length 25
  infer_mode: false
  use_space_char: true
  distributed: true
  save_res_path: ./output/rec/predicts_ppocrv3.txt


Optimizer:
  name: Adam
  beta1: 0.9
  beta2: 0.999
  lr:
    name: Cosine
    learning_rate: 0.001
    warmup_epoch: 5
  regularizer:
    name: L2
    factor: 3.0e-05


Architecture:
  model_type: rec
  algorithm: SVTR_LCNet
  Transform:
  Backbone:
    name: PPLCNetV3
    scale: 0.95
  Head:
    name: MultiHead
    head_list:
      - CTCHead:
          Neck:
            name: svtr
            dims: 120
            depth: 2
            hidden_dims: 120
            kernel_size: [1, 3]
            use_guide: True
          Head:
            fc_decay: 0.00001
      - NRTRHead:
          nrtr_dim: 384
          max_text_length: *max_text_length

Loss:
  name: MultiLoss
  loss_config_list:
    - CTCLoss:
    - NRTRLoss:

PostProcess:  
  name: CTCLabelDecode

Metric:
  name: RecMetric
  main_indicator: acc

Train:
  dataset:
    name: MultiScaleDataSet
    ds_width: false
    data_dir: ./train_data/
    ext_op_transform_idx: 1
    label_file_list:
    - ./train_data/train_list.txt
    transforms:
    - DecodeImage:
        img_mode: BGR
        channel_first: false
    - RecConAug:
        prob: 0.5
        ext_data_num: 2
        image_shape: [48, 320, 3]
        max_text_length: *max_text_length
    - RecAug:
    - MultiLabelEncode:
        gtc_encode: NRTRLabelEncode
    - KeepKeys:
        keep_keys:
        - image
        - label_ctc
        - label_gtc
        - length
        - valid_ratio
  sampler:
    name: MultiScaleSampler
    scales: [[320, 32], [320, 48], [320, 64]]
    first_bs: &bs 192
    fix_bs: false
    divided_factor: [8, 16] # w, h
    is_training: True
  loader:
    shuffle: true
    batch_size_per_card: *bs
    drop_last: true
    num_workers: 8
Eval:
  dataset:
    name: SimpleDataSet
    data_dir: ./train_data
    label_file_list:
    - ./train_data/val_list.txt
    transforms:
    - DecodeImage:
        img_mode: BGR
        channel_first: false
    - MultiLabelEncode:
        gtc_encode: NRTRLabelEncode
    - RecResizeImg:
        image_shape: [3, 48, 320]
    - KeepKeys:
        keep_keys:
        - image
        - label_ctc
        - label_gtc
        - length
        - valid_ratio
  loader:
    shuffle: false
    drop_last: false
    batch_size_per_card: 128
    num_workers: 4

2-4 训练

bash 复制代码
python tools/train.py -c configs/rec/PP-OCRv4/ch_PP-OCRv4_rec.yml \
     -o Global.pretrained_model=./pretrain_models/MobileNetV3_large_x0_5_pretrained

2-5 评估

bash 复制代码
python tools/eval.py -c configs/rec/PP-OCRv4/ch_PP-OCRv4_rec.yml -o Global.checkpoints={path/to/weights}/best_accuracy

2-6 推理

bash 复制代码
python tools/infer_rec.py -c configs/rec/PP-OCRv4/ch_PP-OCRv4_rec.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.infer_img=doc/imgs_words/ch/word_1.jpg

2-7 导出

yaml 复制代码
python tools/export_model.py -c configs/rec/PP-OCRv4/ch_PP-OCRv4_rec.yml -o Global.pretrained_model=./pretrain_models/en_PP-OCRv3_rec_train/best_accuracy  Global.save_inference_dir=./inference/en_PP-OCRv3_rec/

3 文本方向分类器微调

3-1 数据准备

训练集&校验集

首先建议将训练图片放入同一个文件夹,并用一个txt文件(cls_gt_train.txt)记录图片路径和标签。

注意: 默认请将图片路径和图片标签用 \t 分割,如用其他方式分割将造成训练报错

0和180分别表示图片的角度为0度和180度

" 图像文件名                 图像标注信息 "
train/cls/train/word_001.jpg   0
train/cls/train/word_002.jpg   180

最终训练集应有如下文件结构:

bash 复制代码
|-train_data
    |-cls
        |- cls_gt_train.txt
        |- train
            |- word_001.png
            |- word_002.jpg
            |- word_003.jpg
            | ...

3-2 下载预训练模型

ch_ppocr_mobile_v2.0_cls_train

3-3 参数配置

将准备好的txt文件和图片文件夹路径分别写入配置文件的 Train/Eval.dataset.label_file_listTrain/Eval.dataset.data_dir 字段下,Train/Eval.dataset.data_dir字段下的路径和文件里记载的图片名构成了图片的绝对路径。

3-4 训练

bash 复制代码
python tools/train.py -c configs/cls/cls_mv3.yml

3-5 评估

bash 复制代码
python tools/eval.py -c configs/cls/cls_mv3.yml -o Global.checkpoints={path/to/weights}/best_accuracy

3-6 推理

bash 复制代码
python tools/infer_cls.py -c configs/cls/cls_mv3.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.load_static_weights=false Global.infer_img=doc/imgs_words/ch/word_1.jpg
相关推荐
红米煮粥9 小时前
OpenCV-OCR
人工智能·opencv·ocr
菜就多练_08281 天前
《深度学习》OpenCV 摄像头OCR 过程及案例解析
人工智能·深度学习·opencv·ocr
OCR_wintone4212 天前
中安未来 OCR—— 开启高效驾驶证识别新时代
人工智能·汽车·ocr
OCR_wintone4212 天前
中安未来 OCR—— 开启文字识别新时代
人工智能·深度学习·ocr
OCR_wintone4212 天前
翔云 OCR:发票识别与验真
人工智能·深度学习·ocr
OCR_wintone4213 天前
中安未来 OCR:引领智能报关新时代
ocr
Maxx Space6 天前
828华为云征文|部署开源超轻量中文OCR项目 TrWebOCR
docker·开源·华为云·github·ocr
编程乐趣6 天前
tesseract:一个.Net版本的开源OCR项目
ocr·.net
吃什么芹菜卷6 天前
机器学习:opencv--摄像头OCR
人工智能·笔记·opencv·计算机视觉·ocr
翔云API7 天前
回执单识别-银行回单识别API-文字识别OCR API
ocr