不同层设置不同学习率

使用预训练模型时,可能需要将

(1)预训练好的 backbone 的 参数学习率设置为较小值,

(2)而backbone 之外的部分,需要使用较大的学习率。

python 复制代码
from collections import OrderedDict
import torch.nn as nn
import torch.optim as optim

net = nn.Sequential(OrderedDict([
    ("linear1", nn.Linear(10, 20)),
    ("linear2", nn.Linear(20, 30)),
    ("linear3", nn.Linear(30, 40))]))


linear3_params = list(map(id, net.linear3.parameters()))
base_params = filter(lambda p: id(p) not in linear3_params, net.parameters())

optimizer = optim.SGD([
    {'params': base_params},
    {'params': net.linear3.parameters(), 'lr': 0.0005}],
    lr=0.001, momentum=0.9)


print(optimizer)
print(optimizer.param_groups[0]['lr'])
print(optimizer.param_groups[1]['lr'])
相关推荐
盼小辉丶几秒前
TensorFlow深度学习实战(21)——Transformer架构详解与实现
深度学习·tensorflow·transformer
SunsPlanter3 分钟前
Word-- 制作论文三线表
学习
武昌库里写JAVA25 分钟前
iview组件库:当后台返回到的数据与使用官网组件指定的字段不匹配时,进行修改某个属性名再将response数据渲染到页面上的处理
java·开发语言·spring boot·学习·课程设计
东京老树根1 小时前
SAP学习笔记 - 开发29 - 前端Fiori开发 Custom Controls(自定义控件)
笔记·学习
꧁坚持很酷꧂1 小时前
FreeRTOS学习01_移植FreeRTOS到STM32(图文详解)
stm32·嵌入式硬件·学习
IOT.FIVE.NO.13 小时前
Conda安装pytorch和cuda出现问题的解决记录
人工智能·pytorch·python
~Yogi4 小时前
今日学习:Spring线程池|并发修改异常|链路丢失|登录续期|VIP过期策略|数值类缓存
学习·spring·缓存
Moonnnn.6 小时前
【单片机期末】单片机系统设计
笔记·单片机·嵌入式硬件·学习
行云流水剑8 小时前
【学习记录】使用 Kali Linux 与 Hashcat 进行 WiFi 安全分析:合法的安全测试指南
linux·学习·安全
yvestine8 小时前
自然语言处理——Transformer
人工智能·深度学习·自然语言处理·transformer