项目实践19—全球证件智能识别系统(优化检索算法:从MobileNet转EfficientNet)

目录

  • [一、 任务概述](#一、 任务概述)
  • [二、 理论基础与网络架构设计](#二、 理论基础与网络架构设计)
    • [2.1 EfficientNet-B3 的架构特点](#2.1 EfficientNet-B3 的架构特点)
    • [2.2 广义平均池化 (GeM Pooling)](#2.2 广义平均池化 (GeM Pooling))
    • [2.3 网络结构定义](#2.3 网络结构定义)
  • [三、 训练数据准备](#三、 训练数据准备)
    • [3.1 图像预处理适配](#3.1 图像预处理适配)
    • [3.2 数据集类更新](#3.2 数据集类更新)
  • [四、 微调训练流程实现](#四、 微调训练流程实现)
  • [五、 系统集成与部署](#五、 系统集成与部署)
    • [5.1 更新 `feature_extractor.py`](#5.1 更新 feature_extractor.py)
    • [5.2 数据库重建与特征重算](#5.2 数据库重建与特征重算)
  • [六、 总结](#六、 总结)

一、 任务概述

在全球证件智能识别系统的持续迭代中,证件版式检索模块的性能需要在"识别准确率"与"推理效率"之间寻找最佳平衡点。在前序的实践中,MobileNetV3虽然速度极快,但在处理未见样本(Zero-shot)及复杂版式时特征区分度不足。

为构建一个兼顾高精度与高效率的检索底座,本篇博客将对检索模块的特征提取网络进行升级。技术方案将由MobileNetV3切换至EfficientNet-B3,并结合广义平均池化(GeM Pooling)与度量学习(Metric Learning)进行全流程微调。

EfficientNet-B3是EfficientNet家族中的"中坚力量"。其输入分辨率标准为300x300,参数量约为47M,在ImageNet等基准测试中的表现依然远超MobileNet系列。配合GeM池化层,该模型能够有效捕捉证件中的微缩文字布局与防伪纹理特征,同时保持较快的推理速度,非常适合车管所窗口等对实时性有一定要求的业务场景。

本篇博客将详细阐述基于EfficientNet-B3的网络结构改造、GeM池化层的集成、适配300x300分辨率的数据预处理流程,以及完整的微调训练脚本。

二、 理论基础与网络架构设计

2.1 EfficientNet-B3 的架构特点

EfficientNet-B3通过复合缩放系数对网络进行了优化。与本项目前序尝试的模型对比如下:

  • MobileNetV3 Large: 输入224x224,特征较弱,极速。
  • EfficientNet-B3: 输入300x300,特征强,速度中等(平衡之选)。

B3的最后一次卷积层输出通道数为 1536(B5为2048,B1为1280),这一数值将直接影响后续投影头(Projection Head)的设计。

2.2 广义平均池化 (GeM Pooling)

为解决全局平均池化(GAP)导致的空间信息丢失问题,继续沿用广义平均池化(GeM)。其公式为:
f = ( 1 ∣ X ∣ ∑ x ∈ X x p ) 1 p \textbf{f} = \left( \frac{1}{|\mathcal{X}|} \sum_{x \in \mathcal{X}} x^p \right)^{\frac{1}{p}} f=(∣X∣1x∈X∑xp)p1

通过训练学习参数 p p p,网络能够自适应地关注证件图像中的显著区域(如Logo、印章),而非平均化所有背景信息。

2.3 网络结构定义

新建文件 model_efficientnet_b3.py,定义包含GeM池化层的EfficientNet-B3特征提取网络。

代码清单:model_efficientnet_b3.py

python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models

class GeM(nn.Module):
    """
    广义平均池化层 (Generalized Mean Pooling)
    参数:
        p (float): 初始p值,通常设置在3.0左右
        eps (float): 数值稳定性常数
    """
    def __init__(self, p=3.0, eps=1e-6):
        super(GeM, self).__init__()
        # 将p定义为可学习的参数
        self.p = nn.Parameter(torch.ones(1) * p)
        self.eps = eps

    def forward(self, x):
        return F.avg_pool2d(x.clamp(min=self.eps).pow(self.p), (x.size(-2), x.size(-1))).pow(1./self.p)

    def __repr__(self):
        return self.__class__.__name__ + '(' + 'p=' + '{:.4f}'.format(self.p.data.tolist()[0]) + ', ' + 'eps=' + str(self.eps) + ')'

class EfficientNetB3Embedding(nn.Module):
    """
    基于EfficientNet-B3的特征提取网络
    结构: Backbone(EfficientNet-B3) -> GeM Pooling -> Flatten -> Linear(Projection)
    """
    def __init__(self, embedding_dim=1024, pretrained=True):
        super(EfficientNetB3Embedding, self).__init__()
        
        # 1. 加载预训练的EfficientNet-B3
        weights = models.EfficientNet_B3_Weights.DEFAULT if pretrained else None
        base_model = models.efficientnet_b3(weights=weights)
        
        # 2. 提取特征层
        # EfficientNet的features部分输出最后的卷积特征图
        self.features = base_model.features
        
        # 3. 引入GeM池化层
        self.pool = GeM()
        
        # 4. 获取Backbone输出通道数
        # EfficientNet-B3 最后一层卷积输出通道数为 1536
        out_channels = 1536
        
        # 5. 定义Flatten层
        self.flatten = nn.Flatten()
        
        # 6. 定义投影头 (Projection Head)
        # 将高维特征映射到指定的Embedding维度
        self.fc = nn.Sequential(
            nn.Linear(out_channels, out_channels),
            nn.BatchNorm1d(out_channels),
            nn.ReLU(),
            nn.Linear(out_channels, embedding_dim)
        )

    def forward(self, x):
        # 输入: [Batch, 3, 300, 300]
        # 提取特征图: [Batch, 1536, H, W]
        x = self.features(x)
        
        # GeM池化: [Batch, 1536, 1, 1]
        x = self.pool(x)
        
        # 展平: [Batch, 1536]
        x = self.flatten(x)
        
        # 线性投影: [Batch, embedding_dim]
        x = self.fc(x)
        
        # L2归一化
        x = F.normalize(x, p=2, dim=1)
        
        return x

三、 训练数据准备

3.1 图像预处理适配

EfficientNet-B3的标准输入尺寸为 300x300。数据预处理模块需将图像Resize至该尺寸。

3.2 数据集类更新

修改 TripletDataset 类中的尺寸配置。

代码清单:dataset_loader.py

python 复制代码
import random
from pathlib import Path
from PIL import Image, ImageDraw
import torch
from torch.utils.data import Dataset
from torchvision import transforms

# --- 数据增强:随机划痕 (保持不变) ---
class RandomScratches:
    def __init__(self, num_scratches_range=(1, 5), p=0.5):
        self.num_scratches_range = num_scratches_range
        self.p = p

    def __call__(self, img):
        if random.random() > self.p:
            return img
        img_draw = img.copy()
        draw = ImageDraw.Draw(img_draw)
        width, height = img.size
        num_scratches = random.randint(*self.num_scratches_range)
        for _ in range(num_scratches):
            x1, y1 = random.randint(0, width), random.randint(0, height)
            x2, y2 = random.randint(0, width), random.randint(0, height)
            line_width = random.randint(1, 2) 
            line_color = random.randint(50, 200)
            draw.line([(x1, y1), (x2, y2)], fill=(line_color, line_color, line_color), width=line_width)
        return img_draw

class TripletDataset(Dataset):
    """
    三元组数据集加载器
    适配 EfficientNet-B3 输入尺寸 (300x300)
    """
    def __init__(self, image_dir, image_type_suffix, is_train=True):
        self.image_dir = Path(image_dir)
        self.image_type_suffix = image_type_suffix
        self.is_train = is_train
        
        # ImageNet 均值和方差
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])
        
        # EfficientNet-B3 输入尺寸
        input_size = 300
        
        if self.is_train:
            self.transform = transforms.Compose([
                transforms.Resize((320, 320)), # 先放大稍多一点
                transforms.RandomCrop((input_size, input_size)), # 随机裁剪
                transforms.RandomHorizontalFlip(p=0.3),
                transforms.RandomAffine(degrees=10, translate=(0.1, 0.1), scale=(0.9, 1.1)),
                transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
                RandomScratches(p=0.3),
                transforms.ToTensor(),
                normalize
            ])
        else:
            self.transform = transforms.Compose([
                transforms.Resize((input_size, input_size)),
                transforms.ToTensor(),
                normalize
            ])

        self.samples = []
        self.class_to_images = {}
        self.country_to_classes = {}

        # 建立索引逻辑
        self._build_index()

    def _build_index(self):
        print(f"正在扫描 '{self.image_type_suffix}' 类型数据...")
        for country_dir in self.image_dir.iterdir():
            if not country_dir.is_dir(): continue
            country_name = country_dir.name
            
            for state_dir in country_dir.iterdir():
                if not state_dir.is_dir(): continue
                
                for template_dir in state_dir.iterdir():
                    if not template_dir.is_dir(): continue
                    
                    # 查找对应后缀的图像
                    image_files = list(template_dir.glob(f"*{self.image_type_suffix}"))
                    if image_files:
                        class_id = f"{country_name}_{state_dir.name}_{template_dir.name}"
                        
                        if country_name not in self.country_to_classes:
                            self.country_to_classes[country_name] = []
                        if class_id not in self.country_to_classes[country_name]:
                            self.country_to_classes[country_name].append(class_id)
                            
                        self.class_to_images[class_id] = image_files
                        for img_path in image_files:
                            self.samples.append((class_id, img_path, country_name))
        
        print(f"索引建立完成,共找到 {len(self.samples)} 个样本。")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, index):
        # 1. 获取 Anchor
        anchor_class, anchor_path, anchor_country = self.samples[index]
        anchor_img = self._load_image(anchor_path)

        # 2. 选取 Positive (同类样本)
        possible_positives = self.class_to_images[anchor_class]
        if len(possible_positives) < 2:
            # 只有一张图时,强增强作为正样本
            positive_img = self.transform(Image.open(anchor_path).convert('L').convert('RGB'))
        else:
            positive_path = random.choice(possible_positives)
            while positive_path == anchor_path and len(possible_positives) > 1:
                positive_path = random.choice(possible_positives)
            positive_img = self._load_image(positive_path)

        # 3. 选取 Negative (异类样本)
        possible_neg_classes = [c for c in self.country_to_classes.get(anchor_country, []) if c != anchor_class]
        if not possible_neg_classes:
            all_classes = list(self.class_to_images.keys())
            possible_neg_classes = [c for c in all_classes if c != anchor_class]
        
        negative_class = random.choice(possible_neg_classes)
        negative_path = random.choice(self.class_to_images[negative_class])
        negative_img = self._load_image(negative_path)

        return anchor_img, positive_img, negative_img

    def _load_image(self, path):
        # 保持RGB输入
        return self.transform(Image.open(path).convert('L').convert('RGB'))

四、 微调训练流程实现

编写训练脚本 train_efficientnet_b3.py

代码清单:train_efficientnet_b3.py

python 复制代码
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, ConcatDataset
from model_efficientnet_b3 import EfficientNetB3Embedding
from dataset_loader import TripletDataset

def train():
    # 1. 配置参数
    # EfficientNet-B3 显存占用适中
    BATCH_SIZE = 12  
    LEARNING_RATE = 1e-4
    NUM_EPOCHS = 50
    EMBEDDING_DIM = 512
    MARGIN = 1.0
    SAVE_PATH = "efficientnet_b3_gem_finetuned.pth"
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"训练设备: {device}")

    # 2. 准备数据集
    print("正在加载数据集...")
    dataset_front = TripletDataset(image_dir="samples", image_type_suffix="_front_white.jpg", is_train=True)
    dataset_back = TripletDataset(image_dir="samples", image_type_suffix="_back_white.jpg", is_train=True)
    
    if len(dataset_front) == 0 and len(dataset_back) == 0:
        print("错误:未找到有效样本,请检查samples目录结构。")
        return

    full_dataset = ConcatDataset([dataset_front, dataset_back])
    
    dataloader = DataLoader(full_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True)
    
    print(f"数据集加载完成,总样本数(三元组): {len(full_dataset)}")

    # 3. 初始化模型
    print("正在初始化 EfficientNet-B3 模型...")
    model = EfficientNetB3Embedding(embedding_dim=EMBEDDING_DIM, pretrained=True).to(device)
    
    # 4. 定义损失函数和优化器
    criterion = nn.TripletMarginLoss(margin=MARGIN, p=2)
    optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS, eta_min=1e-6)

    # 5. 训练循环
    best_loss = float('inf')
    
    model.train()
    print("开始训练...")
    
    for epoch in range(NUM_EPOCHS):
        running_loss = 0.0
        
        for i, (anchor, positive, negative) in enumerate(dataloader):
            anchor = anchor.to(device)
            positive = positive.to(device)
            negative = negative.to(device)
            
            # 前向传播
            emb_a = model(anchor)
            emb_p = model(positive)
            emb_n = model(negative)
            
            # 计算损失
            loss = criterion(emb_a, emb_p, emb_n)
            
            # 反向传播与优化
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            
            if (i + 1) % 10 == 0:
                print(f"Epoch [{epoch+1}/{NUM_EPOCHS}], Step [{i+1}/{len(dataloader)}], Loss: {loss.item():.4f}")

        # 计算Epoch平均损失
        epoch_loss = running_loss / len(dataloader)
        current_lr = optimizer.param_groups[0]['lr']
        print(f"Epoch [{epoch+1}/{NUM_EPOCHS}] 完成, 平均Loss: {epoch_loss:.4f}, 当前LR: {current_lr:.6f}")
        
        # 更新学习率
        scheduler.step()
        
        # 保存最佳模型
        if epoch_loss < best_loss:
            best_loss = epoch_loss
            torch.save(model.state_dict(), SAVE_PATH)
            print(f"--> 模型性能提升,已保存至 {SAVE_PATH}")

    print("训练结束。")

if __name__ == "__main__":
    train()

执行训练:

在项目根目录运行以下命令:

bash 复制代码
python train_efficientnet_b3.py

五、 系统集成与部署

模型训练完成后,需更新后端的特征提取模块,并重建特征数据库。

5.1 更新 feature_extractor.py

修改特征提取器以适配EfficientNet-B3。

代码清单:feature_extractor.py

python 复制代码
import io
import pickle
import os
import torch
import numpy as np
from PIL import Image
from torchvision import transforms
# 导入B3模型结构
from model_efficientnet_b3 import EfficientNetB3Embedding

class ImageFeatureExtractor:
    def __init__(self, model_path="efficientnet_b3_gem_finetuned.pth"):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        # 初始化模型,Embedding维度为512
        # pretrained=False 因为我们将加载自己的微调权重
        self.model = EfficientNetB3Embedding(embedding_dim=512, pretrained=False).to(self.device)
        
        # 加载微调权重
        if os.path.exists(model_path):
            self.model.load_state_dict(torch.load(model_path, map_location=self.device))
            print(f"成功加载 EfficientNet-B3 微调模型: {model_path}")
        else:
            print(f"警告: 未找到权重文件 {model_path},使用随机初始化权重(仅用于测试)")
            
        self.model.eval()

        # 预处理:适配EfficientNet-B3的300x300输入
        self.preprocess = transforms.Compose([
            transforms.Resize((300, 300)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                                 std=[0.229, 0.224, 0.225])
        ])

    def extract_features(self, image_bytes: bytes) -> bytes:
        """
        提取特征并返回序列化后的字节流
        """
        try:
            # 统一转为RGB
            image = Image.open(io.BytesIO(image_bytes)).convert("L").convert("RGB")
            
            input_tensor = self.preprocess(image)
            input_batch = input_tensor.unsqueeze(0).to(self.device)

            with torch.no_grad():
                output_features = self.model(input_batch)

            # 转换为Numpy并序列化
            feature_np = output_features.cpu().numpy().flatten()
            return pickle.dumps(feature_np)
            
        except Exception as e:
            print(f"特征提取失败: {e}")
            return b''

5.2 数据库重建与特征重算

EfficientNet-B3生成的特征空间与原模型完全不同。必须执行全量数据库重建。

首先修改init_db.py文件:

python 复制代码
# 注释掉原来的代码
# extractor = ImageFeatureExtractor(model_path="mobilenetv3_finetuned.pth")
# 改成下面的代码
extractor = ImageFeatureExtractor(model_path="./efficientnet_b3_gem_finetuned.pth")

然后依次执行如下命令:

bash 复制代码
# delete old db file
rm card_db.sqlite
# delete old alembic versions
rm -rf alembic/versions/*
# generate new alembic revision
alembic revision --autogenerate -m "add_layout_schema"
# add import sqlmodel to the new revision file then apply the migration
alembic upgrade head
# init db
python init_db.py

六、 总结

本篇博客详细记录了将全球证件识别系统的检索模块升级至EfficientNet-B3的完整工程实践。相比于MobileNetV3,EfficientNet-B3提供了更强的特征表达能力。结合GeM池化与度量学习,系统现在能够生成区分度极高的"证件指纹",有效解决了未见样本检索不准的问题,为后续的SIFT精排提供了高质量的候选集。这一升级标志着系统在算法层面达到了性能与效率的最佳平衡。

相关推荐
沉默-_-1 小时前
力扣hot100滑动窗口(C++)
数据结构·c++·学习·算法·滑动窗口
feifeigo1232 小时前
基于EM算法的混合Copula MATLAB实现
开发语言·算法·matlab
漫随流水2 小时前
leetcode回溯算法(78.子集)
数据结构·算法·leetcode·回溯算法
IT猿手2 小时前
六种智能优化算法(NOA、MA、PSO、GA、ZOA、SWO)求解23个基准测试函数(含参考文献及MATLAB代码)
开发语言·算法·matlab·无人机·无人机路径规划·最新多目标优化算法
We་ct2 小时前
LeetCode 151. 反转字符串中的单词:两种解法深度剖析
前端·算法·leetcode·typescript
芜湖xin3 小时前
【题解-Acwing】AcWing 5579. 增加模数(TLE)
算法·快速幂
清酒难咽3 小时前
算法案例之分治法
c++·经验分享·算法
wen__xvn3 小时前
代码随想录算法训练营DAY25第七章 回溯算法 part04
算法·leetcode·深度优先
亲爱的非洲野猪3 小时前
动态规划进阶:序列DP深度解析
算法·动态规划