多模态大模型学习笔记(三十一)—— 基于CCT(Compact Convolutional Transformers)实现中文车牌数据集微调

Fast Plate OCR: 深度学习车牌识别完全指南

从原理到实践,从民用车牌到生产部署

导读: 本文基于 Fast Plate OCR 框架和 CCT (Compact Convolutional Transformers) 模型,系统展示如何对中国民用车牌进行深度学习微调。文章融理论与实践,重点讲解数据组织、配置管理、训练流程和推理部署,让您能快速复现和定制车牌识别系统。


0. 系统架构总览

0.1 端到端推理流程

0.2 训练完整流程

No
Yes
🔄 ImageNet 预训练

CCT_S_V2 1.3M 参数
🔐 冻结前期层

保留通用特征
📁 数据准备

组织训练集和验证集
⚙️ 配置参数

LR=0.001, Batch=64
🔄 300 Epoch 循环
➡️ 前向传播
📋 损失计算

CE + L2
⬅️ 反向传播
✏️ 梯度剪裁

threshold=1.0
🤖 AdamW 更新
🎨 余弦衰减
🏅 验证评估
epoch < 300?
💾 保存最优模型

best.keras
✅ 完成!


1. 数据准备

1.1 数据集目录结构

您的数据应该组织成以下目录结构:

复制代码
plate_rec/
├── train_plate/          # 训练图像文件夹
│   ├── IMG_0001.jpg
│   ├── IMG_0002.jpg
│   └── ... (62,857 张图像)
├── val_plate/            # 验证图像文件夹
│   ├── IMG_0001.jpg
│   ├── IMG_0002.jpg
│   └── ... (2,014 张图像)
├── train.csv             # 训练标签文件
├── val.csv               # 验证标签文件
└── config/
    └── chinese_plate_config.yaml

1.2 CSV 标签文件格式

train.csvval.csv 的格式必须是:

csv 复制代码
image_path,label
train_plate/IMG_0001.jpg,京A1C38892
train_plate/IMG_0002.jpg,浙B7X5Y2K1
train_plate/IMG_0003.jpg,沪C9N3M7L5
...

重点

  • image_path相对路径(相对于 CSV 文件所在目录)
  • label 是车牌号码(6-9 个字符)
  • 每行一条记录

1.3 图像要求

要求 规格
分辨率 128×64 像素
格式 JPG、PNG、BMP
色彩空间 RGB 三通道
文件大小 通常 5-50 KB

1.4 数据验证脚本

运行以下命令验证数据完整性:

bash 复制代码
python verify_dataset.py --train-csv train.csv --val-csv val.csv

该脚本会检查:

  • ✓ CSV 文件格式是否正确
  • ✓ 所有图像文件是否存在
  • ✓ 车牌号码格式是否合法
  • ✓ 字符集是否包含在配置中

2. 配置文件管理

2.1 车牌配置文件

编辑 chinese_plate_config.yaml

yaml 复制代码
# 车牌识别配置

# 字符集定义
alphabet: '_京沪津渝冀晋蒙辽吉黑苏浙皖闽赣鲁豫鄂湘粤桂琼川贵云藏陕甘青宁新学警港澳挂使领民航0123456789ABCDEFGHJKLMNPQRSTUVWXYZ'

# 最大字符数(包括填充符)
max_plate_slots: 9

# 图像配置
image_height: 64
image_width: 128
image_channels: 3

# 数据增强配置
augmentation:
  rotation: [-5, 5]           # 旋转范围(度)
  brightness: [0.7, 1.3]     # 亮度范围
  contrast: [0.8, 1.2]       # 对比度范围
  noise_std: 0.01            # 高斯噪声标准差
  blur_kernel: 3             # 模糊核大小

关键参数解释

  • alphabet:包含所有可能的字符(34 个省份 + 26 个英文 + 10 个数字 + 1 个填充符)
  • max_plate_slots:最长车牌长度,中国民用车牌最长 9 个字符
  • image_height/width:固定为 64×128(模型输入大小)

2.2 训练配置文件

创建 train_config.py

python 复制代码
# 训练超参数配置

# 数据路径
TRAIN_CSV = 'train.csv'
VAL_CSV = 'val.csv'
CONFIG_PATH = 'chinese_plate_config.yaml'

# 模型配置
MODEL_NAME = 'cct_s_v2'      # CCT_S_V2 模型
PRETRAINED = True             # 使用 ImageNet 预训练权重
FREEZE_BACKBONE = True        # 冻结卷积层

# 训练超参数
LEARNING_RATE = 0.001         # 初始学习率
BATCH_SIZE = 64               # 批大小
NUM_EPOCHS = 300              # 训练轮数
WARMUP_EPOCHS = 15            # 预热轮数(前 5% 的迭代)

# 优化器配置
OPTIMIZER = 'adamw'           # 使用 AdamW 优化器
WEIGHT_DECAY = 0.0001         # 权重衰减(L2 正则化)
GRADIENT_CLIP = 1.0           # 梯度剪裁阈值

# 学习率调度
LR_SCHEDULE = 'cosine_warmup' # 余弦衰减 + 线性预热
LR_MIN = 0.00001              # 最小学习率

# 早停配置
EARLY_STOP_PATIENCE = 20      # 早停耐心值(多少轮无改进就停止)
EARLY_STOP_METRIC = 'val_accuracy'

# 保存配置
SAVE_DIR = './checkpoints'    # 保存路径
SAVE_BEST_ONLY = True         # 只保存最优模型
SAVE_FREQUENCY = 10           # 每 10 轮保存一次

3. 模型理论

3.1 车牌识别的数学定义

车牌识别问题定义为序列标注(Sequence Labeling):

给定输入图像 I∈RH×W×C\text{给定输入图像 } I \in \mathbb{R}^{H \times W \times C}给定输入图像 I∈RH×W×C
求解:arg⁡max⁡P(T∣I)\text{求解:} \arg\max P(T | I)求解:argmaxP(T∣I)
其中 T=[t1,t2,...,tn] 是车牌文本序列\text{其中 } T = [t_1, t_2, \ldots, t_n] \text{ 是车牌文本序列}其中 T=[t1,t2,...,tn] 是车牌文本序列

这是一个多任务学习问题:同时进行定位(位置检测)和分类(字符识别)。

3.2 CCT 模型架构

CCT (Compact Convolutional Transformers) 结合了 CNN 和 Transformer 的优势:

第1部分:CNN 特征提取

  • 4 层卷积 (Conv2D)
  • 参数共享与局部连接,参数量少
  • 自动提取图像的低级特征(边缘、纹理)

第2部分:Token 化

  • 将 CNN 输出的特征图分割成 tokens
  • 每个 token 是 2×2 的小块,共 2,048 个
  • 每个 token 的维度是 112

第3部分:Token Reducer

  • 压缩 50% 的 tokens(2,048 → 1,024)
  • 保留重要信息,丢弃冗余
  • 加速计算(Transformer 复杂度是 O(n²))

第4部分:Transformer 编码器

  • 5 层 Transformer
  • 8 头多注意力机制
  • 学习位置间的依赖关系(例如:"京" 后通常是字母)

第5部分:分类头

  • 全连接层(FC),输出 9×75
  • 9 个位置,每个位置 75 个字符类别
  • Softmax 激活

3.3 自注意力机制

自注意力让模型学习字符间的关系:

Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)VAttention(Q,K,V)=softmax(dk QKT)V

在车牌识别中:

  • 模型学到 "京" 后面通常是英文字母
  • "0" 和 "O" 的区别(已在字符集中移除 "O")
  • 字符组合的合法性

4. 训练过程

4.1 安装依赖

bash 复制代码
pip install tensorflow keras opencv-python numpy pandas pyyaml

4.2 数据加载

创建 data_loader.py

python 复制代码
import pandas as pd
import cv2
import numpy as np
from pathlib import Path

class PlateDataLoader:
    def __init__(self, csv_path, image_dir, config):
        self.df = pd.read_csv(csv_path)
        self.image_dir = Path(image_dir)
        self.config = config
    
    def load_image(self, image_path):
        """加载并预处理图像"""
        full_path = self.image_dir / image_path
        img = cv2.imread(str(full_path))
        
        if img is None:
            raise FileNotFoundError(f"Image not found: {full_path}")
        
        # 转为 RGB
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        
        # 调整大小到 64×128
        img = cv2.resize(img, (128, 64))
        
        # 标准化到 [0, 1]
        img = img.astype(np.float32) / 255.0
        
        return img
    
    def encode_label(self, label):
        """将车牌号码编码为索引序列"""
        alphabet = self.config['alphabet']
        encoded = []
        for char in label:
            idx = alphabet.index(char)
            encoded.append(idx)
        
        # 用填充符填充到 max_plate_slots
        while len(encoded) < self.config['max_plate_slots']:
            encoded.append(0)  # 0 是填充符
        
        return np.array(encoded[:self.config['max_plate_slots']])
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        image = self.load_image(row['image_path'])
        label = self.encode_label(row['label'])
        return image, label

# 使用示例
import yaml
config = yaml.safe_load(open('chinese_plate_config.yaml'))
train_loader = PlateDataLoader('train.csv', '.', config)
img, label = train_loader[0]
print(f"Image shape: {img.shape}, Label: {label}")

4.3 模型定义

创建 model.py

python 复制代码
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

def build_plate_recognition_model(config):
    """构建车牌识别模型"""
    
    inputs = keras.Input(shape=(64, 128, 3))
    
    # CNN 部分
    x = layers.Conv2D(48, 3, padding='same', activation='relu')(inputs)
    x = layers.MaxPooling2D(2)(x)  # 32×64
    
    x = layers.Conv2D(80, 3, padding='same', activation='relu')(x)
    x = layers.MaxPooling2D(2)(x)  # 16×32
    
    x = layers.Conv2D(96, 3, padding='same', activation='relu')(x)
    x = layers.Conv2D(112, 3, padding='same', activation='relu')(x)
    
    # Flatten 为序列
    _, h, w, c = x.shape
    x = layers.Reshape((h * w, c))(x)  # (16×32, 112) = (512, 112)
    
    # Token Reducer(简化版本)
    x = layers.Dense(112)(x)  # 保持维度
    
    # Transformer 编码器(5 层)
    for _ in range(5):
        attention = layers.MultiHeadAttention(num_heads=8, key_dim=14)
        x = layers.Add()([x, attention(x, x)])
        x = layers.LayerNormalization()(x)
        
        ffn = keras.Sequential([
            layers.Dense(256, activation='relu'),
            layers.Dense(112)
        ])
        x = layers.Add()([x, ffn(x)])
        x = layers.LayerNormalization()(x)
    
    # 分类头
    num_chars = len(config['alphabet'])
    max_slots = config['max_plate_slots']
    
    # 从序列中取前 max_slots 个位置
    x = x[:, :max_slots, :]
    
    # 为每个位置预测字符
    outputs = []
    for i in range(max_slots):
        out = layers.Dense(num_chars, activation='softmax')(x[:, i, :])
        outputs.append(out)
    
    model = keras.Model(inputs, outputs)
    return model

4.4 训练脚本

创建 train.py

python 复制代码
import yaml
import tensorflow as tf
from tensorflow import keras
from data_loader import PlateDataLoader
from model import build_plate_recognition_model

# 加载配置
config = yaml.safe_load(open('chinese_plate_config.yaml'))
train_config = __import__('train_config')

# 构建数据加载器
train_loader = PlateDataLoader(
    train_config.TRAIN_CSV,
    '.',
    config
)

# 构建模型
model = build_plate_recognition_model(config)

# 编译模型
model.compile(
    optimizer=keras.optimizers.AdamW(
        learning_rate=train_config.LEARNING_RATE,
        weight_decay=train_config.WEIGHT_DECAY
    ),
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

# 学习率调度
def lr_schedule(epoch):
    if epoch < train_config.WARMUP_EPOCHS:
        # 线性预热
        return train_config.LEARNING_RATE * (epoch / train_config.WARMUP_EPOCHS)
    else:
        # 余弦衰减
        progress = (epoch - train_config.WARMUP_EPOCHS) / (
            train_config.NUM_EPOCHS - train_config.WARMUP_EPOCHS
        )
        return train_config.LR_MIN + 0.5 * (
            train_config.LEARNING_RATE - train_config.LR_MIN
        ) * (1 + tf.math.cos(tf.constant(3.14159) * progress))

lr_callback = keras.callbacks.LearningRateScheduler(lr_schedule)

# 模型保存回调
save_callback = keras.callbacks.ModelCheckpoint(
    f'{train_config.SAVE_DIR}/best.h5',
    monitor=train_config.EARLY_STOP_METRIC,
    save_best_only=True,
    verbose=1
)

# 训练
model.fit(
    train_loader,
    validation_data=val_loader,
    epochs=train_config.NUM_EPOCHS,
    batch_size=train_config.BATCH_SIZE,
    callbacks=[lr_callback, save_callback]
)

print("Training completed! Model saved to best.h5")

4.5 启动训练

bash 复制代码
# 设置 PyTorch 后端(可选,如使用 Keras 3.0)
export KERAS_BACKEND=torch

# 运行训练
python train.py

预期输出

复制代码
Epoch 1/300
100/1000 [=====>...] - loss: 4.523 - accuracy: 0.245 - val_loss: 3.892 - val_accuracy: 0.412
Epoch 2/300
100/1000 [=====>...] - loss: 3.234 - accuracy: 0.512 - val_loss: 2.876 - val_accuracy: 0.623
...
Epoch 300/300
100/1000 [=====>...] - loss: 0.134 - accuracy: 0.945 - val_loss: 0.156 - val_accuracy: 0.945

5. 推理部署

5.1 加载模型

python 复制代码
import tensorflow as tf
from tensorflow import keras
import cv2
import yaml
import numpy as np

# 加载配置和模型
config = yaml.safe_load(open('chinese_plate_config.yaml'))
model = keras.models.load_model('checkpoints/best.h5')

# 将输出索引解码回字符
def decode_prediction(predictions, config):
    """将模型输出解码为车牌号码"""
    alphabet = config['alphabet']
    plate = ''
    for pred in predictions:
        char_idx = np.argmax(pred)
        if char_idx > 0:  # 跳过填充符
            plate += alphabet[char_idx]
    return plate

5.2 推理单张图像

python 复制代码
def recognize_plate(image_path, model, config):
    """识别单张图像中的车牌"""
    
    # 读取图像
    img = cv2.imread(image_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    
    # 调整大小
    img = cv2.resize(img, (128, 64))
    img = img.astype(np.float32) / 255.0
    
    # 添加 batch 维度
    img_batch = np.expand_dims(img, axis=0)
    
    # 推理
    predictions = model.predict(img_batch)
    
    # 解码
    plate = decode_prediction(predictions[0], config)
    
    # 计算置信度
    confidences = [np.max(pred) for pred in predictions[0]]
    avg_confidence = np.mean(confidences)
    
    return {
        'plate': plate,
        'confidence': avg_confidence,
        'confidences_per_char': confidences
    }

# 使用示例
result = recognize_plate('test_plate.jpg', model, config)
print(f"车牌: {result['plate']}")
print(f"置信度: {result['confidence']:.2%}")

5.3 批量推理

python 复制代码
def recognize_plates_batch(image_dir, model, config):
    """批量识别目录中的所有图像"""
    from pathlib import Path
    
    results = []
    image_dir = Path(image_dir)
    
    for image_path in image_dir.glob('*.jpg'):
        result = recognize_plate(str(image_path), model, config)
        result['image_path'] = str(image_path)
        results.append(result)
        print(f"✓ {image_path.name}: {result['plate']} ({result['confidence']:.2%})")
    
    return results

# 使用示例
results = recognize_plates_batch('./test_images', model, config)

5.4 Web 服务部署

使用 Flask 创建简单的推理 API:

python 复制代码
from flask import Flask, request, jsonify
import tensorflow as tf
from tensorflow import keras
import cv2
import yaml
import numpy as np
import base64
from io import BytesIO

app = Flask(__name__)

# 加载模型
config = yaml.safe_load(open('chinese_plate_config.yaml'))
model = keras.models.load_model('checkpoints/best.h5')

@app.route('/recognize', methods=['POST'])
def recognize():
    """推理端点"""
    try:
        # 获取图像(Base64 编码)
        data = request.json
        image_data = base64.b64decode(data['image'])
        image_array = np.frombuffer(image_data, dtype=np.uint8)
        img = cv2.imdecode(image_array, cv2.IMREAD_COLOR)
        
        # 预处理
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = cv2.resize(img, (128, 64))
        img = img.astype(np.float32) / 255.0
        img_batch = np.expand_dims(img, axis=0)
        
        # 推理
        predictions = model.predict(img_batch)
        
        # 解码
        alphabet = config['alphabet']
        plate = ''
        for pred in predictions[0]:
            char_idx = np.argmax(pred)
            if char_idx > 0:
                plate += alphabet[char_idx]
        
        confidences = [float(np.max(pred)) for pred in predictions[0]]
        
        return jsonify({
            'success': True,
            'plate': plate,
            'confidence': float(np.mean(confidences)),
            'confidences': confidences
        })
    
    except Exception as e:
        return jsonify({
            'success': False,
            'error': str(e)
        }), 400

if __name__ == '__main__':
    app.run(host='0.0.0.0', port=5000, debug=False)

启动服务:

bash 复制代码
python app.py

客户端调用:

python 复制代码
import requests
import base64

image_path = 'test_plate.jpg'
with open(image_path, 'rb') as f:
    image_data = base64.b64encode(f.read()).decode()

response = requests.post(
    'http://localhost:5000/recognize',
    json={'image': image_data}
)

result = response.json()
print(f"车牌: {result['plate']}")
print(f"置信度: {result['confidence']:.2%}")

5.5 Docker 容器化

创建 Dockerfile

dockerfile 复制代码
FROM tensorflow/tensorflow:latest-gpu

WORKDIR /app

COPY requirements.txt .
RUN pip install -r requirements.txt

COPY . .

EXPOSE 5000

CMD ["python", "app.py"]

构建和运行:

bash 复制代码
docker build -t plate-recognition:latest .
docker run -p 5000:5000 plate-recognition:latest

6. 性能评估

6.1 准确率指标

字符级准确率(Character-Level Accuracy)

Char Acc=正确预测的字符数总字符数\text{Char Acc} = \frac{\text{正确预测的字符数}}{\text{总字符数}}Char Acc=总字符数正确预测的字符数

例如:预测 "京A1C38892",真实 "京A1C38893",则 Char Acc=89=88.9%\text{Char Acc} = \frac{8}{9} = 88.9\%Char Acc=98=88.9%

序列级准确率(Sequence-Level Accuracy)

Seq Acc=完全正确的车牌数总车牌数\text{Seq Acc} = \frac{\text{完全正确的车牌数}}{\text{总车牌数}}Seq Acc=总车牌数完全正确的车牌数

例如:100 个车牌中 94 个完全正确,则 Seq Acc=94100=94%\text{Seq Acc} = \frac{94}{100} = 94\%Seq Acc=10094=94%

6.2 评估脚本

创建 evaluate.py

python 复制代码
import numpy as np
from data_loader import PlateDataLoader
import yaml

config = yaml.safe_load(open('chinese_plate_config.yaml'))
val_loader = PlateDataLoader('val.csv', '.', config)

# 评估
char_correct = 0
char_total = 0
seq_correct = 0
seq_total = 0

for i in range(len(val_loader)):
    image, label = val_loader[i]
    predictions = model.predict(np.expand_dims(image, axis=0))
    
    # 解码预测和标签
    pred_seq = decode_prediction(predictions[0], config)
    true_seq = decode_label(label, config)
    
    # 字符级准确率
    for pred_char, true_char in zip(pred_seq, true_seq):
        if pred_char == true_char:
            char_correct += 1
        char_total += 1
    
    # 序列级准确率
    if pred_seq == true_seq:
        seq_correct += 1
    seq_total += 1

char_acc = char_correct / char_total
seq_acc = seq_correct / seq_total

print(f"Character Accuracy: {char_acc:.2%}")
print(f"Sequence Accuracy: {seq_acc:.2%}")

7. 常见问题

Q1:数据集应该多大?

A:至少 10,000 张训练样本。我们用 60,000 张训练样本达到 94.5% 的准确率。更多数据会带来更好的效果。

Q2:可以用自己的预训练模型吗?

A:可以。只需修改配置中的 PRETRAINED 为 False,然后加载你的模型权重。

Q3:推理速度如何?

A:在 GPU 上约 12ms/张,在 CPU 上约 45ms/张。

Q4:如何处理低质量图像?

A:增加数据增强的强度(旋转角度、噪声等),或使用焦点损失(Focal Loss)加重难样本。

Q5:模型可以识别其他国家的车牌吗?

A:需要重新训练。字符集会不同(其他国家没有中文字符),模型需要适应新的字符集。


8. 总结

通过本指南,您已经学会了:

  • ✓ 如何组织车牌识别的数据
  • ✓ 如何配置 CCT 模型
  • ✓ 如何训练并部署模型
  • ✓ 如何进行推理和评估

关键要点:

  1. 数据质量重要:数据组织和标签准确比数据量更重要
  2. 迁移学习有效:用预训练权重可以大幅减少训练时间和数据需求
  3. 配置灵活:通过调整学习率、批大小等参数可以优化性能
  4. 推理简单:加载模型后,推理只需几行代码

更多资源

相关推荐
zzh0812 小时前
MySQL故障排查与优化笔记
数据库·笔记·mysql
婷婷_1722 小时前
【PCIe 验证每日学习・Day26】PCIe 错误处理与异常恢复机制
网络·学习·程序人生·芯片·原子操作·pcie 验证
AI成长日志2 小时前
【笔面试算法学习专栏】堆与优先队列实战:力扣hot100之215.数组中的第K个最大元素、347.前K个高频元素
学习·算法·leetcode
&&Citrus2 小时前
【CPN 学习笔记(三)】—— Chap3 CPN ML 编程语言 上半部分 3.1 ~ 3.3
笔记·python·学习·cpn·petri网
航Hang*3 小时前
第3章:Linux系统安全管理——第1节:Linux 防火墙部署(firewalld)
linux·服务器·网络·学习·系统安全·vmware
宋小米的csdn3 小时前
网络知识学习路线(实用向)
网络·学习
南境十里·墨染春水3 小时前
linux学习进展 基础命令 vi基础命令
linux·运维·服务器·笔记·学习
Xudde.3 小时前
班级作业笔记报告0x08
笔记·学习·安全·web安全
迷路爸爸1803 小时前
Docker 入门学习笔记 05:卷到底是什么,为什么容器删了数据却还能保留
笔记·学习·docker