OCR经典神经网络(一)文本识别算法CRNN算法原理及其在icdar15数据集上的应用

OCR经典神经网络(一)文本识别算法CRNN算法原理及其在icdar15数据集上的应用

  • 文本识别是OCR(Optical Character Recognition)的一个子任务,其任务为:识别一个固定区域的的文本内容。

    • 在OCR的两阶段方法里,文本识别模型接在文本检测(如DB算法)后面,将图像信息转换为文字信息
    • 具体来讲:如下图所示,文本识别模型的输入是一张经过文本检测后的文本行图片,输出图片中的文字内容和置信度。
    shell 复制代码
    ('实力活力', 0.9861845970153809)
  • 文本识别的应用场景很多,如:文档识别、路标识别、车牌识别、工业编号识别等等。下表展示了主流的算法类别和主要论文。

算法类别 主要思路 主要论文
传统算法 滑动窗口、字符提取、动态规划 -
ctc 基于ctc的方法,序列不对齐,更快速识别 CRNN, Rosetta
Attention 基于attention的方法,应用于非常规文本 RARE, DAN, PREN
Transformer 基于transformer的方法 SRN, NRTR, Master, ABINet
校正 校正模块学习文本边界并校正成水平方向 RARE, ASTER, SAR
分割 基于分割的方法,提取字符位置再做分类 Text Scanner, Mask TextSpotter

1 CRNN网络结构

如下图所示,CRNN由CNN+RNN+CTC三部分组成。

  • 特征提取部分使用主流的卷积结构,常用的有ResNet、MobileNet、VGG等,原始论文中使用的VGG。
  • 由于文本识别任务的特殊性,输入数据中存在大量的上下文信息,卷积神经网络的卷积核特性使其更关注于局部信息,缺乏长依赖的建模能力,因此仅使用卷积网络很难挖掘到文本之间的上下文联系。为了解决这一问题,CRNN文本识别算法引入了双向 LSTM用来增强上下文建模,通过实验证明双向LSTM模块可以有效的提取出图片中的上下文信息。
  • 时间步输出是24个,但是图片中字符数不一定都是24,长短不一(注:是0-9数字以及a-z字母组合,还有一个blank标识符,总共37类)。因此,最终将输出的特征序列输入到CTC模块,直接解码序列结果。该结构被验证有效,并广泛应用在文本识别任务中。

1.1 CNN结构

CNN结构采用的是VGG的结构,并且对VGG网络做了一些微调:

  • 为了能将CNN提取的特征作为输入,输入到RNN网络中,将第三和第四个maxpooling的核尺度从 2 × 2 2 × 2 2×2改为了 1 × 2 1 × 2 1×2;

    • 将第三和第四个maxpooling改变的原因:为了方便的将CNN的提取特征作为RNN的输入。
    • 首先要注意的是这个网络的输入为W × 32,也就是说该网络对输入图片的宽没有特殊的要求,但是高都必须resize到32。
    • 文中举例说明:如果一张包含10个字符的图片大小为100 × 32,经过上述的CNN网络得到的特征尺度为24 × 1(忽略通道数),这样得到一个序列。每一列特征对应原图的一个矩形区域(如下图所示),这样就很方便作为RNN的输入进行下一步的计算了,而且每个特征与输入有一个一对一的对应关系。
  • 为了加速网络的训练,在第五和第六个卷积层后面加上了BN层;

1.2 RNN结构

为了防止训练时梯度的消失,采用了LSTM神经单元作为RNN的单元。作者认为对于序列的预测,序列的前向信息和后向信息都有助于序列的预测,所以作者采用了双向RNN网络

1.3 CTC转录层

  • 如果使用传统的loss function,需要对齐训练样本,有24个时间步,就需要有24个对应的标签,在该任务中显然不合适,除非可以把图片中的每一个字符都单独检测出来,一个字符对应一个标签,但这需要很强大的文字检测算法,而CTCLoss不需要对齐样本。

  • 24个时间步得到24个标签,再进行一个β变换,才得到最终标签。

    • 24个时间步可以看作原图中分成24列,每一列输出一个标签,有时一个字母占据好几列,例如字母S占据三列,则这三列输出类别都应该是S,有的列没有字母,则输出空白类别。
    • 得到最终类别时将连续重复的字符去重(空白符两侧的相同字符不去重,因为真实标签中可能存在连续重复字符,例如green中的两个连续的e不应该去重,则生成标签的时候就该是类似e-e这种,则不会去重),最终去除空白符即可得到最终标签。
  • 由于CTCLoss计算有些复杂,具体可参考:CTC算法详解。深度学习框架中,一般都集成了CTC Loss,我们直接使用即可。

python 复制代码
def ctc_loss():
    import torch
    # https://pytorch.org/docs/1.13/generated/torch.nn.CTCLoss.html#torch.nn.CTCLoss
    # Target are to be padded
    T = 50  # Input sequence length
    C = 20  # Number of classes (including blank)
    N = 2   # Batch size
    S = 30  # Target sequence length of longest target in batch (padding length)
    S_min = 10  # 目标序列的最小长度,这里仅用于生成随机目标长度时的范围限制

     # Initialize random batch of input vectors, for *size = (T,N,C)
    # 对输入进行了softmax操作并取了对数,以符合CTC损失函数的输入要求
    # input shape = (50, 2, 20)
    input = torch.randn(T, N, C).log_softmax(2).detach().requires_grad_()

     # Initialize random batch of targets (0 = blank, 1:C = classes)
    # target shape = (2, 30)
    target = torch.randint(low=1, high=C, size=(N, S), dtype=torch.long)

    # 长度为N的张量,表示每个输入序列的实际长度
    # 这里假设所有输入序列都是完整的T长度
    input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.long)
    # 长度为N的张量,表示每个目标序列的实际长度。
    # 这里通过随机生成一个介于S_min和S之间的整数来模拟不同长度的目标序列
    target_lengths = torch.randint(low=S_min, high=S, size=(N,), dtype=torch.long)

    ctc_loss = nn.CTCLoss()
    # input是模型的输出(对数概率),target是目标序列
    # input_lengths和target_lengths分别指定了输入序列和目标序列的实际长度。
    loss = ctc_loss(input, target, input_lengths, target_lengths)
    loss.backward()

if __name__ == '__main__':
    ctc_loss()

1.4 CRNN网络的简单实现

使用pytorch框架实现CRNN网络如下:

python 复制代码
import torch.nn as nn
from torchinfo import summary

class CRNN(nn.Module):

    def __init__(self, img_channel, img_height, img_width, num_class,
                 map_to_seq_hidden=64, rnn_hidden=256, leaky_relu=False):
        super(CRNN, self).__init__()

        self.cnn, (output_channel, output_height, output_width) = \
            self._cnn_backbone(img_channel, img_height, img_width, leaky_relu)

        self.map_to_seq = nn.Linear(output_channel * output_height, map_to_seq_hidden)

        self.rnn1 = nn.LSTM(map_to_seq_hidden, rnn_hidden, bidirectional=True)

        # 如果接双向lstm输出,则要 *2,固定用法
        self.rnn2 = nn.LSTM(2 * rnn_hidden, rnn_hidden, bidirectional=True)

        self.dense = nn.Linear(2 * rnn_hidden, num_class)

    # CNN主干网络
    def _cnn_backbone(self, img_channel, img_height, img_width, leaky_relu):
        assert img_height % 16 == 0
        assert img_width % 4 == 0

        # 超参设置
        channels = [img_channel, 64, 128, 256, 256, 512, 512, 512]
        kernel_sizes = [3, 3, 3, 3, 3, 3, 2]
        strides = [1, 1, 1, 1, 1, 1, 1]
        paddings = [1, 1, 1, 1, 1, 1, 0]

        cnn = nn.Sequential()

        def conv_relu(i, batch_norm=False):
            # shape of input: (batch, input_channel, height, width)
            input_channel = channels[i]
            output_channel = channels[i + 1]

            cnn.add_module(
                f'conv{i}',
                nn.Conv2d(input_channel, output_channel, kernel_sizes[i], strides[i], paddings[i])
            )

            if batch_norm:
                cnn.add_module(f'batchnorm{i}', nn.BatchNorm2d(output_channel))

            relu = nn.LeakyReLU(0.2, inplace=True) if leaky_relu else nn.ReLU(inplace=True)
            cnn.add_module(f'relu{i}', relu)

        # size of image: (channel, height, width) = (img_channel, img_height, img_width)
        conv_relu(0)
        cnn.add_module('pooling0', nn.MaxPool2d(kernel_size=2, stride=2))
        # (64, img_height // 2, img_width // 2)

        conv_relu(1)
        cnn.add_module('pooling1', nn.MaxPool2d(kernel_size=2, stride=2))
        # (128, img_height // 4, img_width // 4)

        conv_relu(2)
        conv_relu(3)
        cnn.add_module(
            'pooling2',
            nn.MaxPool2d(kernel_size=(2, 1))
        )  # (256, img_height // 8, img_width // 4)

        conv_relu(4, batch_norm=True)
        conv_relu(5, batch_norm=True)
        cnn.add_module(
            'pooling3',
            nn.MaxPool2d(kernel_size=(2, 1))
        )  # (512, img_height // 16, img_width // 4)

        conv_relu(6)  # (512, img_height // 16 - 1, img_width // 4 - 1)

        output_channel, output_height, output_width = \
            channels[-1], img_height // 16 - 1, img_width // 4 - 1
        return cnn, (output_channel, output_height, output_width)

    # CNN+LSTM前向计算
    def forward(self, images):
        # shape of images: (batch, channel, height, width)

        conv = self.cnn(images)
        batch, channel, height, width = conv.size()

        conv = conv.view(batch, channel * height, width)
        conv = conv.permute(2, 0, 1)  # (width, batch, feature)

        # 卷积接全连接。全连接输入形状为(width, batch, channel*height),
        # 输出形状为(width, batch, hidden_layer),分别对应时序长度,batch,特征数,符合LSTM输入要求
        seq = self.map_to_seq(conv)

        recurrent, _ = self.rnn1(seq)
        recurrent, _ = self.rnn2(recurrent)

        output = self.dense(recurrent)
        return output  # shape: (seq_len, batch, num_class)

if __name__ == '__main__':
    net = CRNN(img_channel=3, img_height=32, img_width=100, num_class=37)
    summary(net, input_size=(1, 3, 32, 100))
shell 复制代码
==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
CRNN                                     [24, 1, 37]               --
├─Sequential: 1-1                        [1, 512, 1, 24]           --
│    └─Conv2d: 2-1                       [1, 64, 32, 100]          1,792
│    └─ReLU: 2-2                         [1, 64, 32, 100]          --
│    └─MaxPool2d: 2-3                    [1, 64, 16, 50]           --
│    └─Conv2d: 2-4                       [1, 128, 16, 50]          73,856
│    └─ReLU: 2-5                         [1, 128, 16, 50]          --
│    └─MaxPool2d: 2-6                    [1, 128, 8, 25]           --
│    └─Conv2d: 2-7                       [1, 256, 8, 25]           295,168
│    └─ReLU: 2-8                         [1, 256, 8, 25]           --
│    └─Conv2d: 2-9                       [1, 256, 8, 25]           590,080
│    └─ReLU: 2-10                        [1, 256, 8, 25]           --
│    └─MaxPool2d: 2-11                   [1, 256, 4, 25]           --
│    └─Conv2d: 2-12                      [1, 512, 4, 25]           1,180,160
│    └─BatchNorm2d: 2-13                 [1, 512, 4, 25]           1,024
│    └─ReLU: 2-14                        [1, 512, 4, 25]           --
│    └─Conv2d: 2-15                      [1, 512, 4, 25]           2,359,808
│    └─BatchNorm2d: 2-16                 [1, 512, 4, 25]           1,024
│    └─ReLU: 2-17                        [1, 512, 4, 25]           --
│    └─MaxPool2d: 2-18                   [1, 512, 2, 25]           --
│    └─Conv2d: 2-19                      [1, 512, 1, 24]           1,049,088
│    └─ReLU: 2-20                        [1, 512, 1, 24]           --
├─Linear: 1-2                            [24, 1, 64]               32,832
├─LSTM: 1-3                              [24, 1, 512]              659,456
├─LSTM: 1-4                              [24, 1, 512]              1,576,960
├─Linear: 1-5                            [24, 1, 37]               18,981
==========================================================================================
Total params: 7,840,229
Trainable params: 7,840,229
Non-trainable params: 0
Total mult-adds (M): 675.96
==========================================================================================
Input size (MB): 0.04
Forward/backward pass size (MB): 5.23
Params size (MB): 31.36
Estimated Total Size (MB): 36.63
==========================================================================================

2 CRNN在icdar15数据集上的微调(paddleocr)

shell 复制代码
# git拉取下来,解压
git clone https://gitee.com/paddlepaddle/PaddleOCR

# 然后进入PaddleOCR目录,安装PaddleOCR第三方依赖
pip install -r requirements.txt
  • 我们在paddleocr/tests目录下,创建py文件进行下图的测试

python 复制代码
from paddleocr import PaddleOCR

ocr = PaddleOCR()

# 默认会下载官方训练好的模型,并将下载的模型放到用户目录下(我这里是:C:\\Users\\Undo/.paddleocr)
result = ocr.ocr(img='./rec_img.png'
                 , det=False  # 文本检测器,默认算法为DBNet
                 , rec=True   # 方向分类器
                 , cls=True   # 文本识别,默认为CRNN模型
                 )

for line in result:
    print(line)
shell 复制代码
[('实力活力', 0.7939199805259705)]

2.1 CRNN网络的搭建

2.1.1 backbone

  • MobileNet模型是Google针对手机等嵌入式设备提出的一种轻量化深度神经网络,使用的核心思想是:深度可分离卷积。MobileNet系列中主要包括MobileNet V1、MobileNet V2、MobileNet V3。

    • 论文链接:Searching for MobileNetV3

    • MobileNetV3 有两个版本,MobileNetV3-Small 与 MobileNetV3-Large 分别对应对计算和存储要求低和高的版本。

    • MobileNetV3继续采用了轻量级的深度可分离卷积和残差块等结构,依然是由多个模块组成,但是每个模块得到了优化和升级,包括瓶颈结构、SE模块和NL模块(如下图)。

    • 整体来说MobileNetV3有两大创新点:

      • 互补搜索技术组合:由资源受限的NAS执行模块级搜索,NetAdapt执行局部搜索。

      • 网络结构改进(如下图):将最后一步的平均池化层前移并移除最后一个卷积层,引入h-swish激活函数。

  • PaddleOCR 使用 MobileNetV3 作为骨干网络,paddleocr/configs/rec/rec_icdar15_train.yml中默认使用MobileNetV3-Large版本。同样,需要对MobileNetV3的结构做一些微调,主要对下采样进行修改。

yaml 复制代码
# paddleocr/configs/rec/rec_icdar15_train.yml
Architecture:
  model_type: rec
  algorithm: CRNN         # 使用CRNN模型
  Transform:
  Backbone:               
    name: MobileNetV3     # ppocr/modeling/backbones/rec_mobilenet_v3.py
    scale: 0.5            # 这里进行了缩放
    model_name: large     # 骨干网络使用MobileNetV3的large版本
  Neck:                   
    name: SequenceEncoder # ppocr/modeling/necks/rnn.py
    encoder_type: rnn
    hidden_size: 96
  Head:
    name: CTCHead         # ppocr/modeling/heads/rec_ctc_head.py
    fc_decay: 0

Loss:
  name: CTCLoss           # ppocr/losses/rec_ctc_loss.py 损失计算

PostProcess:
  name: CTCLabelDecode    # ppocr/postprocess/rec_postprocess.py  后处理

Metric:
  name: RecMetric         # ppocr/metrics/rec_metric.py 指标评估
  main_indicator: acc
python 复制代码
# ppocr/modeling/backbones/rec_mobilenet_v3.py

from paddle import nn

from ppocr.modeling.backbones.det_mobilenet_v3 import (
    ResidualUnit,
    ConvBNLayer,
    make_divisible,
)

__all__ = ["MobileNetV3"]


class MobileNetV3(nn.Layer):
    def __init__(
        self,
        in_channels=3,
        model_name="small",
        scale=0.5,
        large_stride=None,
        small_stride=None,
        disable_se=False,
        **kwargs,
    ):
        super(MobileNetV3, self).__init__()
        self.disable_se = disable_se
        if small_stride is None:
            small_stride = [2, 2, 2, 2]
        if large_stride is None:
            large_stride = [1, 2, 2, 2]

        assert isinstance(
            large_stride, list
        ), "large_stride type must " "be list but got {}".format(type(large_stride))
        assert isinstance(
            small_stride, list
        ), "small_stride type must " "be list but got {}".format(type(small_stride))
        assert (
            len(large_stride) == 4
        ), "large_stride length must be " "4 but got {}".format(len(large_stride))
        assert (
            len(small_stride) == 4
        ), "small_stride length must be " "4 but got {}".format(len(small_stride))

        if model_name == "large":
            cfg = [
                # k, exp, c,  se,     nl,  s,
                [3, 16, 16, False, "relu", large_stride[0]],            # 不进行下采样
                [3, 64, 24, False, "relu", (large_stride[1], 1)],       # step 2 高下采样2倍,宽不变
                [3, 72, 24, False, "relu", 1],
                [5, 72, 40, True, "relu", (large_stride[2], 1)],        # step 3 高下采样2倍,宽不变
                [5, 120, 40, True, "relu", 1],
                [5, 120, 40, True, "relu", 1],
                [3, 240, 80, False, "hardswish", 1],
                [3, 200, 80, False, "hardswish", 1],
                [3, 184, 80, False, "hardswish", 1],
                [3, 184, 80, False, "hardswish", 1],
                [3, 480, 112, True, "hardswish", 1],
                [3, 672, 112, True, "hardswish", 1],
                [5, 672, 160, True, "hardswish", (large_stride[3], 1)],  # step 4 高下采样2倍,宽不变
                [5, 960, 160, True, "hardswish", 1],
                [5, 960, 160, True, "hardswish", 1],
            ]
            cls_ch_squeeze = 960
        elif model_name == "small":
            cfg = [
                # k, exp, c,  se,     nl,  s,
                [3, 16, 16, True, "relu", (small_stride[0], 1)],
                [3, 72, 24, False, "relu", (small_stride[1], 1)],
                [3, 88, 24, False, "relu", 1],
                [5, 96, 40, True, "hardswish", (small_stride[2], 1)],
                [5, 240, 40, True, "hardswish", 1],
                [5, 240, 40, True, "hardswish", 1],
                [5, 120, 48, True, "hardswish", 1],
                [5, 144, 48, True, "hardswish", 1],
                [5, 288, 96, True, "hardswish", (small_stride[3], 1)],
                [5, 576, 96, True, "hardswish", 1],
                [5, 576, 96, True, "hardswish", 1],
            ]
            cls_ch_squeeze = 576
        else:
            raise NotImplementedError(
                "mode[" + model_name + "_model] is not implemented!"
            )

        supported_scale = [0.35, 0.5, 0.75, 1.0, 1.25]
        assert (
            scale in supported_scale
        ), "supported scales are {} but input scale is {}".format(
            supported_scale, scale
        )

        inplanes = 16
        # conv1
        self.conv1 = ConvBNLayer(
            in_channels=in_channels,
            out_channels=make_divisible(inplanes * scale),
            kernel_size=3,
            stride=2, # step 1 高下采样2倍,宽下采样2倍
            padding=1,
            groups=1,
            if_act=True,
            act="hardswish",
        )
        i = 0
        block_list = []
        inplanes = make_divisible(inplanes * scale)
        for k, exp, c, se, nl, s in cfg:
            se = se and not self.disable_se
            block_list.append(
                ResidualUnit(
                    in_channels=inplanes,
                    mid_channels=make_divisible(scale * exp),
                    out_channels=make_divisible(scale * c),
                    kernel_size=k,
                    stride=s,
                    use_se=se,
                    act=nl,
                )
            )
            inplanes = make_divisible(scale * c)
            i += 1
        self.blocks = nn.Sequential(*block_list)

        self.conv2 = ConvBNLayer(
            in_channels=inplanes,
            out_channels=make_divisible(scale * cls_ch_squeeze),
            kernel_size=1,
            stride=1,
            padding=0,
            groups=1,
            if_act=True,
            act="hardswish",
        )

        self.pool = nn.MaxPool2D(kernel_size=2, stride=2, padding=0)  # step 5 高下采样2倍,宽下采样2倍
        self.out_channels = make_divisible(scale * cls_ch_squeeze)

    def forward(self, x):
        x = self.conv1(x)
        x = self.blocks(x)
        x = self.conv2(x)
        x = self.pool(x)
        return x


if __name__ == '__main__':
    import paddle as torch
    x = torch.rand((1, 3, 32, 100))
    net = MobileNetV3(in_channels=3, model_name='large', scale=0.5)
    # [1, 480, 1, 25]
    print(net(x).shape)

2.1.2 neck

neck 部分将backbone输出的视觉特征图转换为1维向量输入送到 LSTM 网络中,输出序列特征

python 复制代码
# ppocr/modeling/necks/rnn.py
class SequenceEncoder(nn.Layer):
    def __init__(self, in_channels, encoder_type, hidden_size=48, **kwargs):
        super(SequenceEncoder, self).__init__()
        self.encoder_reshape = Im2Seq(in_channels)
        self.out_channels = self.encoder_reshape.out_channels
        ......

    def forward(self, x):
        if self.encoder_type != "svtr":
            # 1、neck部分将backbone输出的视觉特征图转换为1维向量
            # (bs, channels, H, W) -> (bs, H * W, channels),即(bs, seq_len, embedding_dim)
            # reshape后的x = (bs, 25, 480) , 480 = channels(960) * scale(0.5)
            x = self.encoder_reshape(x)

            # 2、转换后,送到LSTM网络中,输出序列特征
            # self.encoder
            # EncoderWithRNN(
            #   (lstm): LSTM(480, 96, num_layers=2
            #     (0): BiRNN(
            #       (cell_fw): LSTMCell(480, 96)
            #       (cell_bw): LSTMCell(480, 96)
            #     )
            #     (1): BiRNN(
            #       (cell_fw): LSTMCell(192, 96)
            #       (cell_bw): LSTMCell(192, 96)
            #     )
            #   )
            # )
            if not self.only_reshape:
                x = self.encoder(x)  # [bs, 25, 96 * 2]
            return x
        else:
            x = self.encoder(x)
            x = self.encoder_reshape(x)
            return x

2.1.3 head

预测头部分由全连接层和softmax组成,用于计算序列特征时间步上的标签概率分布

python 复制代码
class CTCHead(nn.Layer):
    def __init__(
        self,
        in_channels,
        out_channels,
        fc_decay=0.0004,
        mid_channels=None,
        return_feats=False,
        **kwargs,
    ):
        super(CTCHead, self).__init__()
        if mid_channels is None:
            weight_attr, bias_attr = get_para_bias_attr(
                l2_decay=fc_decay, k=in_channels
            )
            # Linear(in_features=192, out_features=37, dtype=float32)
            self.fc = nn.Linear(
                in_channels, out_channels, weight_attr=weight_attr, bias_attr=bias_attr
            )
        else:
            ......
        self.out_channels = out_channels
        self.mid_channels = mid_channels
        self.return_feats = return_feats

    def forward(self, x, targets=None):
        if self.mid_channels is None:
            # (bs, 25, 192) -> predicts (bs, 25, 37)
            predicts = self.fc(x)
        else:
            x = self.fc1(x)
            predicts = self.fc2(x)

        if self.return_feats:
            result = (x, predicts)
        else:
            result = predicts
        if not self.training:  # 非训练时,经过SoftMax,可以得到各时间步上的概率最大的预测结果
            predicts = F.softmax(predicts, axis=2)
            result = predicts

        return result

2.2 数据集加载及模型训练

2.2.1 数据集的下载

提供一份处理过的icdar15数据集:

下载地址:https://pan.baidu.com/s/1VP2Y_IhxAUwQABDmbXrgIg 提取码: ek25

数据集应有如下文件结构:

|-train_data
  |-ic15_data
    |- rec_gt_train.txt
    |- train
        |- word_001.png
        |- word_002.jpg
        |- word_003.jpg
        | ...
    |- rec_gt_test.txt
    |- test
        |- word_001.png
        |- word_002.jpg
        |- word_003.jpg
        | ...

其中txt文件里的内容如下:

" 图像文件名         图像标注信息 "

train/word_1.png	Genaxis Theatre
train/word_2.png	[06]
...

下载完数据集后,我们复制一份paddleocr/configs/rec/rec_icdar15_train.yml文件到paddleocr\tests\configs进行修改:

yaml 复制代码
......
Train: # 修改训练集的路径及其他相关信息
  dataset:
    name: SimpleDataSet
    data_dir: D:/python/datas/cv/ic15_data  
    label_file_list: ["D:/python/datas/cv/ic15_data/rec_gt_train.txt"]
    transforms:
      - DecodeImage: # load image
          img_mode: BGR
          channel_first: False
      - CTCLabelEncode: # Class handling label
      - RecResizeImg:
          image_shape: [3, 32, 100]
      - KeepKeys:
          keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
  loader:
    shuffle: True             # 是否打乱
    batch_size_per_card: 256  # 批次大小
    drop_last: True           # 最后1个批次是否删除
    num_workers: 0            # 读取数据的进程数
    use_shared_memory: False

Eval: # 修改验证集的路径及其他相关信息
  dataset:
    name: SimpleDataSet
    data_dir: D:/python/datas/cv/ic15_data
    label_file_list: ["D:/python/datas/cv/ic15_data/rec_gt_test.txt"]
    transforms:
      - DecodeImage: # load image
          img_mode: BGR
          channel_first: False
      - CTCLabelEncode: # Class handling label
      - RecResizeImg:
          image_shape: [3, 32, 100]
      - KeepKeys:
          keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
  loader:
    shuffle: False
    drop_last: False
    batch_size_per_card: 256
    num_workers: 0
    use_shared_memory: False

2.2.2 模型的训练与预测

python 复制代码
def train_rec():
    from tools.train import program, set_seed, main
    # 配置文件的源地址地址: paddleocr/configs/rec/rec_icdar15_train.yml
    config, device, logger, vdl_writer = program.preprocess(is_train=True)

    ###############修改配置(也可在yml文件中修改)##################
    # 加载预训练模型
    # https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_none_bilstm_ctc_v2.0_train.tar
    # 或者 https://paddleocr.bj.bcebos.com/PP-OCRv3/english/en_PP-OCRv3_rec_train.tar
    config["Global"]["pretrained_model"] = r"D:\python\models\layout_ocr\rec_mv3_none_bilstm_ctc_v2.0_train\best_accuracy"
    # 字典路径(这里只支持26个小写字母+10个数字)
    config["Global"]["character_dict_path"] = r"D:\python\py_works\paddleocr\ppocr\utils\ic15_dict.txt"
    # 评估频率
    config["Global"]["eval_batch_step"] = [0, 200]
    # log的打印频率
    config["Global"]["print_batch_step"] = 50
    # 训练的epochs
    config["Global"]["epoch_num"] = 1
    # 随机种子
    seed = config["Global"]["seed"] if "seed" in config["Global"] else 1024
    set_seed(seed)

    ###############模型训练##################
    main(config, device, logger, vdl_writer, seed)

    
def infer_rec_img():
    # 加载自己训练的模型
    from tools.infer_rec import main, program

    config, device, logger, vdl_writer = program.preprocess()
    config["Global"]["use_gpu"] = False
    config["Global"]["infer_img"] = r"D:/python/py_works/paddleocr/tests/rec_img_slow.png"
    config["Global"]["checkpoints"] = r"D:\python\py_works\paddleocr\tests\output\rec\ic15\best_accuracy"
    config["Global"]["character_dict_path"] = r"D:\python\py_works\paddleocr\ppocr\utils\ic15_dict.txt"
    # 这里为了能python文件执行,加了add_config这个参数,源码中没有
    main(add_config=(config, device, logger, vdl_writer))    

if __name__ == '__main__':
    # 模型训练
    train_rec()
    # 加载自己模型进行推理
    # ppocr INFO: 	 result: slow	0.863487958908081
    infer_rec_img()

2.2.3 数据集的加载

python 复制代码
# paddleocr/tools/train.py
def main(config, device, logger, vdl_writer, seed):
    # init dist environment
    if config["Global"]["distributed"]:
        dist.init_parallel_env()

    global_config = config["Global"]

    # build dataloader
    set_signal_handlers()
    # 1、创建dataloader
    train_dataloader = build_dataloader(config, "Train", device, logger, seed)

    ......
    if config["Eval"]:
        valid_dataloader = build_dataloader(config, "Eval", device, logger, seed)
    else:
        valid_dataloader = None
    step_pre_epoch = len(train_dataloader)

    # 2、后处理程序
    # build post process
    post_process_class = build_post_process(config["PostProcess"], global_config)

    # 3、模型构建
    # build model
    .....
    model = build_model(config["Architecture"])

    use_sync_bn = config["Global"].get("use_sync_bn", False)
    if use_sync_bn:
        model = paddle.nn.SyncBatchNorm.convert_sync_batchnorm(model)
        logger.info("convert_sync_batchnorm")

    model = apply_to_static(model, config, logger)

    # 4、构建损失函数
    # build loss
    loss_class = build_loss(config["Loss"])

    # 5、构建优化器
    # build optim
    optimizer, lr_scheduler = build_optimizer(
        config["Optimizer"],
        epochs=config["Global"]["epoch_num"],
        step_each_epoch=len(train_dataloader),
        model=model,
    )

    # 6、创建评估函数
    # build metric
    eval_class = build_metric(config["Metric"])
    ......

    # 7、加载预训练模型
    # load pretrain model
    pre_best_model_dict = load_model(
        config, model, optimizer, config["Architecture"]["model_type"]
    )

    if config["Global"]["distributed"]:
        model = paddle.DataParallel(model)

    # 8、模型训练
    # start train
    program.train(
        config,
        train_dataloader,
        valid_dataloader,
        device,
        model,
        loss_class,
        optimizer,
        lr_scheduler,
        post_process_class,
        eval_class,
        pre_best_model_dict,
        logger,
        step_pre_epoch,
        vdl_writer,
        scaler,
        amp_level,
        amp_custom_black_list,
        amp_custom_white_list,
        amp_dtype,
    )
  • 通过build_dataloader加载数据集
python 复制代码
# paddleocr/ppocr/data/__init__.py
def build_dataloader(config, mode, device, logger, seed=None):
    config = copy.deepcopy(config)

    support_dict = [
        "SimpleDataSet", # 配置文件中为SimpleDataSet
        "LMDBDataSet",
        "PGDataSet",
        "PubTabDataSet",
        "LMDBDataSetSR",
        "LMDBDataSetTableMaster",
        "MultiScaleDataSet",
        "TextDetDataset",
        "TextRecDataset",
        "MSTextRecDataset",
        "PubTabTableRecDataset",
        "KieDataset",
        "LaTeXOCRDataSet",
    ]
    module_name = config[mode]["dataset"]["name"]
    assert module_name in support_dict, Exception(
        "DataSet only support {}".format(support_dict)
    )
    assert mode in ["Train", "Eval", "Test"], "Mode should be Train, Eval or Test."
    # 1、创建dataset
    dataset = eval(module_name)(config, mode, logger, seed)
    ......
    # 2、创建data_loader
    data_loader = DataLoader(
        dataset=dataset,
        batch_sampler=batch_sampler,
        places=device,
        num_workers=num_workers,
        return_list=True,
        use_shared_memory=use_shared_memory,
        collate_fn=collate_fn,
    )

    return data_loader
  • SimpleDataSet中定义了数据的预处理:

    • 1、通过ppocr.data.imaug.operators.DecodeImage 将读取的image(二进制数据)转换为numpy数组, 即(高度H、宽度W、通道数C)

    • 2、通过ppocr.data.imaug.label_ops.CTCLabelEncode将label进行标签编码+one-hot编码(data["length"] + data["label"] + data["label_ace"])

    • 3、通过ppocr.data.imaug.rec_img_aug.RecResizeImg将image进行缩放+归一化+padding, image shape=(3, 32, 100)

    • 4、通过ppocr.data.imaug.operators.KeepKeys 仅仅将['image', 'label', 'length']保存到list中

python 复制代码
class SimpleDataSet(Dataset):
    ......
    
    def __getitem__(self, idx):
        file_idx = self.data_idx_order_list[idx]
        data_line = self.data_lines[file_idx]
        try:
            data_line = data_line.decode("utf-8")
            substr = data_line.strip("\n").strip("\r").split(self.delimiter)
            file_name = substr[0]
            file_name = self._try_parse_filename_list(file_name)
            label = substr[1]
            img_path = os.path.join(self.data_dir, file_name)
            data = {"img_path": img_path, "label": label}
            if not os.path.exists(img_path):
                raise Exception("{} does not exist!".format(img_path))
            with open(data["img_path"], "rb") as f:
                img = f.read()
                data["image"] = img
            data["ext_data"] = self.get_ext_data()
            # 数据transform
            # ppocr.data.imaug.operators.DecodeImage    1、将读取的image(二进制数据)转换为numpy数组, 即(高度H、宽度W、通道数C)
            # ppocr.data.imaug.label_ops.CTCLabelEncode 2、将label进行标签编码+one-hot编码(data["length"] + data["label"] + data["label_ace"])
            # ppocr.data.imaug.rec_img_aug.RecResizeImg 3、将image进行缩放+归一化+padding, image shape=(3, 32, 100)
            # ppocr.data.imaug.operators.KeepKeys       4、仅仅将['image', 'label', 'length']保存到list中
            outs = transform(data, self.ops)
        except:
           ......
        return outs

    def __len__(self):
        return len(self.data_idx_order_list)

2.2.4 后处理函数

通过build_post_process构建后处理,

python 复制代码
# ppocr/postprocess/rec_postprocess.py  后处理
class CTCLabelDecode(BaseRecLabelDecode):
    """Convert between text-label and text-index"""

    def __init__(self, character_dict_path=None, use_space_char=False, **kwargs):
        super(CTCLabelDecode, self).__init__(character_dict_path, use_space_char)

    def __call__(self, preds, label=None, return_word_box=False, *args, **kwargs):
        if isinstance(preds, tuple) or isinstance(preds, list):
            preds = preds[-1]
        if isinstance(preds, paddle.Tensor):
            preds = preds.numpy()
        # 1、获取各个时间步上的最大索引值以及概率
        preds_idx = preds.argmax(axis=2)
        preds_prob = preds.max(axis=2)
        # 2、将索引值进行解码(转换为文字)
        text = self.decode(
            preds_idx,
            preds_prob,
            is_remove_duplicate=True,
            return_word_box=return_word_box,
        )
        if return_word_box:
            for rec_idx, rec in enumerate(text):
                wh_ratio = kwargs["wh_ratio_list"][rec_idx]
                max_wh_ratio = kwargs["max_wh_ratio"]
                rec[2][0] = rec[2][0] * (wh_ratio / max_wh_ratio)
        if label is None:
            # 3、推理时,直接返回text[('解码后的文字', 文字置信度的平均值)]
            return text
        # 4、训练时, 还需要将label进行解码
        label = self.decode(label)
        return text, label

    def add_special_char(self, dict_character):
        dict_character = ["blank"] + dict_character
        return dict_character
  • 核心函数就是BaseRecLabelDecode类中的decode函数
python 复制代码
    # ppocr/postprocess/rec_postprocess.py中的BaseRecLabelDecode类
    def decode(
        self,
        text_index,
        text_prob=None,
        is_remove_duplicate=False,
        return_word_box=False,
    ):
        """convert text-index into text-label."""
        result_list = []
        ignored_tokens = self.get_ignored_tokens() # 忽略tokens, 其中[0]代表ctc中的blank位
        batch_size = len(text_index)
        for batch_idx in range(batch_size):
            selection = np.ones(len(text_index[batch_idx]), dtype=bool)
            if is_remove_duplicate:
                # 1、合并blank之间相同的字符,即【当前位置索引】和【下一位置索引】不相同的就保留
                selection[1:] = text_index[batch_idx][1:] != text_index[batch_idx][:-1]
            for ignored_token in ignored_tokens:
                selection &= text_index[batch_idx] != ignored_token
            # 2、将解码的结果存在char_list中
            char_list = [
                self.character[text_id] for text_id in text_index[batch_idx][selection]
            ]
            if text_prob is not None:
                conf_list = text_prob[batch_idx][selection]
            else:
                conf_list = [1] * len(selection)
            if len(conf_list) == 0:
                conf_list = [0]
            # 3、char_list合并为字符串
            text = "".join(char_list)

            if self.reverse:  # for arabic rec
                text = self.pred_reverse(text)

            if return_word_box:
               ......
            else:
                # 置信度为每个识别的字符置信度的平均值
                result_list.append((text, np.mean(conf_list).tolist()))
        return result_list

2.2.5 CTC Loss

  • 通过build_loss函数构建CTC Loss
  • CRNN 模型的损失函数为 CTC loss, 飞桨集成了常用的 Loss 函数,只需调用实现即可
python 复制代码
#  paddleocr/ppocr/losses/rec_ctc_loss.py
class CTCLoss(nn.Layer):
    def __init__(self, use_focal_loss=False, **kwargs):
        super(CTCLoss, self).__init__()
         # blank 是 ctc 的无意义连接符
        self.loss_func = nn.CTCLoss(blank=0, reduction="none")
        self.use_focal_loss = use_focal_loss

    def forward(self, predicts, batch):
        if isinstance(predicts, (list, tuple)):
            predicts = predicts[-1]
        # 转置模型 head 层的预测结果,沿channel层排列 
        # (bs, 25, 37) -> (25, bs, 37)
        predicts = predicts.transpose((1, 0, 2)) 
        N, B, _ = predicts.shape
        # [N, N, ..., N]一共bs个,每个长度都为N
        preds_lengths = paddle.to_tensor(
            [N] * B, dtype="int64", place=paddle.CPUPlace()
        )
        # batch一个list
        # batch[0]为bs个预处理好image Tensor
        # batch[1]为bs个编码好的token序列,即label,shape = (bs, seq_len)
        # batch[2]为bs个token序列的实际长度(因为有填充)
        labels = batch[1].astype("int32")
        label_lengths = batch[2].astype("int64")
        # 计算损失函数
        loss = self.loss_func(predicts, labels, preds_lengths, label_lengths)
        if self.use_focal_loss:
            weight = paddle.exp(-loss)
            weight = paddle.subtract(paddle.to_tensor([1.0]), weight)
            weight = paddle.square(weight)
            loss = paddle.multiply(loss, weight)
        loss = loss.mean()
        return {"loss": loss}    

其他训练细节诸如:构建优化器、创建评估函数、加载预训练模型、模型训练等,大家可以查看源码,不再赘述。

相关推荐
蹉跎x1 小时前
力扣1358. 包含所有三种字符的子字符串数目
数据结构·算法·leetcode·职场和发展
巫师不要去魔法部乱说2 小时前
PyCharm专项训练4 最小生成树算法
算法·pycharm
IT猿手2 小时前
最新高性能多目标优化算法:多目标麋鹿优化算法(MOEHO)求解GLSMOP1-GLSMOP9及工程应用---盘式制动器设计,提供完整MATLAB代码
开发语言·算法·机器学习·matlab·强化学习
阿七想学习2 小时前
数据结构《排序》
java·数据结构·学习·算法·排序算法
王老师青少年编程2 小时前
gesp(二级)(12)洛谷:B3955:[GESP202403 二级] 小杨的日字矩阵
c++·算法·矩阵·gesp·csp·信奥赛
Kenneth風车3 小时前
【机器学习(九)】分类和回归任务-多层感知机(Multilayer Perceptron,MLP)算法-Sentosa_DSML社区版 (1)111
算法·机器学习·分类
盼小辉丶3 小时前
TensorFlow深度学习实战(2)——使用TensorFlow构建神经网络
深度学习·神经网络·tensorflow
eternal__day3 小时前
数据结构(哈希表(中)纯概念版)
java·数据结构·算法·哈希算法·推荐算法
APP 肖提莫3 小时前
MyBatis-Plus分页拦截器,源码的重构(重构total总数的计算逻辑)
java·前端·算法
OTWOL3 小时前
两道数组有关的OJ练习题
c语言·开发语言·数据结构·c++·算法