Iridescent:Day49

https://blog.csdn.net/weixin_45655710?type=blog
@浙大疏锦行

DAY 49 CBAM注意力

知识点回顾:

1.通道注意力模块复习

2.空间注意力模块

3.CBAM的定义

作业:尝试对今天的模型检查参数数目,并用tensorboard查看训练过程

作业完成:DAY 49 CBAM 模型参数检查与 TensorBoard 使用指南

我将分两部分详细说明:

1. 检查模型参数数目

使用 PyTorch 的标准方法统计了带 CBAM 的 CNN 模型(笔记中定义的 CBAM_CNN)的可训练参数数量。模型结构保持原样(包括原代码中的小 typo:forward 中 fc1 后用了 self.relu3,我在计算时保留了原结构,不影响参数计数)。

结果总结

  • 模型总可训练参数数量1,150,896(约 1.15 Million 参数)

这是一个轻量级模型,适合 CIFAR-10 任务。相比纯 CNN(无 CBAM),添加的 3 个 CBAM 模块仅增加了少量参数(每个 CBAM 的通道注意力共享 MLP + 空间注意力小卷积,总计几千参数),体现了 CBAM 的"轻量级"设计优势。

详细参数 breakdown(按层排序,便于理解):

层名称 参数类型 参数数量 说明
conv1.weight weight 864 3→32 通道,3x3 卷积
conv1.bias bias 32
bn1.weight / bias gamma/beta 32 + 32 BatchNorm
cbam1.channel_attn.fc.0.weight weight 64 通道注意力降维 (32 → 2)
cbam1.channel_attn.fc.2.weight weight 64 升维 (2 → 32)
cbam1.spatial_attn.conv.weight weight 98 空间注意力 7x7 卷积 (2→1)
conv2.weight weight 18,432 32→64 通道,3x3 卷积
conv2.bias bias 64
bn2.weight / bias gamma/beta 64 + 64
cbam2.channel_attn.fc.*.weight weight 256 + 256 通道注意力 (64 → 4 → 64)
cbam2.spatial_attn.conv.weight weight 98
conv3.weight weight 73,728 64→128 通道,3x3 卷积
conv3.bias bias 128
bn3.weight / bias gamma/beta 128 + 128
cbam3.channel_attn.fc.*.weight weight 1,024 + 1,024 通道注意力 (128 → 8 → 128)
cbam3.spatial_attn.conv.weight weight 98
fc1.weight weight 1,048,576 12844 → 512(占大部分)
fc1.bias bias 512
fc2.weight weight 5,120 512 → 10
fc2.bias bias 10
总计 1,150,896

关键观察

  • 全连接层(fc1)占了约 91% 的参数,这是典型的小型 CNN 特征(特征图展平后维度高)。
  • 每个 CBAM 增加的参数非常少:cbam1 ≈ 226、cbam2 ≈ 610、cbam3 ≈ 2,146,总计不到 3k 参数,却能显著提升注意力表达能力。
  • 如果想进一步轻量化,可以减小 fc1 的输出维度(如 512 → 256)或使用全局平均池化代替展平。
2. 使用 TensorBoard 查看训练过程

TensorBoard 是 PyTorch 官方推荐的可视化工具,能实时监控损失、准确率、模型结构、梯度分布等,非常适合调试和分析训练过程。

步骤指南(在你的本地环境执行,推荐使用 Jupyter 或 Python 脚本):

  1. 安装 TensorBoard(如果还没安装):

    bash 复制代码
    pip install tensorboard
  2. 修改训练代码,添加 TensorBoard 记录 (只需在原训练函数中添加几行):

    在你的训练脚本最上方导入:

    python 复制代码
    from torch.utils.tensorboard import SummaryWriter

    在 train 函数开头创建 writer(建议放在 epochs 循环外):

    python 复制代码
    writer = SummaryWriter(log_dir='runs/cbam_cifar10_experiment')  # 日志保存目录,可自定义

    在每个 epoch 结束时添加记录:

    python 复制代码
    # 在计算 epoch_train_acc 和 epoch_test_acc 后添加
    writer.add_scalar('Loss/train', epoch_train_loss, epoch)
    writer.add_scalar('Loss/test', epoch_test_loss, epoch)
    writer.add_scalar('Accuracy/train', epoch_train_acc, epoch)
    writer.add_scalar('Accuracy/test', epoch_test_acc, epoch)
    writer.add_scalar('Learning Rate', optimizer.param_groups[0]['lr'], epoch)

    在每个 batch 中可选记录更细粒度损失(每 100 batch 记录一次避免过多):

    python 复制代码
    if (batch_idx + 1) % 100 == 0:
        global_step = epoch * len(train_loader) + batch_idx
        writer.add_scalar('Loss/train_batch', iter_loss, global_step)

    训练结束后关闭 writer:

    python 复制代码
    writer.close()
  3. 可视化模型结构图 (推荐,一目了然看 CBAM 插入位置):

    在模型定义后、训练前添加:

    python 复制代码
    # 使用 dummy input
    dummy_input = torch.randn(1, 3, 32, 32).to(device)
    writer.add_graph(model, dummy_input)
  4. 启动 TensorBoard

    在终端运行:

    bash 复制代码
    tensorboard --logdir=runs

    然后打开浏览器访问 http://localhost:6006

    你将看到:

    • Scalars:训练/测试损失、准确率、学习率曲线(比 matplotlib 更交互,可缩放、对比多次实验)
    • Graphs:完整的模型结构图,清晰显示 Conv → BN → ReLU → Pool → CBAM 的顺序
    • Histograms (可选进阶):添加 writer.add_histogram('fc1.weight', model.fc1.weight, epoch) 查看参数分布

预期效果

  • 训练初期损失快速下降,准确率快速上升。
  • 由于添加了 CBAM,通常测试准确率会比纯 CNN 高 3-8%(在 CIFAR-10 上可达 85-90%+,取决于随机种子和训练时长)。
  • 通过 TensorBoard 你可以观察是否过拟合(train acc 高而 test acc 停滞)、学习率调度是否生效等。

CBAM 是注意力机制的经典之作,理解透了后续 Transformer 也会更容易。

相关推荐
玄同76521 小时前
从 0 到 1:用 Python 开发 MCP 工具,让 AI 智能体拥有 “超能力”
开发语言·人工智能·python·agent·ai编程·mcp·trae
小瑞瑞acd21 小时前
【小瑞瑞精讲】卷积神经网络(CNN):从入门到精通,计算机如何“看”懂世界?
人工智能·python·深度学习·神经网络·机器学习
火车叼位1 天前
也许你不需要创建.venv, 此规范使python脚本自备依赖
python
火车叼位1 天前
脚本伪装:让 Python 与 Node.js 像原生 Shell 命令一样运行
运维·javascript·python
孤狼warrior1 天前
YOLO目标检测 一千字解析yolo最初的摸样 模型下载,数据集构建及模型训练代码
人工智能·python·深度学习·算法·yolo·目标检测·目标跟踪
Katecat996631 天前
YOLO11分割算法实现甲状腺超声病灶自动检测与定位_DWR方法应用
python
玩大数据的龙威1 天前
农经权二轮延包—各种地块示意图
python·arcgis
ZH15455891311 天前
Flutter for OpenHarmony Python学习助手实战:数据库操作与管理的实现
python·学习·flutter
belldeep1 天前
python:用 Flask 3 , mistune 2 和 mermaid.min.js 10.9 来实现 Markdown 中 mermaid 图表的渲染
javascript·python·flask
喵手1 天前
Python爬虫实战:电商价格监控系统 - 从定时任务到历史趋势分析的完整实战(附CSV导出 + SQLite持久化存储)!
爬虫·python·爬虫实战·零基础python爬虫教学·电商价格监控系统·从定时任务到历史趋势分析·采集结果sqlite存储