深度学习基础知识 学习率调度器的用法解析

深度学习基础知识 学习率调度器的用法解析

1、自定义学习率调度器**:**torch.optim.lr_scheduler.LambdaLR

实验代码:

python 复制代码
import torch
import torch.nn as nn

def lr_lambda(x):
    return x*2
    
net=nn.Sequential(nn.Conv2d(3,16,3,1,1))

optimizer=torch.optim.SGD(net.parameters(),lr=0.01,momentum=0.9)

lr_scheduler=torch.optim.lr_scheduler.LambdaLR(optimizer,lr_lambda=lr_lambda)

for _ in range(10):
    optimizer.step()
    lr_scheduler.step()
    print(optimizer.param_groups[0]['lr'])

打印结果:

分析数据变化如下图所示:

2、正儿八经的模型搭建流程以及学习率调度器的使用设置


代码:

python 复制代码
import torch
import torch.nn as nn
import numpy as np

def create_lr_scheduler(optimizer,
                        num_step:int,
                        epochs:int,
                        warmup=True,
                        warmup_epochs=1,
                        warmup_factor=1e-3):
    assert num_step>0 and epochs>0
    if warmup is False:
        warmup_epochs=0
    
    def f(x):
        """
            根据step数,返回一个学习率倍率因子,
            注意在训练开始之前,pytorch会提前调用一次create_lr_scheduler.step()方法

        
        """

        if warmup is True and x <= (warmup_epochs * num_step):
            alpha=float(x) / (warmup_epochs * num_step)
            # warmup过程中,学习率因子(learning rate factor):warmup_factor -----> 1
            return warmup_factor * (1-alpha) + alpha
        else:
            # warmup后,学习率因子(learning rate factor):warmup_factor -----> 0
            return (1-(x - warmup_epochs * num_step) / (epochs-warmup_epochs * num_step)) ** 0.9
        
    return torch.optim.lr_scheduler.LambdaLR(optimizer,lr_lambda=f)


net=nn.Sequential(nn.Conv2d(3,16,1,1))
optimizer=torch.optim.SGD(net.parameters(),lr=0.01,momentum=0.9)

lr_scheduler=create_lr_scheduler(optimizer=optimizer,num_step=5,epochs=20,warmup=True)

image=(np.random.rand(1,3,64,64)).astype(np.float32)
image_tensor=torch.tensor(image.copy(),dtype=torch.float32)
print(image.dtype)

for epoch in range(20):
    net.train()

    predict=net(image_tensor)
    optimizer.zero_grad()
    optimizer.step()
    lr_scheduler.step()
    print(optimizer.param_groups[0]['lr'])   # 打印学习率变化情况
相关推荐
爱喝水的鱼丶11 小时前
SAP-ABAP:SAP 简单报表输出开发系列(共6篇) 第四篇:SAP 报表异常处理机制:数据校验与消息提示规范落地
开发语言·数据库·学习·算法·sap·abap
机器之心11 小时前
当Token飙到天文数字,高通用「计算连续体」重搭智能体新基建
人工智能·openai
weixin_4684668511 小时前
液态神经网络新手入门与实战指南
人工智能·深度学习·神经网络·ai·机器视觉·液态神经网络
机器之心11 小时前
一夜之间,ChatGPT与Codex合并了
人工智能·openai
机器之心11 小时前
老黄的Cosmos 3刚发一天,就被一家中国公司反超了
人工智能·openai
标书畅畅行11 小时前
钛投标标书查重系统技术架构与功能实现解析
大数据·人工智能
Stick_ZYZ11 小时前
从“能调用工具”到“能稳定执行任务”:Agent 工程化的下一步
java·人工智能·后端·spring·ai
宸一11 小时前
Day 4:用后端思维拆解Agent核心架构——三元组、工具调用、错误处理
人工智能
KaMeidebaby11 小时前
卡梅德生物技术快报|蛋白翻译后修饰:YAP/TAZ 分子调控机制与靶向干预技术
前端·人工智能·物联网·百度·新浪微博
阿里云大数据AI技术12 小时前
DataWorks Data Agent:从增强到自主,数据智能体的范式跃迁
人工智能·agent