深度学习基础知识 学习率调度器的用法解析

深度学习基础知识 学习率调度器的用法解析

1、自定义学习率调度器**:**torch.optim.lr_scheduler.LambdaLR

实验代码:

python 复制代码
import torch
import torch.nn as nn

def lr_lambda(x):
    return x*2
    
net=nn.Sequential(nn.Conv2d(3,16,3,1,1))

optimizer=torch.optim.SGD(net.parameters(),lr=0.01,momentum=0.9)

lr_scheduler=torch.optim.lr_scheduler.LambdaLR(optimizer,lr_lambda=lr_lambda)

for _ in range(10):
    optimizer.step()
    lr_scheduler.step()
    print(optimizer.param_groups[0]['lr'])

打印结果:

分析数据变化如下图所示:

2、正儿八经的模型搭建流程以及学习率调度器的使用设置


代码:

python 复制代码
import torch
import torch.nn as nn
import numpy as np

def create_lr_scheduler(optimizer,
                        num_step:int,
                        epochs:int,
                        warmup=True,
                        warmup_epochs=1,
                        warmup_factor=1e-3):
    assert num_step>0 and epochs>0
    if warmup is False:
        warmup_epochs=0
    
    def f(x):
        """
            根据step数,返回一个学习率倍率因子,
            注意在训练开始之前,pytorch会提前调用一次create_lr_scheduler.step()方法

        
        """

        if warmup is True and x <= (warmup_epochs * num_step):
            alpha=float(x) / (warmup_epochs * num_step)
            # warmup过程中,学习率因子(learning rate factor):warmup_factor -----> 1
            return warmup_factor * (1-alpha) + alpha
        else:
            # warmup后,学习率因子(learning rate factor):warmup_factor -----> 0
            return (1-(x - warmup_epochs * num_step) / (epochs-warmup_epochs * num_step)) ** 0.9
        
    return torch.optim.lr_scheduler.LambdaLR(optimizer,lr_lambda=f)


net=nn.Sequential(nn.Conv2d(3,16,1,1))
optimizer=torch.optim.SGD(net.parameters(),lr=0.01,momentum=0.9)

lr_scheduler=create_lr_scheduler(optimizer=optimizer,num_step=5,epochs=20,warmup=True)

image=(np.random.rand(1,3,64,64)).astype(np.float32)
image_tensor=torch.tensor(image.copy(),dtype=torch.float32)
print(image.dtype)

for epoch in range(20):
    net.train()

    predict=net(image_tensor)
    optimizer.zero_grad()
    optimizer.step()
    lr_scheduler.step()
    print(optimizer.param_groups[0]['lr'])   # 打印学习率变化情况
相关推荐
knight_2024几秒前
嵌入式学习日志————对射式红外传感器计次
stm32·单片机·嵌入式硬件·学习
山烛7 分钟前
KNN 算法中的各种距离:从原理到应用
人工智能·python·算法·机器学习·knn·k近邻算法·距离公式
盲盒Q17 分钟前
《频率之光:归途之光》
人工智能·硬件架构·量子计算
墨染点香25 分钟前
第七章 Pytorch构建模型详解【构建CIFAR10模型结构】
人工智能·pytorch·python
go546315846526 分钟前
基于分组规则的Excel数据分组优化系统设计与实现
人工智能·学习·生成对抗网络·数学建模·语音识别
茫茫人海一粒沙32 分钟前
vLLM 的“投机取巧”:Speculative Decoding 如何加速大语言模型推理
人工智能·语言模型·自然语言处理
诗酒当趁年华34 分钟前
【NLP实践】二、自训练数据实现中文文本分类并提供RestfulAPI服务
人工智能·自然语言处理·分类
●VON43 分钟前
重生之我在暑假学习微服务第二天《MybatisPlus-下篇》
java·学习·微服务·架构·mybatis-plus
Yu_Lijing1 小时前
MySQL进阶学习与初阶复习第四天
数据库·学习·mysql
静心问道1 小时前
Idefics3:构建和更好地理解视觉-语言模型:洞察与未来方向
人工智能·多模态·ai技术应用