Iridescent:Day49

https://blog.csdn.net/weixin_45655710?type=blog
@浙大疏锦行

DAY 49 CBAM注意力

知识点回顾:

1.通道注意力模块复习

2.空间注意力模块

3.CBAM的定义

作业:尝试对今天的模型检查参数数目,并用tensorboard查看训练过程

作业完成:DAY 49 CBAM 模型参数检查与 TensorBoard 使用指南

我将分两部分详细说明:

1. 检查模型参数数目

使用 PyTorch 的标准方法统计了带 CBAM 的 CNN 模型(笔记中定义的 CBAM_CNN)的可训练参数数量。模型结构保持原样(包括原代码中的小 typo:forward 中 fc1 后用了 self.relu3,我在计算时保留了原结构,不影响参数计数)。

结果总结

  • 模型总可训练参数数量1,150,896(约 1.15 Million 参数)

这是一个轻量级模型,适合 CIFAR-10 任务。相比纯 CNN(无 CBAM),添加的 3 个 CBAM 模块仅增加了少量参数(每个 CBAM 的通道注意力共享 MLP + 空间注意力小卷积,总计几千参数),体现了 CBAM 的"轻量级"设计优势。

详细参数 breakdown(按层排序,便于理解):

层名称 参数类型 参数数量 说明
conv1.weight weight 864 3→32 通道,3x3 卷积
conv1.bias bias 32
bn1.weight / bias gamma/beta 32 + 32 BatchNorm
cbam1.channel_attn.fc.0.weight weight 64 通道注意力降维 (32 → 2)
cbam1.channel_attn.fc.2.weight weight 64 升维 (2 → 32)
cbam1.spatial_attn.conv.weight weight 98 空间注意力 7x7 卷积 (2→1)
conv2.weight weight 18,432 32→64 通道,3x3 卷积
conv2.bias bias 64
bn2.weight / bias gamma/beta 64 + 64
cbam2.channel_attn.fc.*.weight weight 256 + 256 通道注意力 (64 → 4 → 64)
cbam2.spatial_attn.conv.weight weight 98
conv3.weight weight 73,728 64→128 通道,3x3 卷积
conv3.bias bias 128
bn3.weight / bias gamma/beta 128 + 128
cbam3.channel_attn.fc.*.weight weight 1,024 + 1,024 通道注意力 (128 → 8 → 128)
cbam3.spatial_attn.conv.weight weight 98
fc1.weight weight 1,048,576 12844 → 512(占大部分)
fc1.bias bias 512
fc2.weight weight 5,120 512 → 10
fc2.bias bias 10
总计 1,150,896

关键观察

  • 全连接层(fc1)占了约 91% 的参数,这是典型的小型 CNN 特征(特征图展平后维度高)。
  • 每个 CBAM 增加的参数非常少:cbam1 ≈ 226、cbam2 ≈ 610、cbam3 ≈ 2,146,总计不到 3k 参数,却能显著提升注意力表达能力。
  • 如果想进一步轻量化,可以减小 fc1 的输出维度(如 512 → 256)或使用全局平均池化代替展平。
2. 使用 TensorBoard 查看训练过程

TensorBoard 是 PyTorch 官方推荐的可视化工具,能实时监控损失、准确率、模型结构、梯度分布等,非常适合调试和分析训练过程。

步骤指南(在你的本地环境执行,推荐使用 Jupyter 或 Python 脚本):

  1. 安装 TensorBoard(如果还没安装):

    bash 复制代码
    pip install tensorboard
  2. 修改训练代码,添加 TensorBoard 记录 (只需在原训练函数中添加几行):

    在你的训练脚本最上方导入:

    python 复制代码
    from torch.utils.tensorboard import SummaryWriter

    在 train 函数开头创建 writer(建议放在 epochs 循环外):

    python 复制代码
    writer = SummaryWriter(log_dir='runs/cbam_cifar10_experiment')  # 日志保存目录,可自定义

    在每个 epoch 结束时添加记录:

    python 复制代码
    # 在计算 epoch_train_acc 和 epoch_test_acc 后添加
    writer.add_scalar('Loss/train', epoch_train_loss, epoch)
    writer.add_scalar('Loss/test', epoch_test_loss, epoch)
    writer.add_scalar('Accuracy/train', epoch_train_acc, epoch)
    writer.add_scalar('Accuracy/test', epoch_test_acc, epoch)
    writer.add_scalar('Learning Rate', optimizer.param_groups[0]['lr'], epoch)

    在每个 batch 中可选记录更细粒度损失(每 100 batch 记录一次避免过多):

    python 复制代码
    if (batch_idx + 1) % 100 == 0:
        global_step = epoch * len(train_loader) + batch_idx
        writer.add_scalar('Loss/train_batch', iter_loss, global_step)

    训练结束后关闭 writer:

    python 复制代码
    writer.close()
  3. 可视化模型结构图 (推荐,一目了然看 CBAM 插入位置):

    在模型定义后、训练前添加:

    python 复制代码
    # 使用 dummy input
    dummy_input = torch.randn(1, 3, 32, 32).to(device)
    writer.add_graph(model, dummy_input)
  4. 启动 TensorBoard

    在终端运行:

    bash 复制代码
    tensorboard --logdir=runs

    然后打开浏览器访问 http://localhost:6006

    你将看到:

    • Scalars:训练/测试损失、准确率、学习率曲线(比 matplotlib 更交互,可缩放、对比多次实验)
    • Graphs:完整的模型结构图,清晰显示 Conv → BN → ReLU → Pool → CBAM 的顺序
    • Histograms (可选进阶):添加 writer.add_histogram('fc1.weight', model.fc1.weight, epoch) 查看参数分布

预期效果

  • 训练初期损失快速下降,准确率快速上升。
  • 由于添加了 CBAM,通常测试准确率会比纯 CNN 高 3-8%(在 CIFAR-10 上可达 85-90%+,取决于随机种子和训练时长)。
  • 通过 TensorBoard 你可以观察是否过拟合(train acc 高而 test acc 停滞)、学习率调度是否生效等。

CBAM 是注意力机制的经典之作,理解透了后续 Transformer 也会更容易。

相关推荐
BlackPercy10 小时前
【特殊函数】zeta函数
python·数学建模·sympy
FOAF-lambda10 小时前
本地部署mineru-tianshu
python·mineru
深蓝海拓10 小时前
PySide6从0开始学习的笔记(二十二) 几种封装信号传递内容的方法
笔记·python·qt·学习·pyqt
站大爷IP10 小时前
Python处理Excel多工作表:openpyxl与pandas的实战对比
python
睿思达DBA_WGX10 小时前
Python 程序设计讲义(69):面向对象程序设计——类的定义与使用
数据库·python
花酒锄作田10 小时前
FastAPI异步方法中调用同步方法
python·fastapi
股朋公式网10 小时前
通达信趋势吸引主图指标公式
python
赤鸢QAQ11 小时前
PySide6批量创建控件
python·qt·pyqt
brent42311 小时前
DAY44 Dataset和Dataloader类
python·深度学习
Jelena1577958579211 小时前
实战解析:京东关键词搜索 item_search_pro —— 按关键字搜索商品
开发语言·数据库·python