Iridescent:Day49

https://blog.csdn.net/weixin_45655710?type=blog
@浙大疏锦行

DAY 49 CBAM注意力

知识点回顾:

1.通道注意力模块复习

2.空间注意力模块

3.CBAM的定义

作业:尝试对今天的模型检查参数数目,并用tensorboard查看训练过程

作业完成:DAY 49 CBAM 模型参数检查与 TensorBoard 使用指南

我将分两部分详细说明:

1. 检查模型参数数目

使用 PyTorch 的标准方法统计了带 CBAM 的 CNN 模型(笔记中定义的 CBAM_CNN)的可训练参数数量。模型结构保持原样(包括原代码中的小 typo:forward 中 fc1 后用了 self.relu3,我在计算时保留了原结构,不影响参数计数)。

结果总结

  • 模型总可训练参数数量1,150,896(约 1.15 Million 参数)

这是一个轻量级模型,适合 CIFAR-10 任务。相比纯 CNN(无 CBAM),添加的 3 个 CBAM 模块仅增加了少量参数(每个 CBAM 的通道注意力共享 MLP + 空间注意力小卷积,总计几千参数),体现了 CBAM 的"轻量级"设计优势。

详细参数 breakdown(按层排序,便于理解):

层名称 参数类型 参数数量 说明
conv1.weight weight 864 3→32 通道,3x3 卷积
conv1.bias bias 32
bn1.weight / bias gamma/beta 32 + 32 BatchNorm
cbam1.channel_attn.fc.0.weight weight 64 通道注意力降维 (32 → 2)
cbam1.channel_attn.fc.2.weight weight 64 升维 (2 → 32)
cbam1.spatial_attn.conv.weight weight 98 空间注意力 7x7 卷积 (2→1)
conv2.weight weight 18,432 32→64 通道,3x3 卷积
conv2.bias bias 64
bn2.weight / bias gamma/beta 64 + 64
cbam2.channel_attn.fc.*.weight weight 256 + 256 通道注意力 (64 → 4 → 64)
cbam2.spatial_attn.conv.weight weight 98
conv3.weight weight 73,728 64→128 通道,3x3 卷积
conv3.bias bias 128
bn3.weight / bias gamma/beta 128 + 128
cbam3.channel_attn.fc.*.weight weight 1,024 + 1,024 通道注意力 (128 → 8 → 128)
cbam3.spatial_attn.conv.weight weight 98
fc1.weight weight 1,048,576 12844 → 512(占大部分)
fc1.bias bias 512
fc2.weight weight 5,120 512 → 10
fc2.bias bias 10
总计 1,150,896

关键观察

  • 全连接层(fc1)占了约 91% 的参数,这是典型的小型 CNN 特征(特征图展平后维度高)。
  • 每个 CBAM 增加的参数非常少:cbam1 ≈ 226、cbam2 ≈ 610、cbam3 ≈ 2,146,总计不到 3k 参数,却能显著提升注意力表达能力。
  • 如果想进一步轻量化,可以减小 fc1 的输出维度(如 512 → 256)或使用全局平均池化代替展平。
2. 使用 TensorBoard 查看训练过程

TensorBoard 是 PyTorch 官方推荐的可视化工具,能实时监控损失、准确率、模型结构、梯度分布等,非常适合调试和分析训练过程。

步骤指南(在你的本地环境执行,推荐使用 Jupyter 或 Python 脚本):

  1. 安装 TensorBoard(如果还没安装):

    bash 复制代码
    pip install tensorboard
  2. 修改训练代码,添加 TensorBoard 记录 (只需在原训练函数中添加几行):

    在你的训练脚本最上方导入:

    python 复制代码
    from torch.utils.tensorboard import SummaryWriter

    在 train 函数开头创建 writer(建议放在 epochs 循环外):

    python 复制代码
    writer = SummaryWriter(log_dir='runs/cbam_cifar10_experiment')  # 日志保存目录,可自定义

    在每个 epoch 结束时添加记录:

    python 复制代码
    # 在计算 epoch_train_acc 和 epoch_test_acc 后添加
    writer.add_scalar('Loss/train', epoch_train_loss, epoch)
    writer.add_scalar('Loss/test', epoch_test_loss, epoch)
    writer.add_scalar('Accuracy/train', epoch_train_acc, epoch)
    writer.add_scalar('Accuracy/test', epoch_test_acc, epoch)
    writer.add_scalar('Learning Rate', optimizer.param_groups[0]['lr'], epoch)

    在每个 batch 中可选记录更细粒度损失(每 100 batch 记录一次避免过多):

    python 复制代码
    if (batch_idx + 1) % 100 == 0:
        global_step = epoch * len(train_loader) + batch_idx
        writer.add_scalar('Loss/train_batch', iter_loss, global_step)

    训练结束后关闭 writer:

    python 复制代码
    writer.close()
  3. 可视化模型结构图 (推荐,一目了然看 CBAM 插入位置):

    在模型定义后、训练前添加:

    python 复制代码
    # 使用 dummy input
    dummy_input = torch.randn(1, 3, 32, 32).to(device)
    writer.add_graph(model, dummy_input)
  4. 启动 TensorBoard

    在终端运行:

    bash 复制代码
    tensorboard --logdir=runs

    然后打开浏览器访问 http://localhost:6006

    你将看到:

    • Scalars:训练/测试损失、准确率、学习率曲线(比 matplotlib 更交互,可缩放、对比多次实验)
    • Graphs:完整的模型结构图,清晰显示 Conv → BN → ReLU → Pool → CBAM 的顺序
    • Histograms (可选进阶):添加 writer.add_histogram('fc1.weight', model.fc1.weight, epoch) 查看参数分布

预期效果

  • 训练初期损失快速下降,准确率快速上升。
  • 由于添加了 CBAM,通常测试准确率会比纯 CNN 高 3-8%(在 CIFAR-10 上可达 85-90%+,取决于随机种子和训练时长)。
  • 通过 TensorBoard 你可以观察是否过拟合(train acc 高而 test acc 停滞)、学习率调度是否生效等。

CBAM 是注意力机制的经典之作,理解透了后续 Transformer 也会更容易。

相关推荐
用户8356290780519 小时前
Python 实现 PDF 文件加密与解密方法
后端·python
用户8356290780519 小时前
使用 Python 冻结与拆分 Excel 窗格教程
后端·python
你好潘先生17 小时前
别再记命令了,用 yeero do 说句人话就能跑脚本,而且不烧 token
服务器·python·命令行
Agent_大师18 小时前
WebSocket 行情重连成功,K线缺口不会自动消失
python
荣码18 小时前
LLM结构化输出:让AI返回JSON而不是废话,我踩了4个坑
java·python
copyer_xyf18 小时前
FastAPI 如何连接 MySQL
后端·python
apocelipes1 天前
常用编程语言和库的正则表达式性能对比
c语言·c++·python·性能优化·golang·开发工具和环境
用户8356290780511 天前
使用 Python 在 PDF 中创建与管理书签
后端·python
MeixianAgent2 天前
Python 回测数据入口怎么验?历史 K 线入库前先做 5 个检查
后端·python
咕白m6252 天前
用 Python 实现一键批量查找与替换 Excel 数据
后端·python