Day56 PythonStudy

@浙大疏锦行

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
# 设置中文字体支持
plt.rcParams["font.family"] = ["SimHei"]
plt.rcParams['axes.unicode_minus'] = False  # 解决负号显示问题

# 检查GPU是否可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")

# 数据预处理(与原代码一致)
train_transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.RandomRotation(15),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

# 加载数据集(与原代码一致)
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, transform=test_transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
python 复制代码
import torch
import torch.nn as nn

# 定义通道注意力
class ChannelAttention(nn.Module):
    def __init__(self, in_channels, ratio=16):
        """
        通道注意力机制初始化
        参数:
            in_channels: 输入特征图的通道数
            ratio: 降维比例,用于减少参数量,默认为16
        """
        super().__init__()
        # 全局平均池化,将每个通道的特征图压缩为1x1,保留通道间的平均值信息
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        # 全局最大池化,将每个通道的特征图压缩为1x1,保留通道间的最显著特征
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        # 共享全连接层,用于学习通道间的关系
        # 先降维(除以ratio),再通过ReLU激活,最后升维回原始通道数
        self.fc = nn.Sequential(
            nn.Linear(in_channels, in_channels // ratio, bias=False),  # 降维层
            nn.ReLU(),  # 非线性激活函数
            nn.Linear(in_channels // ratio, in_channels, bias=False)   # 升维层
        )
        # Sigmoid函数将输出映射到0-1之间,作为各通道的权重
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        """
        前向传播函数
        参数:
            x: 输入特征图,形状为 [batch_size, channels, height, width]
        返回:
            调整后的特征图,通道权重已应用
        """
        # 获取输入特征图的维度信息,这是一种元组的解包写法
        b, c, h, w = x.shape
        # 对平均池化结果进行处理:展平后通过全连接网络
        avg_out = self.fc(self.avg_pool(x).view(b, c))
        # 对最大池化结果进行处理:展平后通过全连接网络
        max_out = self.fc(self.max_pool(x).view(b, c))
        # 将平均池化和最大池化的结果相加并通过sigmoid函数得到通道权重
        attention = self.sigmoid(avg_out + max_out).view(b, c, 1, 1)
        # 将注意力权重与原始特征相乘,增强重要通道,抑制不重要通道
        return x * attention #这个运算是pytorch的广播机制
## 空间注意力模块
class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super().__init__()
        self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        # 通道维度池化
        avg_out = torch.mean(x, dim=1, keepdim=True)  # 平均池化:(B,1,H,W)
        max_out, _ = torch.max(x, dim=1, keepdim=True)  # 最大池化:(B,1,H,W)
        pool_out = torch.cat([avg_out, max_out], dim=1)  # 拼接:(B,2,H,W)
        attention = self.conv(pool_out)  # 卷积提取空间特征
        return x * self.sigmoid(attention)  # 特征与空间权重相乘
## CBAM模块
class CBAM(nn.Module):
    def __init__(self, in_channels, ratio=16, kernel_size=7):
        super().__init__()
        self.channel_attn = ChannelAttention(in_channels, ratio)
        self.spatial_attn = SpatialAttention(kernel_size)

    def forward(self, x):
        x = self.channel_attn(x)
        x = self.spatial_attn(x)
        return x
python 复制代码
# 定义带有CBAM的CNN模型
class CBAM_CNN(nn.Module):
    def __init__(self):
        super(CBAM_CNN, self).__init__()
        
        # ---------------------- 第一个卷积块(带CBAM) ----------------------
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(32) # 批归一化
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(kernel_size=2)
        self.cbam1 = CBAM(in_channels=32)  # 在第一个卷积块后添加CBAM
        
        # ---------------------- 第二个卷积块(带CBAM) ----------------------
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(kernel_size=2)
        self.cbam2 = CBAM(in_channels=64)  # 在第二个卷积块后添加CBAM
        
        # ---------------------- 第三个卷积块(带CBAM) ----------------------
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        self.relu3 = nn.ReLU()
        self.pool3 = nn.MaxPool2d(kernel_size=2)
        self.cbam3 = CBAM(in_channels=128)  # 在第三个卷积块后添加CBAM
        
        # ---------------------- 全连接层 ----------------------
        self.fc1 = nn.Linear(128 * 4 * 4, 512)
        self.dropout = nn.Dropout(p=0.5)
        self.fc2 = nn.Linear(512, 10)

    def forward(self, x):
        # 第一个卷积块
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu1(x)
        x = self.pool1(x)
        x = self.cbam1(x)  # 应用CBAM
        
        # 第二个卷积块
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu2(x)
        x = self.pool2(x)
        x = self.cbam2(x)  # 应用CBAM
        
        # 第三个卷积块
        x = self.conv3(x)
        x = self.bn3(x)
        x = self.relu3(x)
        x = self.pool3(x)
        x = self.cbam3(x)  # 应用CBAM
        
        # 全连接层
        x = x.view(-1, 128 * 4 * 4)
        x = self.fc1(x)
        x = self.relu3(x)
        x = self.dropout(x)
        x = self.fc2(x)
        
        return x

# 初始化模型并移至设备
model = CBAM_CNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=3, factor=0.5)
# 统计参数的核心函数
def count_parameters(model):
    """
    统计模型的参数数量
    返回:可训练参数数量,总参数数量
    """
    # 可训练参数(requires_grad=True)
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    # 总参数(包括冻结的参数)
    total_params = sum(p.numel() for p in model.parameters())
    
    return trainable_params, total_params

# 调用函数并输出结果
trainable_params, total_params = count_parameters(model)

# 格式化输出,让数字更易读
print(f"可训练参数数量: {trainable_params:,}")
print(f"总参数数量: {total_params:,}")

# 额外:打印每层的参数详情(可选)
print("\n各层参数详情:")
for name, param in model.named_parameters():
    print(f"{name}: {param.numel():,} 个参数 (可训练: {param.requires_grad})")

可训练参数数量: 1,150,896

总参数数量: 1,150,896

各层参数详情:

conv1.weight: 864 个参数 (可训练: True)

conv1.bias: 32 个参数 (可训练: True)

bn1.weight: 32 个参数 (可训练: True)

bn1.bias: 32 个参数 (可训练: True)

cbam1.channel_attn.fc.0.weight: 64 个参数 (可训练: True)

cbam1.channel_attn.fc.2.weight: 64 个参数 (可训练: True)

cbam1.spatial_attn.conv.weight: 98 个参数 (可训练: True)

conv2.weight: 18,432 个参数 (可训练: True)

conv2.bias: 64 个参数 (可训练: True)

bn2.weight: 64 个参数 (可训练: True)

bn2.bias: 64 个参数 (可训练: True)

cbam2.channel_attn.fc.0.weight: 256 个参数 (可训练: True)

cbam2.channel_attn.fc.2.weight: 256 个参数 (可训练: True)

cbam2.spatial_attn.conv.weight: 98 个参数 (可训练: True)

conv3.weight: 73,728 个参数 (可训练: True)

conv3.bias: 128 个参数 (可训练: True)

bn3.weight: 128 个参数 (可训练: True)

bn3.bias: 128 个参数 (可训练: True)

cbam3.channel_attn.fc.0.weight: 1,024 个参数 (可训练: True)

cbam3.channel_attn.fc.2.weight: 1,024 个参数 (可训练: True)

cbam3.spatial_attn.conv.weight: 98 个参数 (可训练: True)

fc1.weight: 1,048,576 个参数 (可训练: True)

fc1.bias: 512 个参数 (可训练: True)

fc2.weight: 5,120 个参数 (可训练: True)

fc2.bias: 10 个参数 (可训练: True)

相关推荐
Coder_Boy_1 小时前
技术发展的核心规律是「加法打底,减法优化,重构平衡」
人工智能·spring boot·spring·重构
会飞的老朱3 小时前
医药集团数智化转型,智能综合管理平台激活集团管理新效能
大数据·人工智能·oa协同办公
聆风吟º4 小时前
CANN runtime 实战指南:异构计算场景中运行时组件的部署、调优与扩展技巧
人工智能·神经网络·cann·异构计算
Codebee6 小时前
能力中心 (Agent SkillCenter):开启AI技能管理新时代
人工智能
聆风吟º7 小时前
CANN runtime 全链路拆解:AI 异构计算运行时的任务管理与功能适配技术路径
人工智能·深度学习·神经网络·cann
uesowys7 小时前
Apache Spark算法开发指导-One-vs-Rest classifier
人工智能·算法·spark
AI_56787 小时前
AWS EC2新手入门:6步带你从零启动实例
大数据·数据库·人工智能·机器学习·aws
User_芊芊君子7 小时前
CANN大模型推理加速引擎ascend-transformer-boost深度解析:毫秒级响应的Transformer优化方案
人工智能·深度学习·transformer
智驱力人工智能8 小时前
小区高空抛物AI实时预警方案 筑牢社区头顶安全的实践 高空抛物检测 高空抛物监控安装教程 高空抛物误报率优化方案 高空抛物监控案例分享
人工智能·深度学习·opencv·算法·安全·yolo·边缘计算
qq_160144878 小时前
亲测!2026年零基础学AI的入门干货,新手照做就能上手
人工智能