基于pytorch的深度学习基础4——损失函数和优化器

四.损失函数和优化器

4.1 均值初始化

为减轻梯度消失和梯度爆炸,选择合适的权重初值。

十种初始化方法

Initialization Methods

  1. Xavie r均匀分布

  2. Xavie r正态分布

  3. Kaiming正态分布

  4. 均匀分布

  5. 正态分布

  6. 常数分布

  7. 正交矩阵初始化

  8. 单位矩阵初始化

  9. 稀疏矩阵初始化

4.2 损失函数

1、nn.CrossEntropyLoss

nn.CrossEntropyLoss(weight=None,

size_average=None,

ignore_index=-100,

reduce=None,

reduction='mean'')

功能: nn.LogSoftmax ()与nn.NLLLoss ()结合,进行

交叉熵计算

主要参数:

• w eigh t:各类别的loss设置权值

ignore _ind e x:忽略某个类别

redu c tion :计算模式,可为none/sum /m e an

none- 逐个元素计算

sum- 所有元素求和,返回标量

2、 nn.NLLLoss

功能:实现负对数似然函数中的负号功能

主要参数:

• weight:各类别的loss设置权值

• ignore_index:忽略某个类别

•reduction:计算模式,可为none/sum /m e an

none-逐个元素计算

nn.NLLLoss(weight=None,

size_average=None,

ignore_index=-100,

reduce=None,

reduction='mean')sum- 所有元素求和,返回标量

m e an-加权平均,返回标量

3、 nn.BCELoss

nn.BCELoss(weight=None,

size_average=None,

reduce=None,

reduction='mean')

功能:二分类交叉熵

注意事项:输入值取值在[0,1]

主要参数:

• weight:各类别的loss设置权值

• ignore_index:忽略某个类别

• reduction:计算模式,可为none/sum /m e an

none-逐个元素计算

4、 nn.BCEWithLogitsLoss

nn.BCEWithLogitsLoss(weight=None,

size_average=None,

reduce=None, reduction='mean',

pos_weight=None)

功能:结合Sigmoid与二分类交叉熵

注意事项:网络最后不加sigmoid函数

主要参数:

• pos _weight :正样本的权值

• weight:各类别的loss设置权值

•ignore_index:忽略某个类别

•reduction :计算模式,可为none/sum /mean

mean-加权平均,返回标量e aum

  1. nn.L1Loss

  2. nn.MSELoss

  3. nn.SmoothL1Loss

  4. nn.PoissonNLLLoss

  5. nn.KLDivLoss

  6. nn.MarginRankingLoss

  7. nn.MultiLabelMarginLoss

  8. nn.SoftMarginLoss

  9. nn.MultiLabelSoftMarginLoss

  10. nn.MultiMarginLoss

  11. nn.TripletMarginLoss

  12. nn.HingeEmbeddingLoss

  13. nn.CosineEmbeddingLoss

  14. nn.CTCLoss -所有元素求和,返回标量

4.3优化器 Optimizer

pytorch的优化器:管理并更新模型中可学习参数的值,使得模型输出更接近真实标签

导数:函数在指定坐标轴上的变化率

方向导数:指定方向上的变化率

梯度:一个向量,方向为方向导数取得最大值的方向

class Optimizer(object):

def init(self, params, defaults):

self.defaults = defaults

self.state = defaultdict(dict)

self.param_groups = []

param_groups = [{'params':

param_groups}]本属性

• defaults:优化器超参数

• state:参数的缓存,如mom en tum的缓存

• params_groups:管理的参数组

• _step_count:记录更新次数,学习率调整中使用

基本方法

• 1.zero_grad():清空所管理参数的梯度

pytorch特性:张量梯度不自动清零

class Optimizer(object):

def zero_grad(self):

for group in self.param_groups:

for p in group['params']:

if p.grad is not None:

p.grad.detach_()

p.grad.zero_()

  1. step():执行一步更新

  2. add_param_group():添加参数组

class Optimizer(object):

def add_param_group(self, param_group):

for group in self.param_groups:

param_set.update(set(group['params']))

self.param_groups.append(param_group)

4.state_dict():获取优化器当前状态信息字典

• 5.load_state_dict() :加载状态信息字典

class Optimizer(object):

def state_dict(self):

return {

'state': packed_state,

'param_groups': param_groups,

}

def load_state_dict(self, state_dict):

学习率

Learning Rate

梯度下降:

𝒘𝒊+𝟏 = 𝒘𝒊 − 𝒈(𝒘𝒊 )

𝒘𝒊+𝟏 = 𝒘𝒊 − LR * 𝒈(𝒘𝒊)

学习率(learning rate)控制更新的步伐

Momentum(动量,冲量):

结合当前梯度与上一次更新信息,用于当前更新

梯度下降:

𝒘𝒊+𝟏 = 𝒘𝒊 − 𝒍𝒓 ∗ 𝒈(𝒘𝒊 )

pytorch中更新公式:

𝒗𝒊 = 𝒎 ∗ 𝒗𝒊−𝟏 + 𝒈(𝒘𝒊 )

𝒘𝒊+𝟏 = 𝒘𝒊 − 𝒍𝒓 ∗ 𝒗𝒊

𝒗𝟏𝟎𝟎 = 𝒎 ∗ 𝒗𝟗𝟗 + 𝒈(𝒘𝟏𝟎𝟎)

= 𝒈(𝒘𝟏𝟎𝟎) + 𝒎 ∗ (𝒎 ∗ 𝒗𝟗𝟖 + 𝒈(𝒘𝟗𝟗))

= 𝒈(𝒘𝟏𝟎𝟎) + 𝒎 ∗ 𝒈(𝒘𝟗𝟗) + 𝒎𝟐 ∗ 𝒗𝟗𝟖

= 𝒈(𝒘𝟏𝟎𝟎) + 𝒎 ∗ 𝒈(𝒘𝟗𝟗) + 𝒎𝟐 ∗ 𝒈(𝒘𝟗𝟖) + 𝒎𝟑 ∗ 𝒗𝟗𝟕

1.optim.SGD

主要参数:

• params:管理的参数组

• lr:初始学习率

• momentum:动量系数,贝塔

• weight_decay:L2正则化系数

• nesterov:是否采用NAG

optim.SGD(params, lr=<object object>,

momentum=0, dampening=0,

weight_decay=0, nesterov=False)

优化器

Optimizer

  1. optim.SGD:随机梯度下降法

  2. optim.Adagrad:自适应学习率梯度下降法

  3. optim.RMSprop: Adagrad的改进

  4. optim.Adadelta: Adagrad的改进

  5. optim.Adam:RMSprop结合Momentum

  6. optim.Adamax:Adam增加学习率上限

  7. optim.SparseAdam:稀疏版的Adam

  8. optim.ASGD:随机平均梯度下降

  9. optim.Rprop:弹性反向传播

  10. optim.LBFGS:BFGS的改进

相关推荐
新智元2 分钟前
刚刚,英伟达祭出下一代 GPU!狂飙百万 token 巨兽,投 1 亿爆赚 50 亿
人工智能·openai
霍格沃兹_测试11 分钟前
从零开始搭建Qwen智能体:新手也能轻松上手指南
人工智能
SmartJavaAI20 分钟前
Java调用Whisper和Vosk语音识别(ASR)模型,实现高效实时语音识别(附源码)
java·人工智能·whisper·语音识别
七元权23 分钟前
论文阅读-SelectiveStereo
论文阅读·深度学习·双目深度估计·selectivestereo
山东小木23 分钟前
JBoltAI需求分析大师:基于SpringBoot的大模型智能需求文档生成解决方案
人工智能·spring boot·后端·需求分析·jboltai·javaai·aigs
君名余曰正则26 分钟前
【竞赛系列】机器学习实操项目08——全球城市计算AI挑战赛(数据可视化分析)
人工智能·机器学习·信息可视化
浪浪山齐天大圣29 分钟前
python数据可视化之Matplotlib(8)-Matplotlib样式系统深度解析:从入门到企业级应用
python·matplotlib·数据可视化
算家计算32 分钟前
一张图+一段音频=电影级视频!阿里Wan2.2-S2V-14B本地部署教程:实现丝滑口型同步
人工智能·开源·aigc
XINVRY-FPGA36 分钟前
XCVP1902-2MSEVSVA6865 AMD 赛灵思 XilinxVersal Premium FPGA
人工智能·嵌入式硬件·神经网络·fpga开发·云计算·腾讯云·fpga
算家计算38 分钟前
多年AI顽疾被攻克!OpenAI前CTO团队破解AI随机性难题,大模型可靠性迎来飞跃
人工智能·llm·资讯