对pytorch optimizer中state_dict、state、param_groups的简要理解

先说结论:

  • state_dict():一个dict,里面有两个key(stateparam_groups),

    • state这个key对应的value是各个权重对应的优化器状态。具体来说,一个model有很多权重,model.parameters()会打印出该模型的各层的权重,比如使用Adam,每层权重都有一个momentum和variance,形状与权重相同,还有该层当前更新到的步数。state_dict()['state']是一个dict,每个key-value item结构如下:

      该权重在model.parameters()中的位置 : {
      	'step': tensor, 
      	'exp_avg': tensor, # exp_avg: exponential moving average of gradient values
      	'exp_avg_sq: tensor # exp_avg_sq: exponential moving average of squared gradient values
      
    • param_groups这个key对应的value是一个list,其中每个元素都是超参数组成的一个dict,因为不同的权重可以使用不同的超参数,所以需要使用list来表示,而且dict中params表示该超参数配置作用于哪些权重。state_dict()['param_groups']是一个list,每个元素结构如下

      {'lr': 0.01, 'weight_decay': 0,  ...  , 'params', [该超参数配置作用于的权重的位置]}
      
  • state:是一个defaultdict,包含的信息类似于state_dict()['state']+model.parameters(),具体来说,每个key-value item结构如下:

    param_tensor :{
    	'step': tensor, 
    	'exp_avg': tensor, 
    	'exp_avg_sq': tensor,	
    }
    
  • param_groups:是一个list,包含的信息类似于state_dict()['param_groups']+model.parameters(),具体来说,每个元素结构如下:

    {
    	'params': [param1, param2, ...]
    	'lr': 0.01, 
    	'weight_decay': 0, 
    	...
    	# 注意相较于state_dict()['param_groups'],原来'params'这个key对应的是param的索引位置,现在直接就是tensor了
    }
    

示例代码:

python 复制代码
import torch
from torch.nn import Module
from torch.optim import Adam


class MyModel(Module):
    def __init__(self, in_dim, hidden_dim):
        super(MyModel, self).__init__()
        self.linear = torch.nn.Linear(in_features=in_dim, out_features=hidden_dim, bias=True)
        self.linear2 = torch.nn.Linear(in_features=hidden_dim, out_features=in_dim, bias=False)
    def forward(self, x):
        y = self.linear(x)
        out = self.linear2(y)
        return out


in_dim = 5
hidden_dim = 2
model = MyModel(in_dim=in_dim, hidden_dim=hidden_dim)

optimier = Adam([
    {'params': model.linear.parameters(), 'lr': 0.05},
    {'params': model.linear2.parameters()}
], lr=0.01)


x = torch.randn((in_dim))
out = model(x)
loss = torch.sum(out, dim=-1)
optimier.zero_grad()
loss.backward()
optimier.step()

print('#' * 100)
print(optimier.state_dict())

print('#' * 100)
print(optimier.state)

print('#' * 100)
print(optimier.param_groups)

输出:

json 复制代码
####################################################################################################
# state_dict()
{
    'state': {
        0: {
            'step': tensor(1.), 
            'exp_avg': tensor([[ 0.0503,  0.0738, -0.0199,  0.0365, -0.0079],[ 0.0139,  0.0204, -0.0055,  0.0101, -0.0022]]), 
            'exp_avg_sq': tensor([[2.5308e-04, 5.4452e-04, 3.9464e-05, 1.3313e-04, 6.2210e-06],[1.9335e-05, 4.1600e-05, 3.0150e-06, 1.0171e-05, 4.7527e-07]])
        }, 
        1: {
            'step': tensor(1.), 
            'exp_avg': tensor([0.0406, 0.0112]), 
            'exp_avg_sq': tensor([1.6472e-04, 1.2584e-05])
        }, 
        2: {
            'step': tensor(1.), 
            'exp_avg': tensor([[-0.0268,  0.0085],[-0.0268,  0.0085],[-0.0268,  0.0085],[-0.0268,  0.0085],[-0.0268,  0.0085]]), 
            'exp_avg_sq': tensor([[7.1624e-05, 7.2090e-06],[7.1624e-05, 7.2090e-06],[7.1624e-05, 7.2090e-06],[7.1624e-05, 7.2090e-06],[7.1624e-05, 7.2090e-06]])
        }
    }, 
    'param_groups': [
        {'lr': 0.05, 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False, 'maximize': False, 'foreach': None, 'capturable': False, 'differentiable': False, 'fused': False, 
         'params': [0, 1]}, 
        {'lr': 0.01, 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False, 'maximize': False, 'foreach': None, 'capturable': False, 'differentiable': False, 'fused': False, 
         'params': [2]}]}

####################################################################################################
# state
defaultdict(<class 'dict'>, 
    {
        Parameter containing: tensor([[-0.1744, -0.0656,  0.3184, -0.2081,  0.2448],[ 0.3069, -0.4000, -0.0727,  0.3283,  0.1722]], requires_grad=True): {
            'step': tensor(1.), 
            'exp_avg': tensor([[ 0.0503,  0.0738, -0.0199,  0.0365, -0.0079],[ 0.0139,  0.0204, -0.0055,  0.0101, -0.0022]]), 
            'exp_avg_sq': tensor([[2.5308e-04, 5.4452e-04, 3.9464e-05, 1.3313e-04, 6.2210e-06],[1.9335e-05, 4.1600e-05, 3.0150e-06, 1.0171e-05, 4.7527e-07]])
        }, 
        Parameter containing: tensor([ 0.1764, -0.1476], requires_grad=True): {
            'step': tensor(1.), 
            'exp_avg': tensor([0.0406, 0.0112]), 
            'exp_avg_sq': tensor([1.6472e-04, 1.2584e-05])
        }, 
        Parameter containing: tensor([[-0.2588, -0.5732],[-0.2472,  0.2319],[ 0.4441, -0.6283],[ 0.5832,  0.3760],[-0.0654,  0.6558]], requires_grad=True): {
            'step': tensor(1.), 
            'exp_avg': tensor([[-0.0268,  0.0085],[-0.0268,  0.0085],[-0.0268,  0.0085],[-0.0268,  0.0085],[-0.0268,  0.0085]]), 
            'exp_avg_sq': tensor([[7.1624e-05, 7.2090e-06],[7.1624e-05, 7.2090e-06],[7.1624e-05, 7.2090e-06],[7.1624e-05, 7.2090e-06],[7.1624e-05, 7.2090e-06]])
        }
    }
)
####################################################################################################
# param_groups
[
    {
        'params': [
            Parameter containing: tensor([[-0.1744, -0.0656,  0.3184, -0.2081,  0.2448],[ 0.3069, -0.4000, -0.0727,  0.3283,  0.1722]], requires_grad=True), 
            Parameter containing: tensor([ 0.1764, -0.1476], requires_grad=True)
        ], 
        'lr': 0.05, 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False, 'maximize': False, 'foreach': None, 'capturable': False, 'differentiable': False, 'fused': False
        }, 
    {
        'params': [Parameter containing: tensor([[-0.2588, -0.5732],[-0.2472,  0.2319],[ 0.4441, -0.6283],[ 0.5832,  0.3760],[-0.0654,  0.6558]], requires_grad=True)], 
        'lr': 0.01, 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False, 'maximize': False, 'foreach': None, 'capturable': False, 'differentiable': False, 'fused': False
    }
]
相关推荐
佚明zj38 分钟前
全卷积和全连接
人工智能·深度学习
qzhqbb3 小时前
基于统计方法的语言模型
人工智能·语言模型·easyui
冷眼看人间恩怨4 小时前
【话题讨论】AI大模型重塑软件开发:定义、应用、优势与挑战
人工智能·ai编程·软件开发
2401_883041084 小时前
新锐品牌电商代运营公司都有哪些?
大数据·人工智能
AI极客菌5 小时前
Controlnet作者新作IC-light V2:基于FLUX训练,支持处理风格化图像,细节远高于SD1.5。
人工智能·计算机视觉·ai作画·stable diffusion·aigc·flux·人工智能作画
阿_旭5 小时前
一文读懂| 自注意力与交叉注意力机制在计算机视觉中作用与基本原理
人工智能·深度学习·计算机视觉·cross-attention·self-attention
王哈哈^_^5 小时前
【数据集】【YOLO】【目标检测】交通事故识别数据集 8939 张,YOLO道路事故目标检测实战训练教程!
前端·人工智能·深度学习·yolo·目标检测·计算机视觉·pyqt
Power20246666 小时前
NLP论文速读|LongReward:基于AI反馈来提升长上下文大语言模型
人工智能·深度学习·机器学习·自然语言处理·nlp
数据猎手小k6 小时前
AIDOVECL数据集:包含超过15000张AI生成的车辆图像数据集,目的解决旨在解决眼水平分类和定位问题。
人工智能·分类·数据挖掘
好奇龙猫6 小时前
【学习AI-相关路程-mnist手写数字分类-win-硬件:windows-自我学习AI-实验步骤-全连接神经网络(BPnetwork)-操作流程(3) 】
人工智能·算法