LLM系列:2.pytorch入门:7.深层神经网络第一篇

深层神经网络基础概念

深层神经网络(DNN)通过多个隐藏层实现复杂非线性映射。每一层由线性变换(权重矩阵乘法)和非线性激活函数组成,层间传递梯度通过反向传播算法更新参数。

网络结构设计

输入层维度需与数据特征匹配,隐藏层通常采用逐层降维或等宽设计。输出层维度由任务决定(如分类任务使用类别数)。PyTorch中通过nn.Module定义网络结构:

python 复制代码
import torch.nn as nn

class DNN(nn.Module):
    def __init__(self, input_dim, hidden_dims, output_dim):
        super().__init__()
        layers = []
        prev_dim = input_dim
        for dim in hidden_dims:
            layers.append(nn.Linear(prev_dim, dim))
            layers.append(nn.ReLU())
            prev_dim = dim
        self.net = nn.Sequential(*layers, nn.Linear(prev_dim, output_dim))
    
    def forward(self, x):
        return self.net(x)

激活函数选择

ReLU及其变体(LeakyReLU、PReLU)解决梯度消失问题:

  • ReLU: f(x)=max⁡(0,x)f(x) = \max(0,x)f(x)=max(0,x)
  • LeakyReLU: f(x)=max⁡(0.01x,x)f(x) = \max(0.01x,x)f(x)=max(0.01x,x)
  • Sigmoid/Tanh适用于输出层特定场景

参数初始化方法

Xavier初始化适配全连接层:

python 复制代码
def init_weights(m):
    if type(m) == nn.Linear:
        nn.init.xavier_uniform_(m.weight)
        m.bias.data.fill_(0.01)
net.apply(init_weights)

批标准化技术

nn.BatchNorm1d加速深层网络训练:

python 复制代码
self.bn = nn.BatchNorm1d(hidden_dim)
def forward(self, x):
    return self.bn(self.linear(x))

梯度裁剪实现

防止梯度爆炸:

python 复制代码
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

正则化策略

L2正则通过优化器实现:

python 复制代码
optimizer = torch.optim.Adam(model.parameters(), weight_decay=1e-4)

Dropout层随机失活神经元:

python 复制代码
self.dropout = nn.Dropout(p=0.5)

损失函数配置

分类任务常用交叉熵:

python 复制代码
criterion = nn.CrossEntropyLoss()

回归任务采用MSE:

python 复制代码
criterion = nn.MSELoss()
相关推荐
冬奇Lab3 分钟前
每日一个开源项目(第132篇):SkillSpector - 安装 AI Agent Skill 之前先扫一遍
人工智能·开源·agent
冬奇Lab5 分钟前
如何让 AI Skill 质量有据可查?Benchmark 驱动的评测体系设计
人工智能·agent
腾科IT教育1 小时前
Spring AI Alibaba 向量(VectorStore)
人工智能·spring·microsoft
IT_陈寒1 小时前
React中useEffect依赖项这个坑我居然踩了三天
前端·人工智能·后端
江畔柳前堤1 小时前
github实战指南02-仓库管理与 Issue
人工智能·深度学习·github·信号处理·caffe·wps·issue
邵宇然2 小时前
内存分配优化:基于 Unsafe 指针与内存对齐的 Rust 区域分配器
人工智能
海兰2 小时前
【游戏】迷雾镇(Mist Town)AI 沙箱游戏详细设计与部署指南(附源代码)
人工智能·游戏
小赖同学啊2 小时前
智能连接器集群化高可用生产方案
linux·运维·人工智能
ZStack开发者社区2 小时前
基于AI Agent的ZCF API文档全链路自动化
运维·人工智能·自动化
沈麽鬼2 小时前
别瞎用AI写代码!90%开发者都搞错了AI编程的底层逻辑
人工智能·ai编程·trae