目录
- [一、 任务概述](#一、 任务概述)
- [二、 理论基础与网络架构设计](#二、 理论基础与网络架构设计)
-
- [2.1 EfficientNet-B3 的架构特点](#2.1 EfficientNet-B3 的架构特点)
- [2.2 广义平均池化 (GeM Pooling)](#2.2 广义平均池化 (GeM Pooling))
- [2.3 网络结构定义](#2.3 网络结构定义)
- [三、 训练数据准备](#三、 训练数据准备)
-
- [3.1 图像预处理适配](#3.1 图像预处理适配)
- [3.2 数据集类更新](#3.2 数据集类更新)
- [四、 微调训练流程实现](#四、 微调训练流程实现)
- [五、 系统集成与部署](#五、 系统集成与部署)
-
- [5.1 更新 `feature_extractor.py`](#5.1 更新
feature_extractor.py) - [5.2 数据库重建与特征重算](#5.2 数据库重建与特征重算)
- [5.1 更新 `feature_extractor.py`](#5.1 更新
- [六、 总结](#六、 总结)
一、 任务概述
在全球证件智能识别系统的持续迭代中,证件版式检索模块的性能需要在"识别准确率"与"推理效率"之间寻找最佳平衡点。在前序的实践中,MobileNetV3虽然速度极快,但在处理未见样本(Zero-shot)及复杂版式时特征区分度不足。
为构建一个兼顾高精度与高效率的检索底座,本篇博客将对检索模块的特征提取网络进行升级。技术方案将由MobileNetV3切换至EfficientNet-B3,并结合广义平均池化(GeM Pooling)与度量学习(Metric Learning)进行全流程微调。
EfficientNet-B3是EfficientNet家族中的"中坚力量"。其输入分辨率标准为300x300,参数量约为47M,在ImageNet等基准测试中的表现依然远超MobileNet系列。配合GeM池化层,该模型能够有效捕捉证件中的微缩文字布局与防伪纹理特征,同时保持较快的推理速度,非常适合车管所窗口等对实时性有一定要求的业务场景。
本篇博客将详细阐述基于EfficientNet-B3的网络结构改造、GeM池化层的集成、适配300x300分辨率的数据预处理流程,以及完整的微调训练脚本。
二、 理论基础与网络架构设计
2.1 EfficientNet-B3 的架构特点
EfficientNet-B3通过复合缩放系数对网络进行了优化。与本项目前序尝试的模型对比如下:
- MobileNetV3 Large: 输入224x224,特征较弱,极速。
- EfficientNet-B3: 输入300x300,特征强,速度中等(平衡之选)。
B3的最后一次卷积层输出通道数为 1536(B5为2048,B1为1280),这一数值将直接影响后续投影头(Projection Head)的设计。
2.2 广义平均池化 (GeM Pooling)
为解决全局平均池化(GAP)导致的空间信息丢失问题,继续沿用广义平均池化(GeM)。其公式为:
f = ( 1 ∣ X ∣ ∑ x ∈ X x p ) 1 p \textbf{f} = \left( \frac{1}{|\mathcal{X}|} \sum_{x \in \mathcal{X}} x^p \right)^{\frac{1}{p}} f=(∣X∣1x∈X∑xp)p1
通过训练学习参数 p p p,网络能够自适应地关注证件图像中的显著区域(如Logo、印章),而非平均化所有背景信息。
2.3 网络结构定义
新建文件 model_efficientnet_b3.py,定义包含GeM池化层的EfficientNet-B3特征提取网络。
代码清单:model_efficientnet_b3.py
python
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
class GeM(nn.Module):
"""
广义平均池化层 (Generalized Mean Pooling)
参数:
p (float): 初始p值,通常设置在3.0左右
eps (float): 数值稳定性常数
"""
def __init__(self, p=3.0, eps=1e-6):
super(GeM, self).__init__()
# 将p定义为可学习的参数
self.p = nn.Parameter(torch.ones(1) * p)
self.eps = eps
def forward(self, x):
return F.avg_pool2d(x.clamp(min=self.eps).pow(self.p), (x.size(-2), x.size(-1))).pow(1./self.p)
def __repr__(self):
return self.__class__.__name__ + '(' + 'p=' + '{:.4f}'.format(self.p.data.tolist()[0]) + ', ' + 'eps=' + str(self.eps) + ')'
class EfficientNetB3Embedding(nn.Module):
"""
基于EfficientNet-B3的特征提取网络
结构: Backbone(EfficientNet-B3) -> GeM Pooling -> Flatten -> Linear(Projection)
"""
def __init__(self, embedding_dim=1024, pretrained=True):
super(EfficientNetB3Embedding, self).__init__()
# 1. 加载预训练的EfficientNet-B3
weights = models.EfficientNet_B3_Weights.DEFAULT if pretrained else None
base_model = models.efficientnet_b3(weights=weights)
# 2. 提取特征层
# EfficientNet的features部分输出最后的卷积特征图
self.features = base_model.features
# 3. 引入GeM池化层
self.pool = GeM()
# 4. 获取Backbone输出通道数
# EfficientNet-B3 最后一层卷积输出通道数为 1536
out_channels = 1536
# 5. 定义Flatten层
self.flatten = nn.Flatten()
# 6. 定义投影头 (Projection Head)
# 将高维特征映射到指定的Embedding维度
self.fc = nn.Sequential(
nn.Linear(out_channels, out_channels),
nn.BatchNorm1d(out_channels),
nn.ReLU(),
nn.Linear(out_channels, embedding_dim)
)
def forward(self, x):
# 输入: [Batch, 3, 300, 300]
# 提取特征图: [Batch, 1536, H, W]
x = self.features(x)
# GeM池化: [Batch, 1536, 1, 1]
x = self.pool(x)
# 展平: [Batch, 1536]
x = self.flatten(x)
# 线性投影: [Batch, embedding_dim]
x = self.fc(x)
# L2归一化
x = F.normalize(x, p=2, dim=1)
return x
三、 训练数据准备
3.1 图像预处理适配
EfficientNet-B3的标准输入尺寸为 300x300。数据预处理模块需将图像Resize至该尺寸。
3.2 数据集类更新
修改 TripletDataset 类中的尺寸配置。
代码清单:dataset_loader.py
python
import random
from pathlib import Path
from PIL import Image, ImageDraw
import torch
from torch.utils.data import Dataset
from torchvision import transforms
# --- 数据增强:随机划痕 (保持不变) ---
class RandomScratches:
def __init__(self, num_scratches_range=(1, 5), p=0.5):
self.num_scratches_range = num_scratches_range
self.p = p
def __call__(self, img):
if random.random() > self.p:
return img
img_draw = img.copy()
draw = ImageDraw.Draw(img_draw)
width, height = img.size
num_scratches = random.randint(*self.num_scratches_range)
for _ in range(num_scratches):
x1, y1 = random.randint(0, width), random.randint(0, height)
x2, y2 = random.randint(0, width), random.randint(0, height)
line_width = random.randint(1, 2)
line_color = random.randint(50, 200)
draw.line([(x1, y1), (x2, y2)], fill=(line_color, line_color, line_color), width=line_width)
return img_draw
class TripletDataset(Dataset):
"""
三元组数据集加载器
适配 EfficientNet-B3 输入尺寸 (300x300)
"""
def __init__(self, image_dir, image_type_suffix, is_train=True):
self.image_dir = Path(image_dir)
self.image_type_suffix = image_type_suffix
self.is_train = is_train
# ImageNet 均值和方差
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
# EfficientNet-B3 输入尺寸
input_size = 300
if self.is_train:
self.transform = transforms.Compose([
transforms.Resize((320, 320)), # 先放大稍多一点
transforms.RandomCrop((input_size, input_size)), # 随机裁剪
transforms.RandomHorizontalFlip(p=0.3),
transforms.RandomAffine(degrees=10, translate=(0.1, 0.1), scale=(0.9, 1.1)),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
RandomScratches(p=0.3),
transforms.ToTensor(),
normalize
])
else:
self.transform = transforms.Compose([
transforms.Resize((input_size, input_size)),
transforms.ToTensor(),
normalize
])
self.samples = []
self.class_to_images = {}
self.country_to_classes = {}
# 建立索引逻辑
self._build_index()
def _build_index(self):
print(f"正在扫描 '{self.image_type_suffix}' 类型数据...")
for country_dir in self.image_dir.iterdir():
if not country_dir.is_dir(): continue
country_name = country_dir.name
for state_dir in country_dir.iterdir():
if not state_dir.is_dir(): continue
for template_dir in state_dir.iterdir():
if not template_dir.is_dir(): continue
# 查找对应后缀的图像
image_files = list(template_dir.glob(f"*{self.image_type_suffix}"))
if image_files:
class_id = f"{country_name}_{state_dir.name}_{template_dir.name}"
if country_name not in self.country_to_classes:
self.country_to_classes[country_name] = []
if class_id not in self.country_to_classes[country_name]:
self.country_to_classes[country_name].append(class_id)
self.class_to_images[class_id] = image_files
for img_path in image_files:
self.samples.append((class_id, img_path, country_name))
print(f"索引建立完成,共找到 {len(self.samples)} 个样本。")
def __len__(self):
return len(self.samples)
def __getitem__(self, index):
# 1. 获取 Anchor
anchor_class, anchor_path, anchor_country = self.samples[index]
anchor_img = self._load_image(anchor_path)
# 2. 选取 Positive (同类样本)
possible_positives = self.class_to_images[anchor_class]
if len(possible_positives) < 2:
# 只有一张图时,强增强作为正样本
positive_img = self.transform(Image.open(anchor_path).convert('L').convert('RGB'))
else:
positive_path = random.choice(possible_positives)
while positive_path == anchor_path and len(possible_positives) > 1:
positive_path = random.choice(possible_positives)
positive_img = self._load_image(positive_path)
# 3. 选取 Negative (异类样本)
possible_neg_classes = [c for c in self.country_to_classes.get(anchor_country, []) if c != anchor_class]
if not possible_neg_classes:
all_classes = list(self.class_to_images.keys())
possible_neg_classes = [c for c in all_classes if c != anchor_class]
negative_class = random.choice(possible_neg_classes)
negative_path = random.choice(self.class_to_images[negative_class])
negative_img = self._load_image(negative_path)
return anchor_img, positive_img, negative_img
def _load_image(self, path):
# 保持RGB输入
return self.transform(Image.open(path).convert('L').convert('RGB'))
四、 微调训练流程实现
编写训练脚本 train_efficientnet_b3.py。
代码清单:train_efficientnet_b3.py
python
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, ConcatDataset
from model_efficientnet_b3 import EfficientNetB3Embedding
from dataset_loader import TripletDataset
def train():
# 1. 配置参数
# EfficientNet-B3 显存占用适中
BATCH_SIZE = 12
LEARNING_RATE = 1e-4
NUM_EPOCHS = 50
EMBEDDING_DIM = 512
MARGIN = 1.0
SAVE_PATH = "efficientnet_b3_gem_finetuned.pth"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"训练设备: {device}")
# 2. 准备数据集
print("正在加载数据集...")
dataset_front = TripletDataset(image_dir="samples", image_type_suffix="_front_white.jpg", is_train=True)
dataset_back = TripletDataset(image_dir="samples", image_type_suffix="_back_white.jpg", is_train=True)
if len(dataset_front) == 0 and len(dataset_back) == 0:
print("错误:未找到有效样本,请检查samples目录结构。")
return
full_dataset = ConcatDataset([dataset_front, dataset_back])
dataloader = DataLoader(full_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True)
print(f"数据集加载完成,总样本数(三元组): {len(full_dataset)}")
# 3. 初始化模型
print("正在初始化 EfficientNet-B3 模型...")
model = EfficientNetB3Embedding(embedding_dim=EMBEDDING_DIM, pretrained=True).to(device)
# 4. 定义损失函数和优化器
criterion = nn.TripletMarginLoss(margin=MARGIN, p=2)
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-4)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS, eta_min=1e-6)
# 5. 训练循环
best_loss = float('inf')
model.train()
print("开始训练...")
for epoch in range(NUM_EPOCHS):
running_loss = 0.0
for i, (anchor, positive, negative) in enumerate(dataloader):
anchor = anchor.to(device)
positive = positive.to(device)
negative = negative.to(device)
# 前向传播
emb_a = model(anchor)
emb_p = model(positive)
emb_n = model(negative)
# 计算损失
loss = criterion(emb_a, emb_p, emb_n)
# 反向传播与优化
optimizer.zero_grad()
loss.backward()
optimizer.step()
running_loss += loss.item()
if (i + 1) % 10 == 0:
print(f"Epoch [{epoch+1}/{NUM_EPOCHS}], Step [{i+1}/{len(dataloader)}], Loss: {loss.item():.4f}")
# 计算Epoch平均损失
epoch_loss = running_loss / len(dataloader)
current_lr = optimizer.param_groups[0]['lr']
print(f"Epoch [{epoch+1}/{NUM_EPOCHS}] 完成, 平均Loss: {epoch_loss:.4f}, 当前LR: {current_lr:.6f}")
# 更新学习率
scheduler.step()
# 保存最佳模型
if epoch_loss < best_loss:
best_loss = epoch_loss
torch.save(model.state_dict(), SAVE_PATH)
print(f"--> 模型性能提升,已保存至 {SAVE_PATH}")
print("训练结束。")
if __name__ == "__main__":
train()
执行训练:
在项目根目录运行以下命令:
bash
python train_efficientnet_b3.py
五、 系统集成与部署
模型训练完成后,需更新后端的特征提取模块,并重建特征数据库。
5.1 更新 feature_extractor.py
修改特征提取器以适配EfficientNet-B3。
代码清单:feature_extractor.py
python
import io
import pickle
import os
import torch
import numpy as np
from PIL import Image
from torchvision import transforms
# 导入B3模型结构
from model_efficientnet_b3 import EfficientNetB3Embedding
class ImageFeatureExtractor:
def __init__(self, model_path="efficientnet_b3_gem_finetuned.pth"):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 初始化模型,Embedding维度为512
# pretrained=False 因为我们将加载自己的微调权重
self.model = EfficientNetB3Embedding(embedding_dim=512, pretrained=False).to(self.device)
# 加载微调权重
if os.path.exists(model_path):
self.model.load_state_dict(torch.load(model_path, map_location=self.device))
print(f"成功加载 EfficientNet-B3 微调模型: {model_path}")
else:
print(f"警告: 未找到权重文件 {model_path},使用随机初始化权重(仅用于测试)")
self.model.eval()
# 预处理:适配EfficientNet-B3的300x300输入
self.preprocess = transforms.Compose([
transforms.Resize((300, 300)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
def extract_features(self, image_bytes: bytes) -> bytes:
"""
提取特征并返回序列化后的字节流
"""
try:
# 统一转为RGB
image = Image.open(io.BytesIO(image_bytes)).convert("L").convert("RGB")
input_tensor = self.preprocess(image)
input_batch = input_tensor.unsqueeze(0).to(self.device)
with torch.no_grad():
output_features = self.model(input_batch)
# 转换为Numpy并序列化
feature_np = output_features.cpu().numpy().flatten()
return pickle.dumps(feature_np)
except Exception as e:
print(f"特征提取失败: {e}")
return b''
5.2 数据库重建与特征重算
EfficientNet-B3生成的特征空间与原模型完全不同。必须执行全量数据库重建。
首先修改init_db.py文件:
python
# 注释掉原来的代码
# extractor = ImageFeatureExtractor(model_path="mobilenetv3_finetuned.pth")
# 改成下面的代码
extractor = ImageFeatureExtractor(model_path="./efficientnet_b3_gem_finetuned.pth")
然后依次执行如下命令:
bash
# delete old db file
rm card_db.sqlite
# delete old alembic versions
rm -rf alembic/versions/*
# generate new alembic revision
alembic revision --autogenerate -m "add_layout_schema"
# add import sqlmodel to the new revision file then apply the migration
alembic upgrade head
# init db
python init_db.py
六、 总结
本篇博客详细记录了将全球证件识别系统的检索模块升级至EfficientNet-B3的完整工程实践。相比于MobileNetV3,EfficientNet-B3提供了更强的特征表达能力。结合GeM池化与度量学习,系统现在能够生成区分度极高的"证件指纹",有效解决了未见样本检索不准的问题,为后续的SIFT精排提供了高质量的候选集。这一升级标志着系统在算法层面达到了性能与效率的最佳平衡。