LLM系列:2.pytorch入门:7.深层神经网络第一篇

深层神经网络基础概念

深层神经网络(DNN)通过多个隐藏层实现复杂非线性映射。每一层由线性变换(权重矩阵乘法)和非线性激活函数组成,层间传递梯度通过反向传播算法更新参数。

网络结构设计

输入层维度需与数据特征匹配,隐藏层通常采用逐层降维或等宽设计。输出层维度由任务决定(如分类任务使用类别数)。PyTorch中通过nn.Module定义网络结构:

python 复制代码
import torch.nn as nn

class DNN(nn.Module):
    def __init__(self, input_dim, hidden_dims, output_dim):
        super().__init__()
        layers = []
        prev_dim = input_dim
        for dim in hidden_dims:
            layers.append(nn.Linear(prev_dim, dim))
            layers.append(nn.ReLU())
            prev_dim = dim
        self.net = nn.Sequential(*layers, nn.Linear(prev_dim, output_dim))
    
    def forward(self, x):
        return self.net(x)

激活函数选择

ReLU及其变体(LeakyReLU、PReLU)解决梯度消失问题:

  • ReLU: f(x)=max⁡(0,x)f(x) = \max(0,x)f(x)=max(0,x)
  • LeakyReLU: f(x)=max⁡(0.01x,x)f(x) = \max(0.01x,x)f(x)=max(0.01x,x)
  • Sigmoid/Tanh适用于输出层特定场景

参数初始化方法

Xavier初始化适配全连接层:

python 复制代码
def init_weights(m):
    if type(m) == nn.Linear:
        nn.init.xavier_uniform_(m.weight)
        m.bias.data.fill_(0.01)
net.apply(init_weights)

批标准化技术

nn.BatchNorm1d加速深层网络训练:

python 复制代码
self.bn = nn.BatchNorm1d(hidden_dim)
def forward(self, x):
    return self.bn(self.linear(x))

梯度裁剪实现

防止梯度爆炸:

python 复制代码
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

正则化策略

L2正则通过优化器实现:

python 复制代码
optimizer = torch.optim.Adam(model.parameters(), weight_decay=1e-4)

Dropout层随机失活神经元:

python 复制代码
self.dropout = nn.Dropout(p=0.5)

损失函数配置

分类任务常用交叉熵:

python 复制代码
criterion = nn.CrossEntropyLoss()

回归任务采用MSE:

python 复制代码
criterion = nn.MSELoss()
相关推荐
志摩凛6 小时前
领导亲手打造的“技术屎山连环套”:Figma→React→Vue→MCP调用毒瘤UI库,半成品Design Token让我们生不如死|五一节前的噩梦
人工智能·程序员
zfh200506286 小时前
【保姆级教程】Open Claw 2.6.4 本地部署步骤+常见问题解答
人工智能·open claw·小龙虾·open claw安装
俊哥V6 小时前
每日 AI 研究简报 · 2026-05-01
人工智能·ai
irpywp6 小时前
苦于AI生成的网页千篇一律且粗糙?design-md-chrome :一款网页样式提取插件 ,将任意网站的视觉规范转化为大模型可读的代码指令!
前端·人工智能·chrome·开源·github
victory04316 小时前
论文规划框架和实验设计2
人工智能
聚铭网络6 小时前
【一周安全资讯0425】网安标委技术文件《人工智能应用伦理安全指引》1.0版公开征求意见;Vercel遭第三方OAuth劫持入侵
人工智能·安全
2401_827499996 小时前
机器学习03-线性回归
人工智能·机器学习·线性回归
skilllite作者6 小时前
Warp 终端效能与交互体验全景展示
人工智能·后端·架构·rust
穷人小水滴6 小时前
(AI) 编写简单 MCP 工具 (mcp-run)
人工智能·ai·node.js·agent·mcp