基于Claude Code实现MobileNetV3训练记录

MobileNetV3于2019年问世,是目前移动端实时视觉任务的首选,追求极致性价比。在同精度下,V3 比 V2 快约 20%~30%,参数量更小。MobileNetV1V2的原理介绍不再赘述。

代码链接

MobileNetV3 相比 V2 的主要改进

1. 基本模块结构的改进 MobileNetV3 对 V2 的逆向残差模块进行了重构,具体体现在两点:

  • 引入 SE 模块:在模块中集成了 SE(Squeeze-and-Excitation)模块,增强了特征表达能力。值得注意的是,SE 模块中的 sigmoid 函数被替换为计算更高效的 hard sigmoid 函数。

输入 (C×H×W)

Global AvgPool → (C×1×1) [Squeeze: 压缩空间信息]

FC1 → C/r [降维,r=4或8]

ReLU

FC2 → C [升维]

HardSwish

Sigmoid [Excitation: 生成通道权重]

乘以输入 (C×H×W) [通道重标定]

  • 使用新激活函数:采用了新的激活函数 H-swish。相比于 V2 使用的激活函数,H-swish 的图像与 swish 接近,有助于提高准确率,且计算更简单,能避免因设备拟合 sigmoid 函数方法不同带来的误差。

HardSwish(x) = x * ReLU6(x + 3) / 6

2. 网络结构的优化

  • 采用网络搜索算法:MobileNetV3 的结构并非完全人工设计,而是首先利用自动网络搜索算法得到基础结构,然后针对特定需求进行了人工调整。

  • 针对性结构精简:针对网络搜索算法得到的结构中"开始部分"和"尾部"计算耗时久的问题进行了优化,进一步提升了网络的运行效率。

  • 注意:官方mobilenetv3是 NAS 搜索结果的最终实现,也就是已经固定好的架构配置。这些配置是 NAS 搜索出来 + 人工微调后的最终产物。

    NAS 搜索本身是在模型设计阶段进行的,不是在训练时使用。要实现完整的 NAS 需要额外的框架和大量计算资源。我们使用时直接用最终架构,训练自己的数据集即可

MobileNetV3-Small 数据流:

═══════════════════════════════════════════════════════════

输入图像 (3×224×224)

┌──────────────────────────────────────────────────────┐

│ Conv3×3 + BN + ReLU6 │

│ 3 → 16, stride=2 │

│ 输出: 16×112×112 │

└──────────────────────────────────────────────────────┘

┌──────────────────────────────────────────────────────┐

│ 11 × bneck 块倒残差 │

│ (通道数: 16→24→40→48→96) │

│ 输出: 96×7×7 │

└──────────────────────────────────────────────────────┘

┌──────────────────────────────────────────────────────┐

│ Conv1×1 + BN + HardSwish │

│ 96 → 576 │

│ 输出: 576×7×7 │

└──────────────────────────────────────────────────────┘

┌──────────────────────────────────────────────────────┐

│ AdaptiveAvgPool2d(1) │

│ 输出: 576×1×1 │

└──────────────────────────────────────────────────────┘

┌──────────────────────────────────────────────────────┐

│ Flatten → FC(576→1024) → HardSwish → Dropout │

└──────────────────────────────────────────────────────┘

┌──────────────────────────────────────────────────────┐

│ FC(1024→102) │

│ 输出: 102 (花卉类别概率) │

└──────────────────────────────────────────────────────┘

本期主要实现基于ClaudeCode快速搭建MobileNetV3,并进行图像分类任务的训练和测试

需求文档撰写

主要是确认项目的功能和环境配置、以及目录架构等

数据集下载链接(支持自动下载)

ClaudeCode生成

MobileNetV3/

├── README.md ✅ 项目说明文档

├── config.py ✅ 配置文件 (model_type='small')

├── requirements.txt ✅ Python依赖 (Python 3.8兼容)

├── train.py ✅ 训练脚本

├── test.py ✅ 测试/推理脚本

├── checkpoints/ ✅ 检查点保存目录

├── utils/

│ ├── init.py ✅

│ ├── dataset.py ✅ Flowers102数据加载

│ ├── metrics.py ✅ 评估指标

│ ├── visualizer.py ✅ 可视化工具

│ └── flower_names.py ✅ 102类花卉名称

└── (运行时生成)

├── data/ # Flowers102数据集缓存

├── logs/ # TensorBoard日志

└── results/ # 测试结果和可视化

快速使用

1. 安装依赖

pip install -r requirements.txt

2. 开始训练 (使用MobileNetV3-Small)

python train.py

3. 测试评估

python test.py --checkpoint checkpoints/best_model.pth --visualize

4. 单张图像推理

python test.py --checkpoint checkpoints/best_model.pth --image flower.jpg --topk 5

#通过修改config配置文件即可修改训练模型的各种参数

#特殊训练场景

场景1:训练中意外中断(如断电、程序崩溃)

原始训练命令

python train.py --model small --epochs 50

训练到第23轮中断了,从 epoch_20.pth 恢复

python train.py --checkpoint checkpoints/epoch_20.pth

场景2:手动暂停后继续

训练到某个阶段后想继续训练更多轮次

python train.py --checkpoint checkpoints/best_model.pth --epochs 100

代码解读

主要了解代码如何加载数据集,训练模型和测试即可

数据集处理dataset.py

  1. 数据集加载 (torchvision.datasets.Flowers102)

train_dataset = torchvision.datasets.Flowers102(

root=data_root, # './data'

split='train', # 'train', 'val', 'test'

download=download, # 是否自动下载

transform=train_transform # 数据变换

)

  1. 数据变换 (Transform)

训练集变换(数据增强)

transforms.Compose([

transforms.RandomResizedCrop(224), # 随机裁剪+缩放

transforms.RandomHorizontalFlip(0.5), # 50%概率水平翻转

transforms.RandomRotation(15), # 随机旋转±15度

transforms.ColorJitter(...), # 颜色抖动

transforms.ToTensor(), # PIL → Tensor (0-1)

transforms.Normalize(mean, std) # ImageNet标准化

])

  1. 数据打包 (DataLoader)

train_loader = DataLoader(

train_dataset,

batch_size=32, # 每批32张图

shuffle=True, # 打乱顺序(仅训练集)

num_workers=0, # Windows设为0

pin_memory=True, # 加速GPU传输

drop_last=True # 丢弃最后不完整batch

)

训练脚本train.py

def set_seed(seed) - 保证可复现性

def get_model(model_type, num_classes, pretrained, dropout_rate):

1. 加载 torchvision 官方模型

2. 修改最后一层分类头 (ImageNet 1000类 → Flowers102 102类)

3. 可选添加 Dropout

def get_optimizer(model, config):

根据 config.OPTIMIZER 选择优化器:

- adamw: AdamW (推荐)

- adam: Adam

- sgd: SGD with momentum

def get_scheduler(optimizer, config):

三种学习率调度策略:

- cosine: 余弦退火 (推荐,平滑下降)

- step: 阶梯下降

- plateau: 根据指标自动调整

def train_epoch(model, train_loader, criterion, optimizer, device, epoch, config, writer):

model.train() # 设置为训练模式 (启用 Dropout, BatchNorm 更新)

def main(args):

1. 解析命令行参数,覆盖 config

2. set_seed(cfg.SEED)

3. 加载数据: get_dataloaders(cfg)

4. 创建模型: get_model(...)

5. 创建优化器: get_optimizer(...)

6. 创建调度器: get_scheduler(...)

7. 训练循环:

for epoch in range(start_epoch, EPOCHS + 1):

train_loss, train_acc = train_epoch(...)

val_loss, val_acc = validate(...)

save_checkpoint(...) # 保存模型

if 早停条件: break

8. 绘制训练历史曲线

完整训练流程:

main() 启动

├──→ set_seed(42) # 固定随机性

├──→ get_dataloaders() # 加载数据

│ └→ train_loader, val_loader, test_loader

├──→ get_model() # 创建模型

│ └→ model (mobilenet_v3_small)

├──→ get_optimizer() # 创建优化器

│ └→ optimizer (AdamW)

└──→ 训练循环 (50 epochs)

├──→ Epoch 1: train → val → save

├──→ Epoch 2: train → val → save

├──→ ...

└──→ Epoch 50: train → val → save

└→ plot_training_history() # 绘制曲线

训练结果

MobileNetV3-small-30轮

MobileNetV3-large-30轮

测试脚本test.py

函数结构

test.py

├── get_model() # 加载模型 + 检查点权重

├── test() # 批量测试评估

├── predict_single_image() # 单张图像推理

├── print_test_results() # 打印测试结果

├── print_prediction_results() # 打印预测结果

└── main() # 主函数,处理参数

基础测试

python test.py --checkpoint checkpoints/best_model.pth

测试 + 可视化(生成预测图和混淆矩阵)

python test.py --checkpoint checkpoints/best_model.pth --visualize

测试 + 显示每类准确率

python test.py --checkpoint checkpoints/best_model.pth --per-class

测试 + 全部功能

python test.py --checkpoint checkpoints/best_model.pth --visualize --per-class

单图推理模式

功能说明

对单张图像进行预测,返回 Top-K 个最可能的类别及其概率。

使用方法

预测单张图像(默认 Top-5)

python test.py --checkpoint checkpoints/best_model.pth --image flower.jpg

Top-3 预测

python test.py --checkpoint checkpoints/best_model.pth --image flower.jpg --topk 3

指定模型类型

python test.py --checkpoint checkpoints/best_model.pth --image flower.jpg --model large

相关推荐
China_Yanhy9 小时前
动手学大模型第一篇学习总结
人工智能
空间机器人9 小时前
自动驾驶 ADAS 器件选型:算力只是门票,系统才是生死线
人工智能·机器学习·自动驾驶
C+++Python9 小时前
提示词、Agent、MCP、Skill 到底是什么?
人工智能
小松要进步9 小时前
机器学习1
人工智能·机器学习
泰恒9 小时前
openclaw近期怎么样了?
人工智能·深度学习·机器学习
KaneLogger10 小时前
从传统笔记到 LLM 驱动的结构化 Wiki
人工智能·程序员·架构
tinygone10 小时前
OpenClaw之Memory配置成本地模式,Ubuntu+CUDA+cuDNN+llama.cpp
人工智能·ubuntu·llama
正在走向自律10 小时前
第二章-AIGC入门-AIGC工具全解析:技术控的效率神器,DeepSeek国产大模型的骄傲(8/36)
人工智能·chatgpt·aigc·可灵·deepseek·即梦·阿里通义千问
轩轩分享AI10 小时前
DeepSeek、Kimi、笔灵谁最好用?5款网文作者亲测的AI写作神器横评
人工智能·ai·ai写作·小说写作·小说·小说干货
Aevget10 小时前
基于嵌入向量的智能检索!HOOPS AI 解锁 CAD 零件相似性搜索新方式
人工智能·hoops·cad·hoops ai·cad数据格式