基于Claude Code实现MobileNetV3训练记录

MobileNetV3于2019年问世,是目前移动端实时视觉任务的首选,追求极致性价比。在同精度下,V3 比 V2 快约 20%~30%,参数量更小。MobileNetV1V2的原理介绍不再赘述。

代码链接

MobileNetV3 相比 V2 的主要改进

1. 基本模块结构的改进 MobileNetV3 对 V2 的逆向残差模块进行了重构,具体体现在两点:

  • 引入 SE 模块:在模块中集成了 SE(Squeeze-and-Excitation)模块,增强了特征表达能力。值得注意的是,SE 模块中的 sigmoid 函数被替换为计算更高效的 hard sigmoid 函数。

输入 (C×H×W)

Global AvgPool → (C×1×1) [Squeeze: 压缩空间信息]

FC1 → C/r [降维,r=4或8]

ReLU

FC2 → C [升维]

HardSwish

Sigmoid [Excitation: 生成通道权重]

乘以输入 (C×H×W) [通道重标定]

  • 使用新激活函数:采用了新的激活函数 H-swish。相比于 V2 使用的激活函数,H-swish 的图像与 swish 接近,有助于提高准确率,且计算更简单,能避免因设备拟合 sigmoid 函数方法不同带来的误差。

HardSwish(x) = x * ReLU6(x + 3) / 6

2. 网络结构的优化

  • 采用网络搜索算法:MobileNetV3 的结构并非完全人工设计,而是首先利用自动网络搜索算法得到基础结构,然后针对特定需求进行了人工调整。

  • 针对性结构精简:针对网络搜索算法得到的结构中"开始部分"和"尾部"计算耗时久的问题进行了优化,进一步提升了网络的运行效率。

  • 注意:官方mobilenetv3是 NAS 搜索结果的最终实现,也就是已经固定好的架构配置。这些配置是 NAS 搜索出来 + 人工微调后的最终产物。

    NAS 搜索本身是在模型设计阶段进行的,不是在训练时使用。要实现完整的 NAS 需要额外的框架和大量计算资源。我们使用时直接用最终架构,训练自己的数据集即可

MobileNetV3-Small 数据流:

═══════════════════════════════════════════════════════════

输入图像 (3×224×224)

┌──────────────────────────────────────────────────────┐

│ Conv3×3 + BN + ReLU6 │

│ 3 → 16, stride=2 │

│ 输出: 16×112×112 │

└──────────────────────────────────────────────────────┘

┌──────────────────────────────────────────────────────┐

│ 11 × bneck 块倒残差 │

│ (通道数: 16→24→40→48→96) │

│ 输出: 96×7×7 │

└──────────────────────────────────────────────────────┘

┌──────────────────────────────────────────────────────┐

│ Conv1×1 + BN + HardSwish │

│ 96 → 576 │

│ 输出: 576×7×7 │

└──────────────────────────────────────────────────────┘

┌──────────────────────────────────────────────────────┐

│ AdaptiveAvgPool2d(1) │

│ 输出: 576×1×1 │

└──────────────────────────────────────────────────────┘

┌──────────────────────────────────────────────────────┐

│ Flatten → FC(576→1024) → HardSwish → Dropout │

└──────────────────────────────────────────────────────┘

┌──────────────────────────────────────────────────────┐

│ FC(1024→102) │

│ 输出: 102 (花卉类别概率) │

└──────────────────────────────────────────────────────┘

本期主要实现基于ClaudeCode快速搭建MobileNetV3,并进行图像分类任务的训练和测试

需求文档撰写

主要是确认项目的功能和环境配置、以及目录架构等

数据集下载链接(支持自动下载)

ClaudeCode生成

MobileNetV3/

├── README.md ✅ 项目说明文档

├── config.py ✅ 配置文件 (model_type='small')

├── requirements.txt ✅ Python依赖 (Python 3.8兼容)

├── train.py ✅ 训练脚本

├── test.py ✅ 测试/推理脚本

├── checkpoints/ ✅ 检查点保存目录

├── utils/

│ ├── init.py ✅

│ ├── dataset.py ✅ Flowers102数据加载

│ ├── metrics.py ✅ 评估指标

│ ├── visualizer.py ✅ 可视化工具

│ └── flower_names.py ✅ 102类花卉名称

└── (运行时生成)

├── data/ # Flowers102数据集缓存

├── logs/ # TensorBoard日志

└── results/ # 测试结果和可视化

快速使用

1. 安装依赖

pip install -r requirements.txt

2. 开始训练 (使用MobileNetV3-Small)

python train.py

3. 测试评估

python test.py --checkpoint checkpoints/best_model.pth --visualize

4. 单张图像推理

python test.py --checkpoint checkpoints/best_model.pth --image flower.jpg --topk 5

#通过修改config配置文件即可修改训练模型的各种参数

#特殊训练场景

场景1:训练中意外中断(如断电、程序崩溃)

原始训练命令

python train.py --model small --epochs 50

训练到第23轮中断了,从 epoch_20.pth 恢复

python train.py --checkpoint checkpoints/epoch_20.pth

场景2:手动暂停后继续

训练到某个阶段后想继续训练更多轮次

python train.py --checkpoint checkpoints/best_model.pth --epochs 100

代码解读

主要了解代码如何加载数据集,训练模型和测试即可

数据集处理dataset.py

  1. 数据集加载 (torchvision.datasets.Flowers102)

train_dataset = torchvision.datasets.Flowers102(

root=data_root, # './data'

split='train', # 'train', 'val', 'test'

download=download, # 是否自动下载

transform=train_transform # 数据变换

)

  1. 数据变换 (Transform)

训练集变换(数据增强)

transforms.Compose([

transforms.RandomResizedCrop(224), # 随机裁剪+缩放

transforms.RandomHorizontalFlip(0.5), # 50%概率水平翻转

transforms.RandomRotation(15), # 随机旋转±15度

transforms.ColorJitter(...), # 颜色抖动

transforms.ToTensor(), # PIL → Tensor (0-1)

transforms.Normalize(mean, std) # ImageNet标准化

])

  1. 数据打包 (DataLoader)

train_loader = DataLoader(

train_dataset,

batch_size=32, # 每批32张图

shuffle=True, # 打乱顺序(仅训练集)

num_workers=0, # Windows设为0

pin_memory=True, # 加速GPU传输

drop_last=True # 丢弃最后不完整batch

)

训练脚本train.py

def set_seed(seed) - 保证可复现性

def get_model(model_type, num_classes, pretrained, dropout_rate):

1. 加载 torchvision 官方模型

2. 修改最后一层分类头 (ImageNet 1000类 → Flowers102 102类)

3. 可选添加 Dropout

def get_optimizer(model, config):

根据 config.OPTIMIZER 选择优化器:

- adamw: AdamW (推荐)

- adam: Adam

- sgd: SGD with momentum

def get_scheduler(optimizer, config):

三种学习率调度策略:

- cosine: 余弦退火 (推荐,平滑下降)

- step: 阶梯下降

- plateau: 根据指标自动调整

def train_epoch(model, train_loader, criterion, optimizer, device, epoch, config, writer):

model.train() # 设置为训练模式 (启用 Dropout, BatchNorm 更新)

def main(args):

1. 解析命令行参数,覆盖 config

2. set_seed(cfg.SEED)

3. 加载数据: get_dataloaders(cfg)

4. 创建模型: get_model(...)

5. 创建优化器: get_optimizer(...)

6. 创建调度器: get_scheduler(...)

7. 训练循环:

for epoch in range(start_epoch, EPOCHS + 1):

train_loss, train_acc = train_epoch(...)

val_loss, val_acc = validate(...)

save_checkpoint(...) # 保存模型

if 早停条件: break

8. 绘制训练历史曲线

完整训练流程:

main() 启动

├──→ set_seed(42) # 固定随机性

├──→ get_dataloaders() # 加载数据

│ └→ train_loader, val_loader, test_loader

├──→ get_model() # 创建模型

│ └→ model (mobilenet_v3_small)

├──→ get_optimizer() # 创建优化器

│ └→ optimizer (AdamW)

└──→ 训练循环 (50 epochs)

├──→ Epoch 1: train → val → save

├──→ Epoch 2: train → val → save

├──→ ...

└──→ Epoch 50: train → val → save

└→ plot_training_history() # 绘制曲线

训练结果

MobileNetV3-small-30轮

MobileNetV3-large-30轮

测试脚本test.py

函数结构

test.py

├── get_model() # 加载模型 + 检查点权重

├── test() # 批量测试评估

├── predict_single_image() # 单张图像推理

├── print_test_results() # 打印测试结果

├── print_prediction_results() # 打印预测结果

└── main() # 主函数,处理参数

基础测试

python test.py --checkpoint checkpoints/best_model.pth

测试 + 可视化(生成预测图和混淆矩阵)

python test.py --checkpoint checkpoints/best_model.pth --visualize

测试 + 显示每类准确率

python test.py --checkpoint checkpoints/best_model.pth --per-class

测试 + 全部功能

python test.py --checkpoint checkpoints/best_model.pth --visualize --per-class

单图推理模式

功能说明

对单张图像进行预测,返回 Top-K 个最可能的类别及其概率。

使用方法

预测单张图像(默认 Top-5)

python test.py --checkpoint checkpoints/best_model.pth --image flower.jpg

Top-3 预测

python test.py --checkpoint checkpoints/best_model.pth --image flower.jpg --topk 3

指定模型类型

python test.py --checkpoint checkpoints/best_model.pth --image flower.jpg --model large

相关推荐
Loo国昌1 小时前
【AI应用开发实战】05_GraphRAG:知识图谱增强检索实战
人工智能·后端·python·语言模型·自然语言处理·金融·知识图谱
Dr.AE1 小时前
金蝶AI星辰 产品分析报告
大数据·人工智能
LaughingZhu2 小时前
Product Hunt 每日热榜 | 2026-02-22
人工智能·经验分享·深度学习·神经网络·产品运营
数据智能老司机2 小时前
打造 ML/AI 系统的内部开发者平台(IDP)——设计可靠的机器学习(ML)系统
人工智能·llm·aiops
上进小菜猪2 小时前
基于 YOLOv8 的面向矿井场景的煤炭图像智能检测系统 [目标检测完整源码](YOLOv8 + PyQt5 实战)
人工智能
~央千澈~2 小时前
08实战处理AI音乐技术详解第三阶段:时间人性化(Timing Humanization)·卓伊凡
人工智能
xwz小王子2 小时前
Nature Electronics:基于单尖峰编码的人机界面端到端忆阻硬件系统
人工智能·忆阻
后台技术汇2 小时前
读书笔记:《以日为鉴》-- 从日本失落的三十年看中国互联网与AI产业的未来
人工智能
Ray Liang2 小时前
Opus现实打脸GLM5“教课书“式架构
人工智能·mindx