深度学习基础知识 学习率调度器的用法解析

深度学习基础知识 学习率调度器的用法解析

1、自定义学习率调度器**:**torch.optim.lr_scheduler.LambdaLR

实验代码:

python 复制代码
import torch
import torch.nn as nn

def lr_lambda(x):
    return x*2
    
net=nn.Sequential(nn.Conv2d(3,16,3,1,1))

optimizer=torch.optim.SGD(net.parameters(),lr=0.01,momentum=0.9)

lr_scheduler=torch.optim.lr_scheduler.LambdaLR(optimizer,lr_lambda=lr_lambda)

for _ in range(10):
    optimizer.step()
    lr_scheduler.step()
    print(optimizer.param_groups[0]['lr'])

打印结果:

分析数据变化如下图所示:

2、正儿八经的模型搭建流程以及学习率调度器的使用设置


代码:

python 复制代码
import torch
import torch.nn as nn
import numpy as np

def create_lr_scheduler(optimizer,
                        num_step:int,
                        epochs:int,
                        warmup=True,
                        warmup_epochs=1,
                        warmup_factor=1e-3):
    assert num_step>0 and epochs>0
    if warmup is False:
        warmup_epochs=0
    
    def f(x):
        """
            根据step数,返回一个学习率倍率因子,
            注意在训练开始之前,pytorch会提前调用一次create_lr_scheduler.step()方法

        
        """

        if warmup is True and x <= (warmup_epochs * num_step):
            alpha=float(x) / (warmup_epochs * num_step)
            # warmup过程中,学习率因子(learning rate factor):warmup_factor -----> 1
            return warmup_factor * (1-alpha) + alpha
        else:
            # warmup后,学习率因子(learning rate factor):warmup_factor -----> 0
            return (1-(x - warmup_epochs * num_step) / (epochs-warmup_epochs * num_step)) ** 0.9
        
    return torch.optim.lr_scheduler.LambdaLR(optimizer,lr_lambda=f)


net=nn.Sequential(nn.Conv2d(3,16,1,1))
optimizer=torch.optim.SGD(net.parameters(),lr=0.01,momentum=0.9)

lr_scheduler=create_lr_scheduler(optimizer=optimizer,num_step=5,epochs=20,warmup=True)

image=(np.random.rand(1,3,64,64)).astype(np.float32)
image_tensor=torch.tensor(image.copy(),dtype=torch.float32)
print(image.dtype)

for epoch in range(20):
    net.train()

    predict=net(image_tensor)
    optimizer.zero_grad()
    optimizer.step()
    lr_scheduler.step()
    print(optimizer.param_groups[0]['lr'])   # 打印学习率变化情况
相关推荐
AI营销实验室12 分钟前
原圈科技如何以多智能体赋能AI营销内容生产新范式
人工智能
视***间15 分钟前
智驱万物,视联未来 —— 视程空间以 AI 硬科技赋能全场景智能革新
人工智能·边缘计算·视程空间·ai算力开发板
一个java开发34 分钟前
mcp demo 智能天气服务:经纬度预报与城市警报
人工智能
阿里云大数据AI技术36 分钟前
OmniThoughtV:面向多模态深度思考的高质量数据蒸馏
人工智能
jkyy201440 分钟前
AI健康医疗开放平台:企业健康业务的“新基建”
大数据·人工智能·科技·健康医疗
hy15687861 小时前
coze编程-工作流-起起起---废(一句话生成工作流)
人工智能·coze·自动编程
hssfscv1 小时前
Javaweb 学习笔记——html+css
前端·笔记·学习
brave and determined1 小时前
CANN训练营 学习(day8)昇腾大模型推理调优实战指南
人工智能·算法·机器学习·ai实战·昇腾ai·ai推理·实战记录
Fuly10241 小时前
MCP协议的简介和简单实现
人工智能·langchain
Mr.Jessy1 小时前
JavaScript高级:深浅拷贝、异常处理、防抖及节流
开发语言·前端·javascript·学习