目录
- [OCR 基础理论](#OCR 基础理论)
- 文本检测
- 文本识别
- [端到端 OCR](#端到端 OCR)
- 文档版面分析
- 表格识别
- 文档理解
- [多语言 OCR](#多语言 OCR)
- 评估指标与数据集
- 应用与前沿
1. OCR 基础理论
1.1 什么是 OCR
OCR (Optical Character Recognition) 光学字符识别:
目标: 将图像中的文字转换为可编辑的文本
┌─────────────────────────────────────────────────────────────────┐
│ │
│ 输入图像 输出文本 │
│ ┌──────────────────┐ ┌──────────────────┐ │
│ │ 你好世界 │ OCR │ "你好世界" │ │
│ │ Hello World │ ──────► │ "Hello World" │ │
│ │ 12345 │ │ "12345" │ │
│ └──────────────────┘ └──────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────┘
流程:
1. 文本检测: 定位文字区域
2. 文本识别: 识别文字内容
3. 后处理: 纠错、格式化
1.2 OCR 的挑战
OCR 面临的挑战:
1. 文本检测挑战:
├── 多方向: 水平、竖直、旋转文字
├── 多形状: 弯曲、不规则文字
├── 多尺度: 大字、小字
├── 密集: 文字紧密排列
└── 遮挡: 部分文字被遮挡
2. 文本识别挑战:
├── 字体多样: 印刷体、手写体
├── 背景复杂: 噪声、纹理
├── 模糊: 运动模糊、失焦
├── 光照: 反光、阴影
└── 多语言: 不同语言混合
3. 文档理解挑战:
├── 版面复杂: 多栏、表格、图文混排
├── 语义理解: 理解文字含义
└── 结构提取: 提取文档结构
1.3 OCR 技术演进
OCR 技术演进:
┌─────────────────────────────────────────────────────────────────┐
│ OCR 技术发展时间线 │
├─────────────────────────────────────────────────────────────────┤
│ │
│ 传统方法 (2010 之前): │
│ ├── 二值化 + 连通域分析 │
│ ├── 模板匹配 │
│ └── 特征工程 + 分类器 │
│ │
│ 深度学习时代 (2015+): │
│ ├── CNN 特征提取 │
│ ├── RNN 序列建模 │
│ └── CTC/Attention 解码 │
│ │
│ 端到端时代 (2020+): │
│ ├── Transformer 架构 │
│ ├── 预训练模型 │
│ └── 多模态文档理解 │
│ │
└─────────────────────────────────────────────────────────────────┘
2. 文本检测
2.1 文本检测概述
文本检测 (Text Detection):
目标: 定位图像中的文字区域
输出格式:
1. 水平框: [x, y, w, h]
2. 旋转框: [x, y, w, h, θ]
3. 多边形: [(x₁,y₁), (x₂,y₂), ..., (xₙ,yₙ)]
方法分类:
1. 基于回归: 直接回归边界框
2. 基于分割: 像素级预测文字区域
3. 混合方法: 结合回归和分割
2.2 CTPN
论文: "Detecting Text in Natural Image with Connectionist Text Proposal Network"
(Tian et al., 2016)
核心思想:
将文本检测转化为一系列小文本块的检测和连接
流程:
1. 检测小文本块 (宽度固定,高度可变)
2. 使用 RNN 建模文本块之间的关系
3. 连接文本块形成文本行
import torch
import torch.nn as nn
class CTPN(nn.Module):
"""
CTPN (Connectionist Text Proposal Network)
核心思想:
检测小文本块,然后连接成文本行
理论:
文本 = 一系列小文本块的序列
每个文本块宽度固定,高度自适应
使用 RNN 建模文本块之间的关系
"""
def __init__(self):
super().__init__()
# VGG 特征提取
self.backbone = nn.Sequential(
nn.Conv2d(3, 64, 3, padding=1),
nn.ReLU(),
nn.Conv2d(64, 64, 3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(64, 128, 3, padding=1),
nn.ReLU(),
nn.Conv2d(128, 128, 3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(128, 256, 3, padding=1),
nn.ReLU(),
nn.Conv2d(256, 256, 3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2)
)
# RNN (建模文本块序列)
self.rnn = nn.LSTM(256 * 3, 128, bidirectional=True, batch_first=True)
# 预测头
self.cls_head = nn.Linear(256, 2) # 文本/非文本
self.reg_head = nn.Linear(256, 3) # 偏移量 (dy, dh, side_refinement)
def forward(self, x):
"""
x: [B, C, H, W]
"""
# 特征提取
features = self.backbone(x) # [B, 256, H/8, W/8]
# 按列提取特征 (每个位置是一个文本块候选)
B, C, H, W = features.shape
features = features.permute(0, 3, 2, 1) # [B, W, H, C]
features = features.reshape(B * W, H, C) # [B*W, H, C]
# RNN 处理
rnn_out, _ = self.rnn(features)
# 预测
cls = self.cls_head(rnn_out) # [B*W, H, 2]
reg = self.reg_head(rnn_out) # [B*W, H, 3]
return cls, reg
"""
CTPN 的理论:
1. 小文本块检测:
每个小块宽度固定 (如 16 像素)
高度自适应 (垂直回归)
2. RNN 建模:
文本块从左到右排列
RNN 捕获文本块之间的关系
3. 连接策略:
相邻文本块的置信度和距离
连接形成文本行
优势: 擅长检测水平文本
局限: 难以检测弯曲文本
"""
2.3 EAST
论文: "EAST: An Efficient and Accurate Scene Text Detector" (Zhou et al., 2017)
核心思想:
单阶段检测器,直接预测文本框
输出:
1. 分数图: 每个位置是否是文本
2. 几何图: 文本框的几何信息 (距离到四边的距离)
优势:
- 速度快 (单阶段)
- 支持多方向文本
class EAST(nn.Module):
"""
EAST (Efficient and Accurate Scene Text Detector)
单阶段文本检测器
理论:
直接预测文本框的几何信息
分数图 + 几何图
支持多方向文本
"""
def __init__(self):
super().__init__()
# 特征提取 (U-Net 结构)
self.encoder = ResNetEncoder()
self.decoder = UNetDecoder()
# 预测头
self.score_head = nn.Sequential(
nn.Conv2d(32, 1, 1),
nn.Sigmoid()
)
self.geo_head = nn.Sequential(
nn.Conv2d(32, 4, 1), # 距离到四边的距离
nn.Sigmoid()
)
self.angle_head = nn.Sequential(
nn.Conv2d(32, 1, 1),
nn.Tanh()
)
def forward(self, x):
"""
x: [B, C, H, W]
返回:
score: [B, 1, H/4, W/4] 文本分数
geo: [B, 4, H/4, W/4] 几何信息
angle: [B, 1, H/4, W/4] 角度
"""
features = self.encoder(x)
features = self.decoder(features)
score = self.score_head(features)
geo = self.geo_head(features)
angle = self.angle_head(features)
return score, geo, angle
"""
EAST 的几何表示:
对于每个像素,预测到文本框四边的距离:
d_top, d_right, d_bottom, d_left
加上旋转角度 θ
可以重建旋转矩形:
4 个顶点坐标
优势:
支持任意方向的文本
计算高效
"""
2.4 DBNet
论文: "Real-time Scene Text Detection with Differentiable Binarization"
(Liao et al., 2020)
核心创新:
可微分二值化 (Differentiable Binarization)
传统方法: 固定阈值二值化
DBNet: 自适应阈值,端到端学习
class DBNet(nn.Module):
"""
DBNet (Differentiable Binarization Network)
核心创新:
可微分二值化
自适应阈值
理论:
传统: 固定阈值二值化 → 不够灵活
DBNet: 学习自适应阈值 → 更准确
"""
def __init__(self):
super().__init__()
# 特征提取
self.backbone = ResNetEncoder()
# 特征融合 (FPN)
self.fpn = FeaturePyramidNetwork()
# 预测头
self.prob_head = nn.Conv2d(256, 1, 3, padding=1) # 概率图
self.thresh_head = nn.Conv2d(256, 1, 3, padding=1) # 阈值图
def forward(self, x):
# 特征提取
features = self.backbone(x)
# 特征融合
fused = self.fpn(features)
# 预测
prob = torch.sigmoid(self.prob_head(fused))
thresh = torch.sigmoid(self.thresh_head(fused))
# 可微分二值化
binary = self.differentiable_binarize(prob, thresh)
return prob, thresh, binary
def differentiable_binarize(self, prob, thresh, k=50):
"""
可微分二值化
Binarize(p, t) = 1 / (1 + exp(-k * (p - t)))
当 k → ∞ 时,退化为标准二值化
但 k 有限时,可以计算梯度
"""
return torch.sigmoid(k * (prob - thresh))
"""
DBNet 的优势:
1. 自适应阈值:
每个像素有独立的阈值
适应不同的文本和背景
2. 端到端训练:
二值化操作可微分
可以端到端优化
3. 实时性:
单阶段检测
速度快
"""
3. 文本识别
3.1 文本识别概述
文本识别 (Text Recognition):
目标: 识别裁剪后的文本图像中的文字
输入: 文本区域图像 [H, W]
输出: 文本字符串 "Hello"
方法:
1. CTC-based: CTC 解码
2. Attention-based: 注意力解码
3. Seq2Seq: 序列到序列
3.2 CRNN
论文: "An End-to-End Trainable Neural Network for Image-based Sequence Recognition
and Its Application to Scene Text Recognition" (Shi et al., 2017)
核心思想:
CNN 提取特征 + RNN 建模序列 + CTC 解码
┌─────────────────────────────────────────────────────────────┐
│ │
│ 输入图像 → CNN 特征 → 特征序列 → RNN → CTC 解码 → 文本 │
│ │
│ [H, W] [C, H', W'] [T, D] [T, num_classes] │
│ │
└─────────────────────────────────────────────────────────────┘
class CRNN(nn.Module):
"""
CRNN (Convolutional Recurrent Neural Network)
CNN + RNN + CTC
理论:
CNN: 提取图像特征
RNN: 建模序列依赖
CTC: 解码为文本序列
"""
def __init__(self, num_classes, hidden_size=256):
super().__init__()
# CNN 特征提取
self.cnn = nn.Sequential(
nn.Conv2d(1, 64, 3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2, 2),
nn.Conv2d(64, 128, 3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2, 2),
nn.Conv2d(128, 256, 3, padding=1),
nn.ReLU(),
nn.Conv2d(256, 256, 3, padding=1),
nn.ReLU(),
nn.MaxPool2d((2, 1), (2, 1)),
nn.Conv2d(256, 512, 3, padding=1),
nn.BatchNorm2d(512),
nn.ReLU(),
nn.Conv2d(512, 512, 3, padding=1),
nn.BatchNorm2d(512),
nn.ReLU(),
nn.MaxPool2d((2, 1), (2, 1)),
nn.Conv2d(512, 512, 2, padding=0),
nn.BatchNorm2d(512),
nn.ReLU()
)
# RNN 序列建模
self.rnn = nn.LSTM(512, hidden_size, bidirectional=True, num_layers=2, batch_first=True)
# 输出层
self.fc = nn.Linear(hidden_size * 2, num_classes)
def forward(self, x):
"""
x: [B, 1, H, W] 文本图像
返回: [B, T, num_classes] 每个时间步的类别概率
"""
# CNN 特征提取
conv = self.cnn(x) # [B, 512, 1, W']
# 重塑为序列
B, C, H, W = conv.shape
conv = conv.squeeze(2) # [B, 512, W']
conv = conv.permute(0, 2, 1) # [B, W', 512]
# RNN 序列建模
rnn_out, _ = self.rnn(conv) # [B, W', hidden*2]
# 输出
output = self.fc(rnn_out) # [B, W', num_classes]
return output
"""
CRNN 的 CTC 解码:
CTC (Connectionist Temporal Classification):
处理输入和输出不对齐的问题
CTC 空间:
输出序列长度 T 可能大于实际文本长度
引入空白符 blank
解码:
greedy: 每步取最大概率
beam search: 保留多个候选
例:
输出: [b, H, H, e, l, l, l, o, b]
解码: "Hello" (去重 + 去空白)
"""
3.3 Attention-based 识别
基于注意力的文本识别:
理论:
使用注意力机制自动对齐
每步关注图像的不同区域
优势:
- 不需要预定义对齐
- 可以处理不规则文本
- 更灵活
class AttentionRecognizer(nn.Module):
"""
基于注意力的文本识别器
理论:
使用注意力机制自动对齐
每步关注图像的不同区域
decoder_t = Attention(query_t, keys, values)
query_t = Decoder(hidden_t-1, prev_char)
"""
def __init__(self, num_classes, hidden_dim=256):
super().__init__()
# CNN 编码器
self.encoder = nn.Sequential(
nn.Conv2d(3, 64, 3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2, 2),
nn.Conv2d(64, 128, 3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2, 2),
nn.Conv2d(128, 256, 3, padding=1),
nn.ReLU()
)
# 注意力解码器
self.embedding = nn.Embedding(num_classes, hidden_dim)
self.rnn = nn.LSTMCell(hidden_dim + 256, hidden_dim)
self.attention = AttentionModule(256, hidden_dim)
self.classifier = nn.Linear(hidden_dim, num_classes)
def forward(self, image, target=None, max_len=25):
"""
image: [B, C, H, W]
target: [B, T] (训练时使用 teacher forcing)
"""
# 编码
features = self.encoder(image) # [B, 256, H', W']
B, C, H, W = features.shape
features = features.view(B, C, -1).permute(0, 2, 1) # [B, H'*W', 256]
# 解码
hidden = (torch.zeros(B, 256), torch.zeros(B, 256))
input_char = torch.zeros(B, dtype=torch.long)
outputs = []
for t in range(max_len):
# 嵌入
emb = self.embedding(input_char) # [B, hidden_dim]
# 注意力
context = self.attention(features, hidden[0]) # [B, 256]
# RNN
rnn_input = torch.cat([emb, context], dim=1)
hidden = self.rnn(rnn_input, hidden)
# 分类
output = self.classifier(hidden[0])
outputs.append(output)
# 下一个输入
if target is not None:
input_char = target[:, t] # Teacher forcing
else:
input_char = output.argmax(dim=1)
return torch.stack(outputs, dim=1)
class AttentionModule(nn.Module):
"""注意力模块"""
def __init__(self, encoder_dim, decoder_dim):
super().__init__()
self.query = nn.Linear(decoder_dim, encoder_dim)
self.key = nn.Linear(encoder_dim, encoder_dim)
self.value = nn.Linear(encoder_dim, encoder_dim)
self.scale = encoder_dim ** 0.5
def forward(self, features, hidden):
"""
features: [B, N, encoder_dim]
hidden: [B, decoder_dim]
"""
Q = self.query(hidden).unsqueeze(1) # [B, 1, encoder_dim]
K = self.key(features) # [B, N, encoder_dim]
V = self.value(features) # [B, N, encoder_dim]
attn = torch.bmm(Q, K.transpose(1, 2)) / self.scale # [B, 1, N]
attn = torch.softmax(attn, dim=-1)
context = torch.bmm(attn, V).squeeze(1) # [B, encoder_dim]
return context
"""
Attention vs CTC:
CTC:
假设输出独立
解码简单
适合规则文本
Attention:
自动对齐
可以处理不规则文本
更灵活
实践中,两者常结合使用
"""
基于 Transformer 的文本识别:
理论:
使用 Transformer 替代 RNN
并行计算,更高效
自注意力建模全局依赖
class TransformerRecognizer(nn.Module):
"""
Transformer 文本识别器
理论:
CNN 提取视觉特征
Transformer 解码为文本序列
优势:
- 并行计算
- 全局依赖
- 更强的建模能力
"""
def __init__(self, num_classes, d_model=256, nhead=8, num_layers=6):
super().__init__()
# CNN 编码器
self.cnn = nn.Sequential(
nn.Conv2d(3, 64, 3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2, 2),
nn.Conv2d(64, 128, 3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2, 2),
nn.Conv2d(128, d_model, 3, padding=1),
nn.ReLU()
)
# 位置编码
self.pos_encoding = PositionalEncoding(d_model)
# Transformer 解码器
decoder_layer = nn.TransformerDecoderLayer(d_model, nhead, d_model * 4)
self.decoder = nn.TransformerDecoder(decoder_layer, num_layers)
# 输出层
self.fc = nn.Linear(d_model, num_classes)
# 嵌入
self.embedding = nn.Embedding(num_classes, d_model)
def forward(self, image, target=None):
"""
image: [B, C, H, W]
target: [B, T] (训练时)
"""
# CNN 特征
features = self.cnn(image) # [B, d_model, H', W']
B, C, H, W = features.shape
features = features.view(B, C, -1).permute(2, 0, 1) # [N, B, d_model]
# 位置编码
features = self.pos_encoding(features)
# 解码
if target is not None:
# 训练时
tgt = self.embedding(target).permute(1, 0, 2) # [T, B, d_model]
tgt = self.pos_encoding(tgt)
# 因果掩码
tgt_mask = nn.Transformer.generate_square_subsequent_mask(tgt.shape[0])
output = self.decoder(tgt, features, tgt_mask=tgt_mask)
output = self.fc(output).permute(1, 0, 2) # [B, T, num_classes]
else:
# 推理时 (自回归)
output = self.autoregressive_decode(features)
return output
"""
Transformer 识别的优势:
1. 并行计算:
训练时可以并行解码
2. 全局依赖:
自注意力建模全局关系
3. 更强建模:
比 RNN 更强的序列建模能力
"""
4. 端到端 OCR
4.1 端到端 OCR 概述
端到端 OCR:
目标: 同时检测和识别文本
输入: 完整图像
输出: 文本位置 + 文本内容
方法:
1. 两阶段: 先检测,再识别
2. 单阶段: 同时检测和识别
4.2 Mask TextSpotter
Mask TextSpotter:
基于实例分割的端到端 OCR
流程:
1. 检测文本区域 (Mask R-CNN)
2. 提取文本特征
3. 识别文本内容
class MaskTextSpotter(nn.Module):
"""
Mask TextSpotter
基于实例分割的端到端 OCR
理论:
使用 Mask R-CNN 检测文本
同时进行文本识别
"""
def __init__(self, num_classes, num_chars):
super().__init__()
# Mask R-CNN 检测
self.detector = MaskRCNN(num_classes)
# 文本识别器
self.recognizer = AttentionRecognizer(num_chars)
def forward(self, images):
# 检测
detections, masks = self.detector(images)
# 识别每个检测到的文本
texts = []
for det in detections:
# 裁剪文本区域
text_image = self.crop_text(images, det)
# 识别
text = self.recognizer(text_image)
texts.append(text)
return detections, texts
"""
端到端 OCR 的优势:
1. 联合优化:
检测和识别联合训练
相互促进
2. 信息共享:
检测特征帮助识别
识别反馈帮助检测
"""
5. 文档版面分析
5.1 版面分析概述
文档版面分析 (Document Layout Analysis):
目标: 识别文档中的不同区域
区域类型:
- 文本块
- 标题
- 图片
- 表格
- 列表
- 页眉/页脚
方法:
1. 目标检测: 检测不同区域
2. 语义分割: 像素级分类
3. 实例分割: 区分不同实例
5.2 LayoutLM
论文: "LayoutLM: Pre-training of Text and Layout for Document Image Understanding"
(Xu et al., 2020)
核心创新:
联合建模文本、布局和图像
输入:
- 文本: OCR 识别的文字
- 布局: 文字的位置坐标
- 图像: 文档图像特征
预训练:
- 掩码语言模型 (MLM)
- 掩码图像模型 (MIM)
class LayoutLM(nn.Module):
"""
LayoutLM
联合建模文本、布局和图像
理论:
文本: 语义信息
布局: 空间位置
图像: 视觉特征
三者融合 → 文档理解
"""
def __init__(self, vocab_size, max_position_embeddings=512):
super().__init__()
# 文本嵌入
self.text_embedding = nn.Embedding(vocab_size, 768)
# 位置嵌入 (2D 布局)
self.x_position_embedding = nn.Embedding(max_position_embeddings, 768)
self.y_position_embedding = nn.Embedding(max_position_embeddings, 768)
self.h_position_embedding = nn.Embedding(max_position_embeddings, 768)
self.w_position_embedding = nn.Embedding(max_position_embeddings, 768)
# 图像嵌入
self.image_embedding = nn.Linear(768, 768)
# Transformer
self.transformer = nn.TransformerEncoder(
nn.TransformerEncoderLayer(768, 12, 3072),
num_layers=12
)
def forward(self, text_ids, bbox, image_features):
"""
text_ids: [B, N] 文本 token IDs
bbox: [B, N, 4] 边界框坐标 (x, y, x+w, y+h)
image_features: [B, 768] 图像特征
"""
# 文本嵌入
text_emb = self.text_embedding(text_ids)
# 布局嵌入
x_emb = self.x_position_embedding(bbox[:, :, 0])
y_emb = self.y_position_embedding(bbox[:, :, 1])
w_emb = self.w_position_embedding(bbox[:, :, 2] - bbox[:, :, 0])
h_emb = self.h_position_embedding(bbox[:, :, 3] - bbox[:, :, 1])
layout_emb = x_emb + y_emb + w_emb + h_emb
# 图像嵌入
image_emb = self.image_embedding(image_features).unsqueeze(1)
# 融合
embeddings = text_emb + layout_emb
embeddings = torch.cat([image_emb, embeddings], dim=1)
# Transformer
output = self.transformer(embeddings)
return output
"""
LayoutLM 的理论:
1. 多模态融合:
文本 + 布局 + 图像
2. 2D 位置编码:
使用 (x, y, w, h) 编码
保留空间信息
3. 预训练任务:
MLM: 预测被掩码的文字
MIM: 预测被掩码的图像区域
"""
6. 表格识别
6.1 表格识别概述
表格识别 (Table Recognition):
目标: 从文档图像中提取表格结构和内容
子任务:
1. 表格检测: 定位表格区域
2. 结构识别: 识别行、列、单元格
3. 内容提取: 提取单元格内的文字
挑战:
- 复杂表格结构
- 合并单元格
- 不规则表格
6.2 表格结构识别
class TableStructureRecognizer(nn.Module):
"""
表格结构识别器
识别表格的行、列、单元格结构
理论:
将表格结构识别转化为:
1. 行检测
2. 列检测
3. 单元格分割
"""
def __init__(self):
super().__init__()
# 特征提取
self.backbone = nn.Sequential(
nn.Conv2d(3, 64, 3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(64, 128, 3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2)
)
# 行检测
self.row_head = nn.Sequential(
nn.Conv2d(128, 64, 3, padding=1),
nn.ReLU(),
nn.Conv2d(64, 1, 1),
nn.Sigmoid()
)
# 列检测
self.col_head = nn.Sequential(
nn.Conv2d(128, 64, 3, padding=1),
nn.ReLU(),
nn.Conv2d(64, 1, 1),
nn.Sigmoid()
)
def forward(self, image):
"""
image: [B, C, H, W]
"""
features = self.backbone(image)
rows = self.row_head(features)
cols = self.col_head(features)
return rows, cols
"""
表格识别的方法:
1. 基于规则:
检测横线和竖线
简单但不鲁棒
2. 基于学习:
训练模型检测行列
更鲁棒
3. 端到端:
直接输出 HTML/LaTeX
TableFormer 等
"""
7. 文档理解
7.1 文档理解概述
文档理解 (Document Understanding):
目标: 理解文档的语义内容
任务:
1. 文档分类: 文档类型分类
2. 信息提取: 提取关键信息
3. 问答: 回答关于文档的问题
4. 摘要: 生成文档摘要
挑战:
- 多模态: 文本 + 图像 + 布局
- 长文档: 大量文本
- 复杂结构: 表格、列表、嵌套
论文: "DocFormer: Multi-modal Transformer for Document Understanding"
(Appalaraju et al., 2021)
核心创新:
多模态 Transformer
联合文本、视觉和空间特征
7.3 文档问答
class DocumentQA(nn.Module):
"""
文档问答模型
输入: 文档图像 + 问题
输出: 答案
理论:
理解文档内容
回答关于文档的问题
"""
def __init__(self, vocab_size):
super().__init__()
# 视觉编码器
self.visual_encoder = nn.Sequential(
nn.Conv2d(3, 64, 3, padding=1),
nn.ReLU(),
nn.AdaptiveAvgPool2d(1),
nn.Flatten(),
nn.Linear(64, 256)
)
# 文本编码器
self.text_encoder = nn.Embedding(vocab_size, 256)
# Transformer
self.transformer = nn.TransformerEncoder(
nn.TransformerEncoderLayer(256, 8, 1024),
num_layers=6
)
# 答案预测
self.answer_head = nn.Linear(256, vocab_size)
def forward(self, image, question):
"""
image: [B, C, H, W]
question: [B, Q] 问题 token IDs
"""
# 视觉特征
visual_feat = self.visual_encoder(image).unsqueeze(1) # [B, 1, 256]
# 文本特征
text_feat = self.text_encoder(question) # [B, Q, 256]
# 融合
combined = torch.cat([visual_feat, text_feat], dim=1)
# Transformer
output = self.transformer(combined)
# 答案预测
answer = self.answer_head(output[:, 0]) # 使用 [CLS] token
return answer
"""
文档问答的挑战:
1. 多模态理解:
需要同时理解文本和图像
2. 长距离依赖:
答案可能在文档的不同部分
3. 精确定位:
需要精确定位答案位置
"""
8. 多语言 OCR
8.1 多语言挑战
多语言 OCR 的挑战:
1. 字符集多样:
拉丁字母、中文、阿拉伯文、日文等
字符数量差异大
2. 书写方向:
水平: 英文、中文
竖直: 传统中文
右到左: 阿拉伯文
3. 字符复杂度:
拉丁字母: 简单
中文: 复杂 (数千字符)
4. 连笔:
阿拉伯文、印地文等有连笔
8.2 多语言解决方案
class MultilingualOCR(nn.Module):
"""
多语言 OCR
理论:
使用统一模型处理多种语言
方法:
1. 共享特征提取器
2. 语言特定的解码器
3. 多任务学习
"""
def __init__(self, num_languages, num_chars_per_lang):
super().__init__()
# 共享特征提取器
self.shared_encoder = nn.Sequential(
nn.Conv2d(3, 64, 3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(64, 128, 3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2)
)
# 语言特定的解码器
self.decoders = nn.ModuleList([
nn.LSTM(128, 256, bidirectional=True, batch_first=True)
for _ in range(num_languages)
])
# 语言特定的分类器
self.classifiers = nn.ModuleList([
nn.Linear(512, num_chars)
for num_chars in num_chars_per_lang
])
def forward(self, image, language_id):
# 共享特征
features = self.shared_encoder(image)
# 语言特定解码
B, C, H, W = features.shape
features = features.view(B, C, -1).permute(0, 2, 1)
decoder = self.decoders[language_id]
output, _ = decoder(features)
classifier = self.classifiers[language_id]
output = classifier(output)
return output
"""
多语言 OCR 的策略:
1. 统一模型:
一个模型处理所有语言
简单但可能牺牲性能
2. 语言检测 + 专用模型:
先检测语言
再用专用模型识别
3. 混合方法:
共享特征 + 语言特定解码
"""
9. 评估指标与数据集
9.1 评估指标
┌─────────────────────────────────────────────────────────────────────┐
│ OCR 评估指标 │
├─────────────────┬───────────────────────────────────────────────────┤
│ 任务 │ 指标 │
├─────────────────┼───────────────────────────────────────────────────┤
│ 文本检测 │ IoU (交并比) │
│ │ Precision, Recall, F1 │
│ │ DetE (检测精度) │
├─────────────────┼───────────────────────────────────────────────────┤
│ 文本识别 │ 字符准确率 (Character Accuracy) │
│ │ 单词准确率 (Word Accuracy) │
│ │ 编辑距离 (Edit Distance) │
├─────────────────┼───────────────────────────────────────────────────┤
│ 端到端 OCR │ 端到端 F1 │
│ │ 完全匹配率 (Exact Match) │
├─────────────────┼───────────────────────────────────────────────────┤
│ 表格识别 │ TEDS (Tree-Edit-Distance-based Similarity) │
│ │ IoU (单元格) │
├─────────────────┼───────────────────────────────────────────────────┤
│ 文档理解 │ 准确率 (分类/问答) │
│ │ F1 (信息提取) │
└─────────────────┴───────────────────────────────────────────────────┘
class OCRMetrics:
"""OCR 评估指标"""
@staticmethod
def character_accuracy(pred, gt):
"""
字符准确率
正确识别的字符数 / 总字符数
"""
correct = sum(p == g for p, g in zip(pred, gt))
return correct / max(len(gt), 1)
@staticmethod
def word_accuracy(pred, gt):
"""
单词准确率
完全匹配的单词数 / 总单词数
"""
pred_words = pred.split()
gt_words = gt.split()
correct = sum(p == g for p, g in zip(pred_words, gt_words))
return correct / max(len(gt_words), 1)
@staticmethod
def edit_distance(pred, gt):
"""
编辑距离 (Levenshtein Distance)
将 pred 转换为 gt 所需的最少编辑操作数
"""
m, n = len(pred), len(gt)
dp = [[0] * (n + 1) for _ in range(m + 1)]
for i in range(m + 1):
dp[i][0] = i
for j in range(n + 1):
dp[0][j] = j
for i in range(1, m + 1):
for j in range(1, n + 1):
if pred[i-1] == gt[j-1]:
dp[i][j] = dp[i-1][j-1]
else:
dp[i][j] = min(dp[i-1][j], dp[i][j-1], dp[i-1][j-1]) + 1
return dp[m][n]
@staticmethod
def detection_iou(pred_box, gt_box):
"""
检测 IoU
"""
x1 = max(pred_box[0], gt_box[0])
y1 = max(pred_box[1], gt_box[1])
x2 = min(pred_box[2], gt_box[2])
y2 = min(pred_box[3], gt_box[3])
intersection = max(0, x2 - x1) * max(0, y2 - y1)
area_pred = (pred_box[2] - pred_box[0]) * (pred_box[3] - pred_box[1])
area_gt = (gt_box[2] - gt_box[0]) * (gt_box[3] - gt_box[1])
union = area_pred + area_gt - intersection
return intersection / union if union > 0 else 0
"""
评估指标的选择:
检测: IoU (阈值通常 0.5, 0.75)
识别: 字符准确率 (最常用)
端到端: 端到端 F1
"""
9.2 数据集
┌─────────────────────────────────────────────────────────────────────┐
│ OCR 数据集 │
├─────────────────┬───────────────────────────────────────────────────┤
│ 数据集 │ 说明 │
├─────────────────┼───────────────────────────────────────────────────┤
│ ICDAR 2015 │ 场景文本检测+识别,多方向 │
│ ICDAR 2019 │ 文档分析,表格识别 │
│ Total-Text │ 弯曲文本检测+识别 │
│ CTW │ 中文街景文本 │
│ RCTW │ 中文场景文本 │
│ PubTabNet │ 表格识别 │
│ DocVQA │ 文档问答 │
│ FUNSD │ 表单理解 │
│ XFUND │ 多语言文档理解 │
└─────────────────┴───────────────────────────────────────────────────┘
10. 应用与前沿
10.1 应用领域
OCR 与文档智能的应用:
1. 金融:
票据识别
身份证识别
银行卡识别
2. 医疗:
病历识别
处方识别
3. 政务:
证件识别
档案数字化
4. 教育:
试卷批改
手写识别
5. 物流:
快递单识别
地址识别
6. 保险:
保单识别
理赔单据
10.2 前沿方向
OCR 与文档智能的前沿:
1. 大模型文档理解:
使用 LLM 理解文档
多模态文档模型
2. 端到端系统:
检测+识别+理解一体化
3. 少样本 OCR:
新字体/新语言的快速适应
4. 手写识别:
复杂手写体识别
5. 古籍识别:
历史文献数字化
6. 实时 OCR:
移动端实时识别
附录
A. 发展时间线
2015 ──┬── CRNN (CNN+RNN+CTC)
│
2016 ──┼── CTPN (文本检测)
│
2017 ──┼── EAST (单阶段检测)
│
2019 ──┼── Transformer 文本识别
│
2020 ──┼── LayoutLM (文档理解)
│
2021 ──┼── DocFormer (多模态文档)
│
2022+ ──┴── 大模型文档理解
B. 核心公式速查
| 公式 |
含义 |
| CTC: P(y |
x) = Σ_π P(π |
| IoU = |A∩B| / |A∪B| |
交并比 |
| Edit(pred, gt) |
编辑距离 |
| DB: B(p,t) = σ(k(p-t)) |
可微分二值化 |