基于PyTorch的深度学习5—神经网络工具箱

nn.Module是nn的一个核心数据结构,它可以是神经网络的某个层(Layer),也可以是包含多层的神经网络。在实际使用中,最常见的做法是继承nn.Module,生成自己的网络/层。nn中已实现了绝大多数层,包括全连接层、损失层、激活层、卷积层、循环层等,这些层都是nn.Module的子类,能够自动检测到自己的Parameter,并将其作为学习参数,且针对GPU运行进行了cuDNN优化。

nn中的层,一类是继承了nn.Module,其命名一般为nn.Xxx(第一个是大写)​,如nn.Linear、nn.Conv2d、nn.CrossEntropyLoss等。另一类是nn.functional中的函数,其名称一般为nn.funtional.xxx,如nn.funtional.linear、nn.funtional.conv2d、nn.funtional.cross_entropy等。从功能来说两者相当,基于nn.Moudle能实现的层,使用nn.funtional也可实现,反之亦然,而且性能方面两者也没有太大差异。不过在具体使用时,两者还是有区别,主要区别如下:

1)nn.Xxx继承于nn.Module,nn.Xxx需要先实例化并传入参数,然后以函数调用的方式调用实例化的对象并传入输入数据。它能够很好地与nn.Sequential结合使用,而nn.functional.xxx无法与nn.Sequential结合使用。2)nn.Xxx不需要自己定义和管理weight、bias参数;而nn.functional.xxx需要自己定义weight、bias参数,每次调用的时候都需要手动传入weight、bias等参数,不利于代码复用。3)Dropout操作在训练和测试阶段是有区别的,使用nn.Xxx方式定义Dropout,在调用model.eval()之后,自动实现状态的转换,而使用nn.functional.xxx却无此功能。总的来说,两种功能都是相同的,但PyTorch官方推荐:具有学习参数的(例如,conv2d,linear,batch_norm)采用nn.Xxx方式。没有学习参数的(例如,maxpool、loss func、activation func)等根据个人选择使用nn.functional.xxx或者nn.Xxx方式。

PyTorch常用的优化方法都封装在torch.optim里面,其设计很灵活,可以扩展为自定义的优化方法。所有的优化方法都是继承了基类optim.Optimizer,并实现了自己的优化步骤。最常用的优化算法就是梯度下降法及其各种变种,这类优化算法通过使用参数的梯度值更新参数。随机梯度下降法(SGD)就是最普通的优化器,一般SGD并说没有加速效果,包含动量参数Momentum,它是SGD的改良版。

优化器的一般步骤为

(1)建立优化器实例导入optim模块,实例化SGD优化器,这里使用动量参数momentum(该值一般在(0,1)之间),是SGD的改进版,效果一般比不使用动量规则的要好。

(2)向前传播把输入数据传入神经网络Net实例化对象model中,自动执行forward函数,得到out输出值,然后用out与标记label计算损失值loss。

(3)清空梯度缺省情况梯度是累加的,在梯度反向传播前,先需把梯度清零。

(4)反向传播基于损失值,把梯度进行反向传播。

(5)更新参数基于当前梯度(存储在参数的.grad属性中)更新参数。

修改参数的方式可以通过修改参数optimizer.params_groups或新建optimizer。新建optimizer比较简单,optimizer十分轻量级,所以开销很小。但是新的优化器会初始化动量等状态信息,这对于使用动量的优化器(momentum参数的sgd)可能会造成收敛中的震荡。所以,这里直接采用修改参数optimizer.params_groups。

复制代码
for epoch in range(num_epoches):
#动态修改参数学习率
    if epoch%5==0:
        optimizer.param_groups[0]['lr']*=0.1
        print(optimizer.param_groups[0]['lr'])
    for img, label in train_loader:
######
相关推荐
weixin_445238127 小时前
第R8周:RNN实现阿尔兹海默病诊断(pytorch)
人工智能·pytorch·rnn
KingDol_MIni7 小时前
ResNet残差神经网络的模型结构定义(pytorch实现)
人工智能·pytorch·神经网络
缘友一世9 小时前
深度学习系统学习系列【5】之深度学习基础(激活函数&损失函数&超参数)
人工智能·深度学习·学习
COOCC19 小时前
PyTorch 实战:从 0 开始搭建 Transformer
人工智能·pytorch·python·深度学习·算法·机器学习·transformer
闭月之泪舞10 小时前
神经网络—感知器、多层感知器
人工智能·深度学习·神经网络
985小水博一枚呀10 小时前
【EI会议推荐】2025年6月智启未来:通信导航、 机器学习、半导体与AI、数字创新领域国际研讨会总结!
人工智能·python·深度学习·机器学习
卡尔曼的BD SLAMer12 小时前
问题 | 当前计算机视觉迫切解决的问题
图像处理·人工智能·深度学习·计算机视觉·信息与通信
basketball61612 小时前
使用pytorch保存和加载预训练的模型方法
人工智能·pytorch·python
小彭律师12 小时前
基于深度学习的交通标志识别系统
人工智能·深度学习
灏瀚星空13 小时前
PyTorch 入门与核心概念详解:从基础到实战问题解决
人工智能·pytorch·python·深度学习·算法·机器学习