全连接神经网络之参数初始化和损失函数(四)

全连接神经网络之参数初始化和损失函数(四)


文章目录

  • 全连接神经网络之参数初始化和损失函数(四)
    • 一、参数初始化
      • [1.1 固定值初始化(仅演示,**权重不要用**)](#1.1 固定值初始化(仅演示,权重不要用))
      • [1.2 随机初始化(打破对称性的起点)](#1.2 随机初始化(打破对称性的起点))
      • [1.3 Xavier(Glorot)初始化 ------ 均衡前向与反向方差](#1.3 Xavier(Glorot)初始化 —— 均衡前向与反向方差)
      • [1.4 He(Kaiming)初始化 ------ 专为 ReLU 优化](#1.4 He(Kaiming)初始化 —— 专为 ReLU 优化)
      • [1.5 高级初始化速览](#1.5 高级初始化速览)
    • 二、损失函数
      • [2.1 回归任务](#2.1 回归任务)
      • [2.2 分类任务](#2.2 分类任务)
        • [2.2.1 多类单标签 ------ CrossEntropyLoss](#2.2.1 多类单标签 —— CrossEntropyLoss)
        • [2.2.2 二分类 / 多标签 ------ BCEWithLogitsLoss](#2.2.2 二分类 / 多标签 —— BCEWithLogitsLoss)
    • 三、总结
    • 四、案例

关键词:参数初始化、Xavier / He 初始值、对称性破坏、损失函数、MAE / MSE / CrossEntropy / BCE



一、参数初始化

1.1 固定值初始化(仅演示,权重不要用

方法 代码 缺陷
全零 nn.init.zeros_(w) 对称性未被破坏,所有神经元等价
全一 nn.init.ones_(w) 同上,且激活后输出恒等
任意常数 nn.init.constant_(w, val) 仍无法打破对称性
python 复制代码
import torch.nn as nn
fc = nn.Linear(4, 3)
nn.init.zeros_(fc.weight)      # 仅偏置可用

1.2 随机初始化(打破对称性的起点)

分布 代码 方差公式 备注
均匀 nn.init.uniform_(w, a, b) ( b − a ) 2 12 \displaystyle \frac{(b-a)^2}{12} 12(b−a)2 需手动调区间
正态 nn.init.normal_(w, mean, std) σ 2 \displaystyle \sigma^2 σ2 std 难校准

局限:未考虑前向/反向方差,深层网络仍需"自适应"方法。


1.3 Xavier(Glorot)初始化 ------ 均衡前向与反向方差

激活函数 数学依据 PyTorch API
Sigmoid / Tanh Var ( W ) = 2 n in + n out \displaystyle \text{Var}(W)=\frac{2}{n_{\text{in}}+n_{\text{out}}} Var(W)=nin+nout2 xavier_uniform_ / xavier_normal_
  • 均匀区间 : [ − 6 n in + n out ,    6 n in + n out ] [-\sqrt{\frac{6}{n_{\text{in}}+n_{\text{out}}}},\; \sqrt{\frac{6}{n_{\text{in}}+n_{\text{out}}}}] [−nin+nout6 ,nin+nout6 ]
  • 正态方差 : 2 n in + n out \frac{2}{n_{\text{in}}+n_{\text{out}}} nin+nout2
python 复制代码
# 示例:对 Tanh 使用 Xavier
fc = nn.Linear(128, 64)
nn.init.xavier_uniform_(fc.weight, gain=nn.init.calculate_gain('tanh'))

1.4 He(Kaiming)初始化 ------ 专为 ReLU 优化

模式 方差公式 场景 API
fan_in 2 n in \frac{2}{n_{\text{in}}} nin2 前向方差稳定 kaiming_normal_(..., mode='fan_in')
fan_out 2 n out \frac{2}{n_{\text{out}}} nout2 反向梯度稳定 kaiming_uniform_(..., mode='fan_out')
python 复制代码
# 示例:ReLU + He
fc = nn.Linear(256, 128)
nn.init.kaiming_normal_(fc.weight, nonlinearity='relu')

1.5 高级初始化速览

方法 一句话说明 代码
orthogonal_ 生成(半)正交矩阵,保持动态等距 nn.init.orthogonal_(w)
sparse_ 指定稀疏度的高斯权重 nn.init.sparse_(w, sparsity=0.9)

二、损失函数

2.1 回归任务

损失 公式 PyTorch
MAE (L1) $\frac{1}{n}\sum y_i-\hat y_i
MSE (L2) 1 n ∑ ( y i − y ^ i ) 2 \frac{1}{n}\sum(y_i-\hat y_i)^2 n1∑(yi−y^i)2 nn.MSELoss()
python 复制代码
pred   = torch.randn(32, 1)
target = torch.randn(32, 1)
print("MSE:", nn.MSELoss()(pred, target).item())
print("MAE:", nn.L1Loss()(pred, target).item())

2.2 分类任务

2.2.1 多类单标签 ------ CrossEntropyLoss
  • 已内置 Softmax不要 再手动 softmax
  • 公式:
    L = − 1 N ∑ i log ⁡ e x i , y i − max ⁡ ( x i ) ∑ j e x i , j − max ⁡ ( x i ) \displaystyle \mathcal{L}=-\frac{1}{N}\sum_{i}\log\frac{e^{x_{i,y_i}-\max(x_i)}}{\sum_j e^{x_{i,j}-\max(x_i)}} L=−N1i∑log∑jexi,j−max(xi)exi,yi−max(xi)
python 复制代码
logits = torch.randn(16, 10)      # (batch, n_classes)
labels = torch.randint(0, 10, (16,))
loss = nn.CrossEntropyLoss()(logits, labels)
print("CrossEntropy:", loss.item())
2.2.2 二分类 / 多标签 ------ BCEWithLogitsLoss
  • 已内置 Sigmoid,推荐一步到位。
  • 公式:
    − 1 N ∑ i [ y i log ⁡ y ^ i + ( 1 − y i ) log ⁡ ( 1 − y ^ i ) ] \displaystyle -\frac{1}{N}\sum_{i}\left[y_i\log\hat y_i+(1-y_i)\log(1-\hat y_i)\right] −N1i∑[yilogy^i+(1−yi)log(1−y^i)]
python 复制代码
logits  = torch.randn(8, 5)       # 8 个样本,5 个标签
targets = torch.randint(0, 2, (8, 5)).float()
loss = nn.BCEWithLogitsLoss()(logits, targets)
print("Multi-label BCE:", loss.item())

三、总结

任务类型 输出层 初始化 损失 备注
线性回归 无激活 Xavier / He 均可 MSE / MAE 输出无需激活
二分类 1 个神经元 + Sigmoid He BCEWithLogitsLoss 标签 0/1
多类单标签 Softmax He CrossEntropyLoss 无需手动 Softmax
多标签 Sigmoid(每个类) He BCEWithLogitsLoss 标签多热编码

四、案例

完整训练片段:初始化 + 损失

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim

class MLP(nn.Module):
    def __init__(self, in_dim=784, hidden=256, out_dim=10):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden),
            nn.ReLU(),
            nn.Linear(hidden, out_dim)
        )
        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
                nn.init.zeros_(m.bias)

    def forward(self, x):
        return self.net(x)

model = MLP()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# 模拟一个 batch
x = torch.randn(64, 784)
y = torch.randint(0, 10, (64,))
out = model(x)
loss = criterion(out, y)

optimizer.zero_grad()
loss.backward()
optimizer.step()
print("Step loss:", loss.item())

相关推荐
JJJJ_iii2 分钟前
【机器学习01】监督学习、无监督学习、线性回归、代价函数
人工智能·笔记·python·学习·机器学习·jupyter·线性回归
qq_416276422 小时前
LOFAR物理频谱特征提取及实现
人工智能
Python图像识别3 小时前
71_基于深度学习的布料瑕疵检测识别系统(yolo11、yolov8、yolov5+UI界面+Python项目源码+模型+标注好的数据集)
python·深度学习·yolo
余俊晖3 小时前
如何构造一个文档解析的多模态大模型?MinerU2.5架构、数据、训练方法
人工智能·文档解析
Akamai中国4 小时前
Linebreak赋能实时化企业转型:专业系统集成商携手Akamai以实时智能革新企业运营
人工智能·云计算·云服务
LiJieNiub5 小时前
读懂目标检测:从基础概念到主流算法
人工智能·计算机视觉·目标跟踪
哥布林学者6 小时前
吴恩达深度学习课程一:神经网络和深度学习 第三周:浅层神经网络(二)
深度学习·ai
weixin_519535776 小时前
从ChatGPT到新质生产力:一份数据驱动的AI研究方向指南
人工智能·深度学习·机器学习·ai·chatgpt·数据分析·aigc
爱喝白开水a6 小时前
LangChain 基础系列之 Prompt 工程详解:从设计原理到实战模板_langchain prompt
开发语言·数据库·人工智能·python·langchain·prompt·知识图谱
takashi_void6 小时前
如何在本地部署大语言模型(Windows,Mac,Linux)三系统教程
linux·人工智能·windows·macos·语言模型·nlp