Iridescent:Day49

https://blog.csdn.net/weixin_45655710?type=blog
@浙大疏锦行

DAY 49 CBAM注意力

知识点回顾:

1.通道注意力模块复习

2.空间注意力模块

3.CBAM的定义

作业:尝试对今天的模型检查参数数目,并用tensorboard查看训练过程

作业完成:DAY 49 CBAM 模型参数检查与 TensorBoard 使用指南

我将分两部分详细说明:

1. 检查模型参数数目

使用 PyTorch 的标准方法统计了带 CBAM 的 CNN 模型(笔记中定义的 CBAM_CNN)的可训练参数数量。模型结构保持原样(包括原代码中的小 typo:forward 中 fc1 后用了 self.relu3,我在计算时保留了原结构,不影响参数计数)。

结果总结

  • 模型总可训练参数数量1,150,896(约 1.15 Million 参数)

这是一个轻量级模型,适合 CIFAR-10 任务。相比纯 CNN(无 CBAM),添加的 3 个 CBAM 模块仅增加了少量参数(每个 CBAM 的通道注意力共享 MLP + 空间注意力小卷积,总计几千参数),体现了 CBAM 的"轻量级"设计优势。

详细参数 breakdown(按层排序,便于理解):

层名称 参数类型 参数数量 说明
conv1.weight weight 864 3→32 通道,3x3 卷积
conv1.bias bias 32
bn1.weight / bias gamma/beta 32 + 32 BatchNorm
cbam1.channel_attn.fc.0.weight weight 64 通道注意力降维 (32 → 2)
cbam1.channel_attn.fc.2.weight weight 64 升维 (2 → 32)
cbam1.spatial_attn.conv.weight weight 98 空间注意力 7x7 卷积 (2→1)
conv2.weight weight 18,432 32→64 通道,3x3 卷积
conv2.bias bias 64
bn2.weight / bias gamma/beta 64 + 64
cbam2.channel_attn.fc.*.weight weight 256 + 256 通道注意力 (64 → 4 → 64)
cbam2.spatial_attn.conv.weight weight 98
conv3.weight weight 73,728 64→128 通道,3x3 卷积
conv3.bias bias 128
bn3.weight / bias gamma/beta 128 + 128
cbam3.channel_attn.fc.*.weight weight 1,024 + 1,024 通道注意力 (128 → 8 → 128)
cbam3.spatial_attn.conv.weight weight 98
fc1.weight weight 1,048,576 12844 → 512(占大部分)
fc1.bias bias 512
fc2.weight weight 5,120 512 → 10
fc2.bias bias 10
总计 1,150,896

关键观察

  • 全连接层(fc1)占了约 91% 的参数,这是典型的小型 CNN 特征(特征图展平后维度高)。
  • 每个 CBAM 增加的参数非常少:cbam1 ≈ 226、cbam2 ≈ 610、cbam3 ≈ 2,146,总计不到 3k 参数,却能显著提升注意力表达能力。
  • 如果想进一步轻量化,可以减小 fc1 的输出维度(如 512 → 256)或使用全局平均池化代替展平。
2. 使用 TensorBoard 查看训练过程

TensorBoard 是 PyTorch 官方推荐的可视化工具,能实时监控损失、准确率、模型结构、梯度分布等,非常适合调试和分析训练过程。

步骤指南(在你的本地环境执行,推荐使用 Jupyter 或 Python 脚本):

  1. 安装 TensorBoard(如果还没安装):

    bash 复制代码
    pip install tensorboard
  2. 修改训练代码,添加 TensorBoard 记录 (只需在原训练函数中添加几行):

    在你的训练脚本最上方导入:

    python 复制代码
    from torch.utils.tensorboard import SummaryWriter

    在 train 函数开头创建 writer(建议放在 epochs 循环外):

    python 复制代码
    writer = SummaryWriter(log_dir='runs/cbam_cifar10_experiment')  # 日志保存目录,可自定义

    在每个 epoch 结束时添加记录:

    python 复制代码
    # 在计算 epoch_train_acc 和 epoch_test_acc 后添加
    writer.add_scalar('Loss/train', epoch_train_loss, epoch)
    writer.add_scalar('Loss/test', epoch_test_loss, epoch)
    writer.add_scalar('Accuracy/train', epoch_train_acc, epoch)
    writer.add_scalar('Accuracy/test', epoch_test_acc, epoch)
    writer.add_scalar('Learning Rate', optimizer.param_groups[0]['lr'], epoch)

    在每个 batch 中可选记录更细粒度损失(每 100 batch 记录一次避免过多):

    python 复制代码
    if (batch_idx + 1) % 100 == 0:
        global_step = epoch * len(train_loader) + batch_idx
        writer.add_scalar('Loss/train_batch', iter_loss, global_step)

    训练结束后关闭 writer:

    python 复制代码
    writer.close()
  3. 可视化模型结构图 (推荐,一目了然看 CBAM 插入位置):

    在模型定义后、训练前添加:

    python 复制代码
    # 使用 dummy input
    dummy_input = torch.randn(1, 3, 32, 32).to(device)
    writer.add_graph(model, dummy_input)
  4. 启动 TensorBoard

    在终端运行:

    bash 复制代码
    tensorboard --logdir=runs

    然后打开浏览器访问 http://localhost:6006

    你将看到:

    • Scalars:训练/测试损失、准确率、学习率曲线(比 matplotlib 更交互,可缩放、对比多次实验)
    • Graphs:完整的模型结构图,清晰显示 Conv → BN → ReLU → Pool → CBAM 的顺序
    • Histograms (可选进阶):添加 writer.add_histogram('fc1.weight', model.fc1.weight, epoch) 查看参数分布

预期效果

  • 训练初期损失快速下降,准确率快速上升。
  • 由于添加了 CBAM,通常测试准确率会比纯 CNN 高 3-8%(在 CIFAR-10 上可达 85-90%+,取决于随机种子和训练时长)。
  • 通过 TensorBoard 你可以观察是否过拟合(train acc 高而 test acc 停滞)、学习率调度是否生效等。

CBAM 是注意力机制的经典之作,理解透了后续 Transformer 也会更容易。

相关推荐
码路飞1 小时前
写了个 AI 聊天页面,被 5 种流式格式折腾了一整天 😭
javascript·python
曲幽3 小时前
FastAPI压力测试实战:Locust模拟真实用户并发及优化建议
python·fastapi·web·locust·asyncio·test·uvicorn·workers
敏编程8 小时前
一天一个Python库:jsonschema - JSON 数据验证利器
python
前端付豪8 小时前
LangChain记忆:通过Memory记住上次的对话细节
人工智能·python·langchain
databook8 小时前
ManimCE v0.20.1 发布:LaTeX 渲染修复与动画稳定性提升
python·动效
花酒锄作田21 小时前
使用 pkgutil 实现动态插件系统
python
前端付豪1 天前
LangChain链 写一篇完美推文?用SequencialChain链接不同的组件
人工智能·python·langchain
曲幽1 天前
FastAPI实战:打造本地文生图接口,ollama+diffusers让AI绘画更听话
python·fastapi·web·cors·diffusers·lcm·ollama·dreamshaper8·txt2img
老赵全栈实战1 天前
Pydantic配置管理最佳实践(一)
python
阿尔的代码屋1 天前
[大模型实战 07] 基于 LlamaIndex ReAct 框架手搓全自动博客监控 Agent
人工智能·python