python打卡day52

神经网络调参指南

知识点回顾:

  1. 随机种子
  2. 内参的初始化
  3. 神经网络调参指南
    1. 参数的分类
    2. 调参的顺序
    3. 各部分参数的调整心得

参数可视化

python 复制代码
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np

# 设置设备
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# 定义极简CNN模型(仅1个卷积层+1个全连接层)
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        
        # 卷积层:输入3通道,输出16通道,卷积核3x3
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)
        
        # 池化层:2x2窗口,尺寸减半
        self.pool = nn.MaxPool2d(kernel_size=2)
        
        # 全连接层:展平后连接到10个输出(对应10个类别)
        # 输入尺寸:16通道 × 16x16特征图 = 16×16×16=4096
        self.fc = nn.Linear(16 * 16 * 16, 10)
    
    def forward(self, x):
        # 卷积+池化
        x = self.pool(self.conv1(x))  # 输出尺寸: [batch, 16, 16, 16]
        
        # 展平
        x = x.view(-1, 16 * 16 * 16)  # 展平为: [batch, 4096]
        
        # 全连接
        x = self.fc(x)  # 输出尺寸: [batch, 10]
        
        return x

# 初始化模型
model = SimpleCNN()
model = model.to(device)

# 查看模型结构
print(model)

# 查看初始权重统计信息
def print_weight_stats(model):
    # 卷积层
    conv_weights = model.conv1.weight.data
    print("\n卷积层 权重统计:")
    print(f"  均值: {conv_weights.mean().item():.6f}")
    print(f"  标准差: {conv_weights.std().item():.6f}")
    print(f"  理论标准差 (Kaiming): {np.sqrt(2/3):.6f}")  # 输入通道数为3
    
    # 全连接层
    fc_weights = model.fc.weight.data
    print("\n全连接层 权重统计:")
    print(f"  均值: {fc_weights.mean().item():.6f}")
    print(f"  标准差: {fc_weights.std().item():.6f}")
    print(f"  理论标准差 (Kaiming): {np.sqrt(2/(16*16*16)):.6f}")

# 改进的可视化权重分布函数
def visualize_weights(model, layer_name, weights, save_path=None):
    plt.figure(figsize=(12, 5))
    
    # 权重直方图
    plt.subplot(1, 2, 1)
    plt.hist(weights.cpu().numpy().flatten(), bins=50)
    plt.title(f'{layer_name} 权重分布')
    plt.xlabel('权重值')
    plt.ylabel('频次')
    
    # 权重热图
    plt.subplot(1, 2, 2)
    if len(weights.shape) == 4:  # 卷积层权重 [out_channels, in_channels, kernel_size, kernel_size]
        # 只显示第一个输入通道的前10个滤波器
        w = weights[:10, 0].cpu().numpy()
        plt.imshow(w.reshape(-1, weights.shape[2]), cmap='viridis')
    else:  # 全连接层权重 [out_features, in_features]
        # 只显示前10个神经元的权重,重塑为更合理的矩形
        w = weights[:10].cpu().numpy()
        
        # 计算更合理的二维形状(尝试接近正方形)
        n_features = w.shape[1]
        side_length = int(np.sqrt(n_features))
        
        # 如果不能完美整除,添加零填充使能重塑
        if n_features % side_length != 0:
            new_size = (side_length + 1) * side_length
            w_padded = np.zeros((w.shape[0], new_size))
            w_padded[:, :n_features] = w
            w = w_padded
        
        # 重塑并显示
        plt.imshow(w.reshape(w.shape[0] * side_length, -1), cmap='viridis')
    
    plt.colorbar()
    plt.title(f'{layer_name} 权重热图')
    
    plt.tight_layout()
    if save_path:
        plt.savefig(f'{save_path}_{layer_name}.png')
    plt.show()

# 打印权重统计
print_weight_stats(model)

# 可视化各层权重
visualize_weights(model, "Conv1", model.conv1.weight.data, "initial_weights")
visualize_weights(model, "FC", model.fc.weight.data, "initial_weights")

# 可视化偏置
plt.figure(figsize=(12, 5))

# 卷积层偏置
conv_bias = model.conv1.bias.data
plt.subplot(1, 2, 1)
plt.bar(range(len(conv_bias)), conv_bias.cpu().numpy())
plt.title('卷积层 偏置')

# 全连接层偏置
fc_bias = model.fc.bias.data
plt.subplot(1, 2, 2)
plt.bar(range(len(fc_bias)), fc_bias.cpu().numpy())
plt.title('全连接层 偏置')

plt.tight_layout()
plt.savefig('biases_initial.png')
plt.show()

print("\n偏置统计:")
print(f"卷积层偏置 均值: {conv_bias.mean().item():.6f}")
print(f"卷积层偏置 标准差: {conv_bias.std().item():.6f}")
print(f"全连接层偏置 均值: {fc_bias.mean().item():.6f}")
print(f"全连接层偏置 标准差: {fc_bias.std().item():.6f}")

指南

  1. 参数初始化----有预训练的参数直接起飞

  2. batchsize---测试下能允许的最高值

  3. epoch---这个不必多说,默认都是训练到收敛位置,可以采取早停策略

  4. 学习率与调度器----收益最高,因为鞍点太多了,模型越复杂鞍点越多

  5. 模型结构----消融实验或者对照试验

  6. 损失函数---选择比较少,试出来一个即可,高手可以自己构建

  7. 激活函数---选择同样较少

  8. 正则化参数---主要是droupout,等到过拟合了用,上述所有步骤都为了让模型过拟合

相关推荐
渣渣苏6 分钟前
Langchain实战快速入门
人工智能·python·langchain
3GPP仿真实验室8 分钟前
【Matlab源码】6G候选波形:OFDM-IM 增强仿真平台 DM、CI
开发语言·matlab·ci/cd
devmoon12 分钟前
在 Polkadot 上部署独立区块链Paseo 测试网实战部署指南
开发语言·安全·区块链·polkadot·erc-20·测试网·独立链
lili-felicity12 分钟前
CANN流水线并行推理与资源调度优化
开发语言·人工智能
沐知全栈开发13 分钟前
CSS3 边框:全面解析与实战技巧
开发语言
lili-felicity15 分钟前
CANN模型量化详解:从FP32到INT8的精度与性能平衡
人工智能·python
数据知道18 分钟前
PostgreSQL实战:详解如何用Python优雅地从PG中存取处理JSON
python·postgresql·json
island131423 分钟前
CANN GE(图引擎)深度解析:计算图优化管线、内存静态规划与异构 Stream 调度机制
c语言·开发语言·神经网络
曹牧27 分钟前
Spring Boot:如何在Java Controller中处理POST请求?
java·开发语言
浅念-30 分钟前
C++入门(2)
开发语言·c++·经验分享·笔记·学习