OCR经典神经网络(一)文本识别算法CRNN算法原理及其在icdar15数据集上的应用
-
文本识别是OCR(Optical Character Recognition)的一个子任务,其任务为:识别一个固定区域的的文本内容。
- 在OCR的两阶段方法里,文本识别模型接在文本检测(如DB算法)后面,
将图像信息转换为文字信息
。 - 具体来讲:如下图所示,文本识别模型的输入是一张经过文本检测后的文本行图片,输出图片中的文字内容和置信度。
shell('实力活力', 0.9861845970153809)
- 在OCR的两阶段方法里,文本识别模型接在文本检测(如DB算法)后面,
-
文本识别的应用场景很多,如:文档识别、路标识别、车牌识别、工业编号识别等等。下表展示了主流的算法类别和主要论文。
- 今天我们了解下由华中科技大学白翔老师团队在2015年提出的CRNN模型。
- 论文链接:https://arxiv.org/pdf/1507.05717
- 百度开源的paddleocr中集成了此算法:https://github.com/PaddlePaddle/PaddleOCR
算法类别 | 主要思路 | 主要论文 |
---|---|---|
传统算法 | 滑动窗口、字符提取、动态规划 | - |
ctc | 基于ctc的方法,序列不对齐,更快速识别 | CRNN, Rosetta |
Attention | 基于attention的方法,应用于非常规文本 | RARE, DAN, PREN |
Transformer | 基于transformer的方法 | SRN, NRTR, Master, ABINet |
校正 | 校正模块学习文本边界并校正成水平方向 | RARE, ASTER, SAR |
分割 | 基于分割的方法,提取字符位置再做分类 | Text Scanner, Mask TextSpotter |
1 CRNN网络结构
如下图所示,CRNN由CNN+RNN+CTC三部分组成。
- 特征提取部分使用主流的卷积结构,常用的有ResNet、MobileNet、VGG等,原始论文中使用的VGG。
- 由于文本识别任务的特殊性,输入数据中存在大量的上下文信息,卷积神经网络的卷积核特性使其更关注于局部信息,缺乏长依赖的建模能力,因此仅使用卷积网络很难挖掘到文本之间的上下文联系。为了解决这一问题,CRNN文本识别算法引入了双向 LSTM用来增强上下文建模,通过实验证明双向LSTM模块可以有效的提取出图片中的上下文信息。
- 时间步输出是24个,但是图片中字符数不一定都是24,长短不一(注:是0-9数字以及a-z字母组合,还有一个blank标识符,总共37类)。因此,最终将输出的特征序列输入到CTC模块,直接解码序列结果。该结构被验证有效,并广泛应用在文本识别任务中。
1.1 CNN结构
CNN结构采用的是VGG的结构,并且对VGG网络做了一些微调:
-
为了能将CNN提取的特征作为输入,输入到RNN网络中,
将第三和第四个maxpooling的核尺度
从 2 × 2 2 × 2 2×2改为了 1 × 2 1 × 2 1×2;- 将第三和第四个maxpooling改变的原因:为了方便的将CNN的提取特征作为RNN的输入。
- 首先要注意的是这个网络的输入为W × 32,也就是说该网络对输入图片的宽没有特殊的要求,但是高都必须resize到32。
- 文中举例说明:如果一张包含10个字符的图片大小为100 × 32,经过上述的CNN网络得到的特征尺度为24 × 1(忽略通道数),这样得到一个序列。每一列特征对应原图的一个矩形区域(如下图所示),这样就很方便作为RNN的输入进行下一步的计算了,而且每个特征与输入有一个一对一的对应关系。
-
为了加速网络的训练,在第五和第六个卷积层后面加上了BN层;
1.2 RNN结构
为了防止训练时梯度的消失,采用了LSTM神经单元作为RNN的单元。作者认为对于序列的预测,序列的前向信息和后向信息都有助于序列的预测,所以作者采用了双向RNN网络
1.3 CTC转录层
-
如果使用传统的loss function,需要对齐训练样本,有24个时间步,就需要有24个对应的标签,在该任务中显然不合适,除非可以把图片中的每一个字符都单独检测出来,一个字符对应一个标签,但这需要很强大的文字检测算法,而CTCLoss不需要对齐样本。
-
24个时间步得到24个标签,再进行一个β变换,才得到最终标签。
- 24个时间步可以看作原图中分成24列,每一列输出一个标签,有时一个字母占据好几列,例如字母S占据三列,则这三列输出类别都应该是S,有的列没有字母,则输出空白类别。
- 得到最终类别时将连续重复的字符去重(
空白符两侧的相同字符不去重,因为真实标签中可能存在连续重复字符
,例如green中的两个连续的e不应该去重,则生成标签的时候就该是类似e-e这种,则不会去重),最终去除空白符即可得到最终标签。
-
由于CTCLoss计算有些复杂,具体可参考:CTC算法详解。深度学习框架中,一般都集成了CTC Loss,我们直接使用即可。
python
def ctc_loss():
import torch
# https://pytorch.org/docs/1.13/generated/torch.nn.CTCLoss.html#torch.nn.CTCLoss
# Target are to be padded
T = 50 # Input sequence length
C = 20 # Number of classes (including blank)
N = 2 # Batch size
S = 30 # Target sequence length of longest target in batch (padding length)
S_min = 10 # 目标序列的最小长度,这里仅用于生成随机目标长度时的范围限制
# Initialize random batch of input vectors, for *size = (T,N,C)
# 对输入进行了softmax操作并取了对数,以符合CTC损失函数的输入要求
# input shape = (50, 2, 20)
input = torch.randn(T, N, C).log_softmax(2).detach().requires_grad_()
# Initialize random batch of targets (0 = blank, 1:C = classes)
# target shape = (2, 30)
target = torch.randint(low=1, high=C, size=(N, S), dtype=torch.long)
# 长度为N的张量,表示每个输入序列的实际长度
# 这里假设所有输入序列都是完整的T长度
input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.long)
# 长度为N的张量,表示每个目标序列的实际长度。
# 这里通过随机生成一个介于S_min和S之间的整数来模拟不同长度的目标序列
target_lengths = torch.randint(low=S_min, high=S, size=(N,), dtype=torch.long)
ctc_loss = nn.CTCLoss()
# input是模型的输出(对数概率),target是目标序列
# input_lengths和target_lengths分别指定了输入序列和目标序列的实际长度。
loss = ctc_loss(input, target, input_lengths, target_lengths)
loss.backward()
if __name__ == '__main__':
ctc_loss()
1.4 CRNN网络的简单实现
使用pytorch框架实现CRNN网络如下:
python
import torch.nn as nn
from torchinfo import summary
class CRNN(nn.Module):
def __init__(self, img_channel, img_height, img_width, num_class,
map_to_seq_hidden=64, rnn_hidden=256, leaky_relu=False):
super(CRNN, self).__init__()
self.cnn, (output_channel, output_height, output_width) = \
self._cnn_backbone(img_channel, img_height, img_width, leaky_relu)
self.map_to_seq = nn.Linear(output_channel * output_height, map_to_seq_hidden)
self.rnn1 = nn.LSTM(map_to_seq_hidden, rnn_hidden, bidirectional=True)
# 如果接双向lstm输出,则要 *2,固定用法
self.rnn2 = nn.LSTM(2 * rnn_hidden, rnn_hidden, bidirectional=True)
self.dense = nn.Linear(2 * rnn_hidden, num_class)
# CNN主干网络
def _cnn_backbone(self, img_channel, img_height, img_width, leaky_relu):
assert img_height % 16 == 0
assert img_width % 4 == 0
# 超参设置
channels = [img_channel, 64, 128, 256, 256, 512, 512, 512]
kernel_sizes = [3, 3, 3, 3, 3, 3, 2]
strides = [1, 1, 1, 1, 1, 1, 1]
paddings = [1, 1, 1, 1, 1, 1, 0]
cnn = nn.Sequential()
def conv_relu(i, batch_norm=False):
# shape of input: (batch, input_channel, height, width)
input_channel = channels[i]
output_channel = channels[i + 1]
cnn.add_module(
f'conv{i}',
nn.Conv2d(input_channel, output_channel, kernel_sizes[i], strides[i], paddings[i])
)
if batch_norm:
cnn.add_module(f'batchnorm{i}', nn.BatchNorm2d(output_channel))
relu = nn.LeakyReLU(0.2, inplace=True) if leaky_relu else nn.ReLU(inplace=True)
cnn.add_module(f'relu{i}', relu)
# size of image: (channel, height, width) = (img_channel, img_height, img_width)
conv_relu(0)
cnn.add_module('pooling0', nn.MaxPool2d(kernel_size=2, stride=2))
# (64, img_height // 2, img_width // 2)
conv_relu(1)
cnn.add_module('pooling1', nn.MaxPool2d(kernel_size=2, stride=2))
# (128, img_height // 4, img_width // 4)
conv_relu(2)
conv_relu(3)
cnn.add_module(
'pooling2',
nn.MaxPool2d(kernel_size=(2, 1))
) # (256, img_height // 8, img_width // 4)
conv_relu(4, batch_norm=True)
conv_relu(5, batch_norm=True)
cnn.add_module(
'pooling3',
nn.MaxPool2d(kernel_size=(2, 1))
) # (512, img_height // 16, img_width // 4)
conv_relu(6) # (512, img_height // 16 - 1, img_width // 4 - 1)
output_channel, output_height, output_width = \
channels[-1], img_height // 16 - 1, img_width // 4 - 1
return cnn, (output_channel, output_height, output_width)
# CNN+LSTM前向计算
def forward(self, images):
# shape of images: (batch, channel, height, width)
conv = self.cnn(images)
batch, channel, height, width = conv.size()
conv = conv.view(batch, channel * height, width)
conv = conv.permute(2, 0, 1) # (width, batch, feature)
# 卷积接全连接。全连接输入形状为(width, batch, channel*height),
# 输出形状为(width, batch, hidden_layer),分别对应时序长度,batch,特征数,符合LSTM输入要求
seq = self.map_to_seq(conv)
recurrent, _ = self.rnn1(seq)
recurrent, _ = self.rnn2(recurrent)
output = self.dense(recurrent)
return output # shape: (seq_len, batch, num_class)
if __name__ == '__main__':
net = CRNN(img_channel=3, img_height=32, img_width=100, num_class=37)
summary(net, input_size=(1, 3, 32, 100))
shell
==========================================================================================
Layer (type:depth-idx) Output Shape Param #
==========================================================================================
CRNN [24, 1, 37] --
├─Sequential: 1-1 [1, 512, 1, 24] --
│ └─Conv2d: 2-1 [1, 64, 32, 100] 1,792
│ └─ReLU: 2-2 [1, 64, 32, 100] --
│ └─MaxPool2d: 2-3 [1, 64, 16, 50] --
│ └─Conv2d: 2-4 [1, 128, 16, 50] 73,856
│ └─ReLU: 2-5 [1, 128, 16, 50] --
│ └─MaxPool2d: 2-6 [1, 128, 8, 25] --
│ └─Conv2d: 2-7 [1, 256, 8, 25] 295,168
│ └─ReLU: 2-8 [1, 256, 8, 25] --
│ └─Conv2d: 2-9 [1, 256, 8, 25] 590,080
│ └─ReLU: 2-10 [1, 256, 8, 25] --
│ └─MaxPool2d: 2-11 [1, 256, 4, 25] --
│ └─Conv2d: 2-12 [1, 512, 4, 25] 1,180,160
│ └─BatchNorm2d: 2-13 [1, 512, 4, 25] 1,024
│ └─ReLU: 2-14 [1, 512, 4, 25] --
│ └─Conv2d: 2-15 [1, 512, 4, 25] 2,359,808
│ └─BatchNorm2d: 2-16 [1, 512, 4, 25] 1,024
│ └─ReLU: 2-17 [1, 512, 4, 25] --
│ └─MaxPool2d: 2-18 [1, 512, 2, 25] --
│ └─Conv2d: 2-19 [1, 512, 1, 24] 1,049,088
│ └─ReLU: 2-20 [1, 512, 1, 24] --
├─Linear: 1-2 [24, 1, 64] 32,832
├─LSTM: 1-3 [24, 1, 512] 659,456
├─LSTM: 1-4 [24, 1, 512] 1,576,960
├─Linear: 1-5 [24, 1, 37] 18,981
==========================================================================================
Total params: 7,840,229
Trainable params: 7,840,229
Non-trainable params: 0
Total mult-adds (M): 675.96
==========================================================================================
Input size (MB): 0.04
Forward/backward pass size (MB): 5.23
Params size (MB): 31.36
Estimated Total Size (MB): 36.63
==========================================================================================
2 CRNN在icdar15数据集上的微调(paddleocr)
-
我们这里使用百度开源的paddleocr来对CRNN模型有更深的认识:
- paddleocr地址:https://github.com/PaddlePaddle/PaddleOCR
- paddleocr中集成的算法列表:https://github.com/PaddlePaddle/PaddleOCR/blob/main/docs/algorithm/overview.md
shell
# git拉取下来,解压
git clone https://gitee.com/paddlepaddle/PaddleOCR
# 然后进入PaddleOCR目录,安装PaddleOCR第三方依赖
pip install -r requirements.txt
-
我们在
paddleocr/tests
目录下,创建py文件进行下图的测试
python
from paddleocr import PaddleOCR
ocr = PaddleOCR()
# 默认会下载官方训练好的模型,并将下载的模型放到用户目录下(我这里是:C:\\Users\\Undo/.paddleocr)
result = ocr.ocr(img='./rec_img.png'
, det=False # 文本检测器,默认算法为DBNet
, rec=True # 方向分类器
, cls=True # 文本识别,默认为CRNN模型
)
for line in result:
print(line)
shell
[('实力活力', 0.7939199805259705)]
2.1 CRNN网络的搭建
2.1.1 backbone
-
MobileNet模型是Google针对手机等嵌入式设备提出的一种轻量化深度神经网络,使用的核心思想是:
深度可分离卷积
。MobileNet系列中主要包括MobileNet V1、MobileNet V2、MobileNet V3。-
MobileNetV3 有两个版本,MobileNetV3-Small 与 MobileNetV3-Large 分别对应对计算和存储要求低和高的版本。
-
MobileNetV3继续采用了轻量级的深度可分离卷积和残差块等结构,依然是由多个模块组成,但是每个模块得到了优化和升级,包括瓶颈结构、SE模块和NL模块(如下图)。
-
整体来说MobileNetV3有两大创新点:
-
互补搜索技术组合:由资源受限的NAS执行模块级搜索,NetAdapt执行局部搜索。
-
网络结构改进(如下图):将最后一步的平均池化层前移并移除最后一个卷积层,引入h-swish激活函数。
-
-
PaddleOCR 使用 MobileNetV3 作为骨干网络,
paddleocr/configs/rec/rec_icdar15_train.yml
中默认使用MobileNetV3-Large版本。同样,需要对MobileNetV3的结构做一些微调,主要对下采样进行修改。
yaml
# paddleocr/configs/rec/rec_icdar15_train.yml
Architecture:
model_type: rec
algorithm: CRNN # 使用CRNN模型
Transform:
Backbone:
name: MobileNetV3 # ppocr/modeling/backbones/rec_mobilenet_v3.py
scale: 0.5 # 这里进行了缩放
model_name: large # 骨干网络使用MobileNetV3的large版本
Neck:
name: SequenceEncoder # ppocr/modeling/necks/rnn.py
encoder_type: rnn
hidden_size: 96
Head:
name: CTCHead # ppocr/modeling/heads/rec_ctc_head.py
fc_decay: 0
Loss:
name: CTCLoss # ppocr/losses/rec_ctc_loss.py 损失计算
PostProcess:
name: CTCLabelDecode # ppocr/postprocess/rec_postprocess.py 后处理
Metric:
name: RecMetric # ppocr/metrics/rec_metric.py 指标评估
main_indicator: acc
python
# ppocr/modeling/backbones/rec_mobilenet_v3.py
from paddle import nn
from ppocr.modeling.backbones.det_mobilenet_v3 import (
ResidualUnit,
ConvBNLayer,
make_divisible,
)
__all__ = ["MobileNetV3"]
class MobileNetV3(nn.Layer):
def __init__(
self,
in_channels=3,
model_name="small",
scale=0.5,
large_stride=None,
small_stride=None,
disable_se=False,
**kwargs,
):
super(MobileNetV3, self).__init__()
self.disable_se = disable_se
if small_stride is None:
small_stride = [2, 2, 2, 2]
if large_stride is None:
large_stride = [1, 2, 2, 2]
assert isinstance(
large_stride, list
), "large_stride type must " "be list but got {}".format(type(large_stride))
assert isinstance(
small_stride, list
), "small_stride type must " "be list but got {}".format(type(small_stride))
assert (
len(large_stride) == 4
), "large_stride length must be " "4 but got {}".format(len(large_stride))
assert (
len(small_stride) == 4
), "small_stride length must be " "4 but got {}".format(len(small_stride))
if model_name == "large":
cfg = [
# k, exp, c, se, nl, s,
[3, 16, 16, False, "relu", large_stride[0]], # 不进行下采样
[3, 64, 24, False, "relu", (large_stride[1], 1)], # step 2 高下采样2倍,宽不变
[3, 72, 24, False, "relu", 1],
[5, 72, 40, True, "relu", (large_stride[2], 1)], # step 3 高下采样2倍,宽不变
[5, 120, 40, True, "relu", 1],
[5, 120, 40, True, "relu", 1],
[3, 240, 80, False, "hardswish", 1],
[3, 200, 80, False, "hardswish", 1],
[3, 184, 80, False, "hardswish", 1],
[3, 184, 80, False, "hardswish", 1],
[3, 480, 112, True, "hardswish", 1],
[3, 672, 112, True, "hardswish", 1],
[5, 672, 160, True, "hardswish", (large_stride[3], 1)], # step 4 高下采样2倍,宽不变
[5, 960, 160, True, "hardswish", 1],
[5, 960, 160, True, "hardswish", 1],
]
cls_ch_squeeze = 960
elif model_name == "small":
cfg = [
# k, exp, c, se, nl, s,
[3, 16, 16, True, "relu", (small_stride[0], 1)],
[3, 72, 24, False, "relu", (small_stride[1], 1)],
[3, 88, 24, False, "relu", 1],
[5, 96, 40, True, "hardswish", (small_stride[2], 1)],
[5, 240, 40, True, "hardswish", 1],
[5, 240, 40, True, "hardswish", 1],
[5, 120, 48, True, "hardswish", 1],
[5, 144, 48, True, "hardswish", 1],
[5, 288, 96, True, "hardswish", (small_stride[3], 1)],
[5, 576, 96, True, "hardswish", 1],
[5, 576, 96, True, "hardswish", 1],
]
cls_ch_squeeze = 576
else:
raise NotImplementedError(
"mode[" + model_name + "_model] is not implemented!"
)
supported_scale = [0.35, 0.5, 0.75, 1.0, 1.25]
assert (
scale in supported_scale
), "supported scales are {} but input scale is {}".format(
supported_scale, scale
)
inplanes = 16
# conv1
self.conv1 = ConvBNLayer(
in_channels=in_channels,
out_channels=make_divisible(inplanes * scale),
kernel_size=3,
stride=2, # step 1 高下采样2倍,宽下采样2倍
padding=1,
groups=1,
if_act=True,
act="hardswish",
)
i = 0
block_list = []
inplanes = make_divisible(inplanes * scale)
for k, exp, c, se, nl, s in cfg:
se = se and not self.disable_se
block_list.append(
ResidualUnit(
in_channels=inplanes,
mid_channels=make_divisible(scale * exp),
out_channels=make_divisible(scale * c),
kernel_size=k,
stride=s,
use_se=se,
act=nl,
)
)
inplanes = make_divisible(scale * c)
i += 1
self.blocks = nn.Sequential(*block_list)
self.conv2 = ConvBNLayer(
in_channels=inplanes,
out_channels=make_divisible(scale * cls_ch_squeeze),
kernel_size=1,
stride=1,
padding=0,
groups=1,
if_act=True,
act="hardswish",
)
self.pool = nn.MaxPool2D(kernel_size=2, stride=2, padding=0) # step 5 高下采样2倍,宽下采样2倍
self.out_channels = make_divisible(scale * cls_ch_squeeze)
def forward(self, x):
x = self.conv1(x)
x = self.blocks(x)
x = self.conv2(x)
x = self.pool(x)
return x
if __name__ == '__main__':
import paddle as torch
x = torch.rand((1, 3, 32, 100))
net = MobileNetV3(in_channels=3, model_name='large', scale=0.5)
# [1, 480, 1, 25]
print(net(x).shape)
2.1.2 neck
neck 部分将backbone输出的视觉特征图转换为1维向量输入送到 LSTM 网络中,输出序列特征
python
# ppocr/modeling/necks/rnn.py
class SequenceEncoder(nn.Layer):
def __init__(self, in_channels, encoder_type, hidden_size=48, **kwargs):
super(SequenceEncoder, self).__init__()
self.encoder_reshape = Im2Seq(in_channels)
self.out_channels = self.encoder_reshape.out_channels
......
def forward(self, x):
if self.encoder_type != "svtr":
# 1、neck部分将backbone输出的视觉特征图转换为1维向量
# (bs, channels, H, W) -> (bs, H * W, channels),即(bs, seq_len, embedding_dim)
# reshape后的x = (bs, 25, 480) , 480 = channels(960) * scale(0.5)
x = self.encoder_reshape(x)
# 2、转换后,送到LSTM网络中,输出序列特征
# self.encoder
# EncoderWithRNN(
# (lstm): LSTM(480, 96, num_layers=2
# (0): BiRNN(
# (cell_fw): LSTMCell(480, 96)
# (cell_bw): LSTMCell(480, 96)
# )
# (1): BiRNN(
# (cell_fw): LSTMCell(192, 96)
# (cell_bw): LSTMCell(192, 96)
# )
# )
# )
if not self.only_reshape:
x = self.encoder(x) # [bs, 25, 96 * 2]
return x
else:
x = self.encoder(x)
x = self.encoder_reshape(x)
return x
2.1.3 head
预测头部分由全连接层和softmax组成,用于计算序列特征时间步上的标签概率分布
python
class CTCHead(nn.Layer):
def __init__(
self,
in_channels,
out_channels,
fc_decay=0.0004,
mid_channels=None,
return_feats=False,
**kwargs,
):
super(CTCHead, self).__init__()
if mid_channels is None:
weight_attr, bias_attr = get_para_bias_attr(
l2_decay=fc_decay, k=in_channels
)
# Linear(in_features=192, out_features=37, dtype=float32)
self.fc = nn.Linear(
in_channels, out_channels, weight_attr=weight_attr, bias_attr=bias_attr
)
else:
......
self.out_channels = out_channels
self.mid_channels = mid_channels
self.return_feats = return_feats
def forward(self, x, targets=None):
if self.mid_channels is None:
# (bs, 25, 192) -> predicts (bs, 25, 37)
predicts = self.fc(x)
else:
x = self.fc1(x)
predicts = self.fc2(x)
if self.return_feats:
result = (x, predicts)
else:
result = predicts
if not self.training: # 非训练时,经过SoftMax,可以得到各时间步上的概率最大的预测结果
predicts = F.softmax(predicts, axis=2)
result = predicts
return result
2.2 数据集加载及模型训练
2.2.1 数据集的下载
提供一份处理过的icdar15数据集:
下载地址:https://pan.baidu.com/s/1VP2Y_IhxAUwQABDmbXrgIg 提取码: ek25
数据集应有如下文件结构:
|-train_data
|-ic15_data
|- rec_gt_train.txt
|- train
|- word_001.png
|- word_002.jpg
|- word_003.jpg
| ...
|- rec_gt_test.txt
|- test
|- word_001.png
|- word_002.jpg
|- word_003.jpg
| ...
其中txt文件里的内容如下:
" 图像文件名 图像标注信息 "
train/word_1.png Genaxis Theatre
train/word_2.png [06]
...
下载完数据集后,我们复制一份paddleocr/configs/rec/rec_icdar15_train.yml
文件到paddleocr\tests\configs进行修改:
yaml
......
Train: # 修改训练集的路径及其他相关信息
dataset:
name: SimpleDataSet
data_dir: D:/python/datas/cv/ic15_data
label_file_list: ["D:/python/datas/cv/ic15_data/rec_gt_train.txt"]
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- CTCLabelEncode: # Class handling label
- RecResizeImg:
image_shape: [3, 32, 100]
- KeepKeys:
keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
loader:
shuffle: True # 是否打乱
batch_size_per_card: 256 # 批次大小
drop_last: True # 最后1个批次是否删除
num_workers: 0 # 读取数据的进程数
use_shared_memory: False
Eval: # 修改验证集的路径及其他相关信息
dataset:
name: SimpleDataSet
data_dir: D:/python/datas/cv/ic15_data
label_file_list: ["D:/python/datas/cv/ic15_data/rec_gt_test.txt"]
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- CTCLabelEncode: # Class handling label
- RecResizeImg:
image_shape: [3, 32, 100]
- KeepKeys:
keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
loader:
shuffle: False
drop_last: False
batch_size_per_card: 256
num_workers: 0
use_shared_memory: False
2.2.2 模型的训练与预测
- 我这里不用命令行执行,在
paddleocr\tests
目录下创建一个py文件执行训练过程 - 通过下面的py文件,我们就可以愉快的查看源码了。
- 模型训练、评估细节,可参考官方文档:https://github.com/PaddlePaddle/PaddleOCR/blob/main/docs/ppocr/model_train/recognition.md
python
def train_rec():
from tools.train import program, set_seed, main
# 配置文件的源地址地址: paddleocr/configs/rec/rec_icdar15_train.yml
config, device, logger, vdl_writer = program.preprocess(is_train=True)
###############修改配置(也可在yml文件中修改)##################
# 加载预训练模型
# https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_none_bilstm_ctc_v2.0_train.tar
# 或者 https://paddleocr.bj.bcebos.com/PP-OCRv3/english/en_PP-OCRv3_rec_train.tar
config["Global"]["pretrained_model"] = r"D:\python\models\layout_ocr\rec_mv3_none_bilstm_ctc_v2.0_train\best_accuracy"
# 字典路径(这里只支持26个小写字母+10个数字)
config["Global"]["character_dict_path"] = r"D:\python\py_works\paddleocr\ppocr\utils\ic15_dict.txt"
# 评估频率
config["Global"]["eval_batch_step"] = [0, 200]
# log的打印频率
config["Global"]["print_batch_step"] = 50
# 训练的epochs
config["Global"]["epoch_num"] = 1
# 随机种子
seed = config["Global"]["seed"] if "seed" in config["Global"] else 1024
set_seed(seed)
###############模型训练##################
main(config, device, logger, vdl_writer, seed)
def infer_rec_img():
# 加载自己训练的模型
from tools.infer_rec import main, program
config, device, logger, vdl_writer = program.preprocess()
config["Global"]["use_gpu"] = False
config["Global"]["infer_img"] = r"D:/python/py_works/paddleocr/tests/rec_img_slow.png"
config["Global"]["checkpoints"] = r"D:\python\py_works\paddleocr\tests\output\rec\ic15\best_accuracy"
config["Global"]["character_dict_path"] = r"D:\python\py_works\paddleocr\ppocr\utils\ic15_dict.txt"
# 这里为了能python文件执行,加了add_config这个参数,源码中没有
main(add_config=(config, device, logger, vdl_writer))
if __name__ == '__main__':
# 模型训练
train_rec()
# 加载自己模型进行推理
# ppocr INFO: result: slow 0.863487958908081
infer_rec_img()
2.2.3 数据集的加载
python
# paddleocr/tools/train.py
def main(config, device, logger, vdl_writer, seed):
# init dist environment
if config["Global"]["distributed"]:
dist.init_parallel_env()
global_config = config["Global"]
# build dataloader
set_signal_handlers()
# 1、创建dataloader
train_dataloader = build_dataloader(config, "Train", device, logger, seed)
......
if config["Eval"]:
valid_dataloader = build_dataloader(config, "Eval", device, logger, seed)
else:
valid_dataloader = None
step_pre_epoch = len(train_dataloader)
# 2、后处理程序
# build post process
post_process_class = build_post_process(config["PostProcess"], global_config)
# 3、模型构建
# build model
.....
model = build_model(config["Architecture"])
use_sync_bn = config["Global"].get("use_sync_bn", False)
if use_sync_bn:
model = paddle.nn.SyncBatchNorm.convert_sync_batchnorm(model)
logger.info("convert_sync_batchnorm")
model = apply_to_static(model, config, logger)
# 4、构建损失函数
# build loss
loss_class = build_loss(config["Loss"])
# 5、构建优化器
# build optim
optimizer, lr_scheduler = build_optimizer(
config["Optimizer"],
epochs=config["Global"]["epoch_num"],
step_each_epoch=len(train_dataloader),
model=model,
)
# 6、创建评估函数
# build metric
eval_class = build_metric(config["Metric"])
......
# 7、加载预训练模型
# load pretrain model
pre_best_model_dict = load_model(
config, model, optimizer, config["Architecture"]["model_type"]
)
if config["Global"]["distributed"]:
model = paddle.DataParallel(model)
# 8、模型训练
# start train
program.train(
config,
train_dataloader,
valid_dataloader,
device,
model,
loss_class,
optimizer,
lr_scheduler,
post_process_class,
eval_class,
pre_best_model_dict,
logger,
step_pre_epoch,
vdl_writer,
scaler,
amp_level,
amp_custom_black_list,
amp_custom_white_list,
amp_dtype,
)
- 通过build_dataloader加载数据集
python
# paddleocr/ppocr/data/__init__.py
def build_dataloader(config, mode, device, logger, seed=None):
config = copy.deepcopy(config)
support_dict = [
"SimpleDataSet", # 配置文件中为SimpleDataSet
"LMDBDataSet",
"PGDataSet",
"PubTabDataSet",
"LMDBDataSetSR",
"LMDBDataSetTableMaster",
"MultiScaleDataSet",
"TextDetDataset",
"TextRecDataset",
"MSTextRecDataset",
"PubTabTableRecDataset",
"KieDataset",
"LaTeXOCRDataSet",
]
module_name = config[mode]["dataset"]["name"]
assert module_name in support_dict, Exception(
"DataSet only support {}".format(support_dict)
)
assert mode in ["Train", "Eval", "Test"], "Mode should be Train, Eval or Test."
# 1、创建dataset
dataset = eval(module_name)(config, mode, logger, seed)
......
# 2、创建data_loader
data_loader = DataLoader(
dataset=dataset,
batch_sampler=batch_sampler,
places=device,
num_workers=num_workers,
return_list=True,
use_shared_memory=use_shared_memory,
collate_fn=collate_fn,
)
return data_loader
-
SimpleDataSet中定义了数据的预处理:
-
1、通过ppocr.data.imaug.operators.DecodeImage 将读取的image(二进制数据)转换为numpy数组, 即(高度H、宽度W、通道数C)
-
2、通过ppocr.data.imaug.label_ops.CTCLabelEncode将label进行标签编码+one-hot编码(data["length"] + data["label"] + data["label_ace"])
-
3、通过ppocr.data.imaug.rec_img_aug.RecResizeImg将image进行缩放+归一化+padding, image shape=(3, 32, 100)
-
4、通过ppocr.data.imaug.operators.KeepKeys 仅仅将
['image', 'label', 'length']
保存到list中
-
python
class SimpleDataSet(Dataset):
......
def __getitem__(self, idx):
file_idx = self.data_idx_order_list[idx]
data_line = self.data_lines[file_idx]
try:
data_line = data_line.decode("utf-8")
substr = data_line.strip("\n").strip("\r").split(self.delimiter)
file_name = substr[0]
file_name = self._try_parse_filename_list(file_name)
label = substr[1]
img_path = os.path.join(self.data_dir, file_name)
data = {"img_path": img_path, "label": label}
if not os.path.exists(img_path):
raise Exception("{} does not exist!".format(img_path))
with open(data["img_path"], "rb") as f:
img = f.read()
data["image"] = img
data["ext_data"] = self.get_ext_data()
# 数据transform
# ppocr.data.imaug.operators.DecodeImage 1、将读取的image(二进制数据)转换为numpy数组, 即(高度H、宽度W、通道数C)
# ppocr.data.imaug.label_ops.CTCLabelEncode 2、将label进行标签编码+one-hot编码(data["length"] + data["label"] + data["label_ace"])
# ppocr.data.imaug.rec_img_aug.RecResizeImg 3、将image进行缩放+归一化+padding, image shape=(3, 32, 100)
# ppocr.data.imaug.operators.KeepKeys 4、仅仅将['image', 'label', 'length']保存到list中
outs = transform(data, self.ops)
except:
......
return outs
def __len__(self):
return len(self.data_idx_order_list)
2.2.4 后处理函数
通过build_post_process构建后处理,
python
# ppocr/postprocess/rec_postprocess.py 后处理
class CTCLabelDecode(BaseRecLabelDecode):
"""Convert between text-label and text-index"""
def __init__(self, character_dict_path=None, use_space_char=False, **kwargs):
super(CTCLabelDecode, self).__init__(character_dict_path, use_space_char)
def __call__(self, preds, label=None, return_word_box=False, *args, **kwargs):
if isinstance(preds, tuple) or isinstance(preds, list):
preds = preds[-1]
if isinstance(preds, paddle.Tensor):
preds = preds.numpy()
# 1、获取各个时间步上的最大索引值以及概率
preds_idx = preds.argmax(axis=2)
preds_prob = preds.max(axis=2)
# 2、将索引值进行解码(转换为文字)
text = self.decode(
preds_idx,
preds_prob,
is_remove_duplicate=True,
return_word_box=return_word_box,
)
if return_word_box:
for rec_idx, rec in enumerate(text):
wh_ratio = kwargs["wh_ratio_list"][rec_idx]
max_wh_ratio = kwargs["max_wh_ratio"]
rec[2][0] = rec[2][0] * (wh_ratio / max_wh_ratio)
if label is None:
# 3、推理时,直接返回text[('解码后的文字', 文字置信度的平均值)]
return text
# 4、训练时, 还需要将label进行解码
label = self.decode(label)
return text, label
def add_special_char(self, dict_character):
dict_character = ["blank"] + dict_character
return dict_character
- 核心函数就是BaseRecLabelDecode类中的decode函数
python
# ppocr/postprocess/rec_postprocess.py中的BaseRecLabelDecode类
def decode(
self,
text_index,
text_prob=None,
is_remove_duplicate=False,
return_word_box=False,
):
"""convert text-index into text-label."""
result_list = []
ignored_tokens = self.get_ignored_tokens() # 忽略tokens, 其中[0]代表ctc中的blank位
batch_size = len(text_index)
for batch_idx in range(batch_size):
selection = np.ones(len(text_index[batch_idx]), dtype=bool)
if is_remove_duplicate:
# 1、合并blank之间相同的字符,即【当前位置索引】和【下一位置索引】不相同的就保留
selection[1:] = text_index[batch_idx][1:] != text_index[batch_idx][:-1]
for ignored_token in ignored_tokens:
selection &= text_index[batch_idx] != ignored_token
# 2、将解码的结果存在char_list中
char_list = [
self.character[text_id] for text_id in text_index[batch_idx][selection]
]
if text_prob is not None:
conf_list = text_prob[batch_idx][selection]
else:
conf_list = [1] * len(selection)
if len(conf_list) == 0:
conf_list = [0]
# 3、char_list合并为字符串
text = "".join(char_list)
if self.reverse: # for arabic rec
text = self.pred_reverse(text)
if return_word_box:
......
else:
# 置信度为每个识别的字符置信度的平均值
result_list.append((text, np.mean(conf_list).tolist()))
return result_list
2.2.5 CTC Loss
- 通过build_loss函数构建CTC Loss
- CRNN 模型的损失函数为 CTC loss, 飞桨集成了常用的 Loss 函数,只需调用实现即可
python
# paddleocr/ppocr/losses/rec_ctc_loss.py
class CTCLoss(nn.Layer):
def __init__(self, use_focal_loss=False, **kwargs):
super(CTCLoss, self).__init__()
# blank 是 ctc 的无意义连接符
self.loss_func = nn.CTCLoss(blank=0, reduction="none")
self.use_focal_loss = use_focal_loss
def forward(self, predicts, batch):
if isinstance(predicts, (list, tuple)):
predicts = predicts[-1]
# 转置模型 head 层的预测结果,沿channel层排列
# (bs, 25, 37) -> (25, bs, 37)
predicts = predicts.transpose((1, 0, 2))
N, B, _ = predicts.shape
# [N, N, ..., N]一共bs个,每个长度都为N
preds_lengths = paddle.to_tensor(
[N] * B, dtype="int64", place=paddle.CPUPlace()
)
# batch一个list
# batch[0]为bs个预处理好image Tensor
# batch[1]为bs个编码好的token序列,即label,shape = (bs, seq_len)
# batch[2]为bs个token序列的实际长度(因为有填充)
labels = batch[1].astype("int32")
label_lengths = batch[2].astype("int64")
# 计算损失函数
loss = self.loss_func(predicts, labels, preds_lengths, label_lengths)
if self.use_focal_loss:
weight = paddle.exp(-loss)
weight = paddle.subtract(paddle.to_tensor([1.0]), weight)
weight = paddle.square(weight)
loss = paddle.multiply(loss, weight)
loss = loss.mean()
return {"loss": loss}
其他训练细节诸如:构建优化器、创建评估函数、加载预训练模型、模型训练等,大家可以查看源码,不再赘述。