先说结论:
-
state_dict():一个dict,里面有两个key(
state
和param_groups
),-
state
这个key对应的value是各个权重对应的优化器状态。具体来说,一个model有很多权重,model.parameters()
会打印出该模型的各层的权重,比如使用Adam,每层权重都有一个momentum和variance,形状与权重相同,还有该层当前更新到的步数。state_dict()['state']
是一个dict,每个key-value item结构如下:该权重在model.parameters()中的位置 : { 'step': tensor, 'exp_avg': tensor, # exp_avg: exponential moving average of gradient values 'exp_avg_sq: tensor # exp_avg_sq: exponential moving average of squared gradient values
-
param_groups
这个key对应的value是一个list,其中每个元素都是超参数组成的一个dict,因为不同的权重可以使用不同的超参数,所以需要使用list来表示,而且dict中params
表示该超参数配置作用于哪些权重。state_dict()['param_groups']
是一个list,每个元素结构如下{'lr': 0.01, 'weight_decay': 0, ... , 'params', [该超参数配置作用于的权重的位置]}
-
-
state:是一个defaultdict,包含的信息类似于
state_dict()['state']
+model.parameters()
,具体来说,每个key-value item结构如下:param_tensor :{ 'step': tensor, 'exp_avg': tensor, 'exp_avg_sq': tensor, }
-
param_groups:是一个list,包含的信息类似于
state_dict()['param_groups']
+model.parameters()
,具体来说,每个元素结构如下:{ 'params': [param1, param2, ...] 'lr': 0.01, 'weight_decay': 0, ... # 注意相较于state_dict()['param_groups'],原来'params'这个key对应的是param的索引位置,现在直接就是tensor了 }
示例代码:
python
import torch
from torch.nn import Module
from torch.optim import Adam
class MyModel(Module):
def __init__(self, in_dim, hidden_dim):
super(MyModel, self).__init__()
self.linear = torch.nn.Linear(in_features=in_dim, out_features=hidden_dim, bias=True)
self.linear2 = torch.nn.Linear(in_features=hidden_dim, out_features=in_dim, bias=False)
def forward(self, x):
y = self.linear(x)
out = self.linear2(y)
return out
in_dim = 5
hidden_dim = 2
model = MyModel(in_dim=in_dim, hidden_dim=hidden_dim)
optimier = Adam([
{'params': model.linear.parameters(), 'lr': 0.05},
{'params': model.linear2.parameters()}
], lr=0.01)
x = torch.randn((in_dim))
out = model(x)
loss = torch.sum(out, dim=-1)
optimier.zero_grad()
loss.backward()
optimier.step()
print('#' * 100)
print(optimier.state_dict())
print('#' * 100)
print(optimier.state)
print('#' * 100)
print(optimier.param_groups)
输出:
json
####################################################################################################
# state_dict()
{
'state': {
0: {
'step': tensor(1.),
'exp_avg': tensor([[ 0.0503, 0.0738, -0.0199, 0.0365, -0.0079],[ 0.0139, 0.0204, -0.0055, 0.0101, -0.0022]]),
'exp_avg_sq': tensor([[2.5308e-04, 5.4452e-04, 3.9464e-05, 1.3313e-04, 6.2210e-06],[1.9335e-05, 4.1600e-05, 3.0150e-06, 1.0171e-05, 4.7527e-07]])
},
1: {
'step': tensor(1.),
'exp_avg': tensor([0.0406, 0.0112]),
'exp_avg_sq': tensor([1.6472e-04, 1.2584e-05])
},
2: {
'step': tensor(1.),
'exp_avg': tensor([[-0.0268, 0.0085],[-0.0268, 0.0085],[-0.0268, 0.0085],[-0.0268, 0.0085],[-0.0268, 0.0085]]),
'exp_avg_sq': tensor([[7.1624e-05, 7.2090e-06],[7.1624e-05, 7.2090e-06],[7.1624e-05, 7.2090e-06],[7.1624e-05, 7.2090e-06],[7.1624e-05, 7.2090e-06]])
}
},
'param_groups': [
{'lr': 0.05, 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False, 'maximize': False, 'foreach': None, 'capturable': False, 'differentiable': False, 'fused': False,
'params': [0, 1]},
{'lr': 0.01, 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False, 'maximize': False, 'foreach': None, 'capturable': False, 'differentiable': False, 'fused': False,
'params': [2]}]}
####################################################################################################
# state
defaultdict(<class 'dict'>,
{
Parameter containing: tensor([[-0.1744, -0.0656, 0.3184, -0.2081, 0.2448],[ 0.3069, -0.4000, -0.0727, 0.3283, 0.1722]], requires_grad=True): {
'step': tensor(1.),
'exp_avg': tensor([[ 0.0503, 0.0738, -0.0199, 0.0365, -0.0079],[ 0.0139, 0.0204, -0.0055, 0.0101, -0.0022]]),
'exp_avg_sq': tensor([[2.5308e-04, 5.4452e-04, 3.9464e-05, 1.3313e-04, 6.2210e-06],[1.9335e-05, 4.1600e-05, 3.0150e-06, 1.0171e-05, 4.7527e-07]])
},
Parameter containing: tensor([ 0.1764, -0.1476], requires_grad=True): {
'step': tensor(1.),
'exp_avg': tensor([0.0406, 0.0112]),
'exp_avg_sq': tensor([1.6472e-04, 1.2584e-05])
},
Parameter containing: tensor([[-0.2588, -0.5732],[-0.2472, 0.2319],[ 0.4441, -0.6283],[ 0.5832, 0.3760],[-0.0654, 0.6558]], requires_grad=True): {
'step': tensor(1.),
'exp_avg': tensor([[-0.0268, 0.0085],[-0.0268, 0.0085],[-0.0268, 0.0085],[-0.0268, 0.0085],[-0.0268, 0.0085]]),
'exp_avg_sq': tensor([[7.1624e-05, 7.2090e-06],[7.1624e-05, 7.2090e-06],[7.1624e-05, 7.2090e-06],[7.1624e-05, 7.2090e-06],[7.1624e-05, 7.2090e-06]])
}
}
)
####################################################################################################
# param_groups
[
{
'params': [
Parameter containing: tensor([[-0.1744, -0.0656, 0.3184, -0.2081, 0.2448],[ 0.3069, -0.4000, -0.0727, 0.3283, 0.1722]], requires_grad=True),
Parameter containing: tensor([ 0.1764, -0.1476], requires_grad=True)
],
'lr': 0.05, 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False, 'maximize': False, 'foreach': None, 'capturable': False, 'differentiable': False, 'fused': False
},
{
'params': [Parameter containing: tensor([[-0.2588, -0.5732],[-0.2472, 0.2319],[ 0.4441, -0.6283],[ 0.5832, 0.3760],[-0.0654, 0.6558]], requires_grad=True)],
'lr': 0.01, 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False, 'maximize': False, 'foreach': None, 'capturable': False, 'differentiable': False, 'fused': False
}
]