温故而知新,可以为师矣!
一、参考资料
二、PyTorch之模型迁移和迁移学习
1. 测试代码
python
import torch
from torchvision import models
"""
model_urls = {
'resnet18': 'https://download.pytorch.org/models/resnet18-f37072fd.pth',
'resnet34': 'https://download.pytorch.org/models/resnet34-b627a593.pth',
'resnet50': 'https://download.pytorch.org/models/resnet50-0676ba61.pth',
'resnet101': 'https://download.pytorch.org/models/resnet101-63fe2227.pth',
'resnet152': 'https://download.pytorch.org/models/resnet152-394f9c45.pth',
'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
}
"""
if __name__ == "__main__":
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 获取设备信息
print("-----device:{}".format(device))
print("-----Pytorch version:{}".format(torch.__version__))
input_tensor = torch.zeros(1, 3, 224, 224) # 定义输入
print('input_tensor:', input_tensor.shape)
pretrained_file = "model/resnet18-f37072fd.pth" # 预训练模型文件
model = models.resnet18() # 实例化网络结构
pretrained_dict = torch.load(pretrained_file) # 加载预训练模型的权重参数/字典
model.load_state_dict(pretrained_dict) # 加载网络结构的参数/字典,并绑定模型的权重参数/字典
model.eval()
out = model(input_tensor) # 执行推理
print("out:", out.shape, out[0, 0:10])
2. 修改网络结构的迁移学习
修改网络结构,预训练模型文件不变。
修改resnet18网络结构,将网络层名称 layer4
改为 layer44
。
源文件路径:/home/yoyo/miniconda3/envs/yolov5-pytorch/lib/python3.9/site-packages/torchvision/models/resnet.py
python
class ResNet(nn.Module):
def __init__(
self,
block: Type[Union[BasicBlock, Bottleneck]],
layers: List[int],
num_classes: int = 1000,
zero_init_residual: bool = False,
groups: int = 1,
width_per_group: int = 64,
replace_stride_with_dilation: Optional[List[bool]] = None,
norm_layer: Optional[Callable[..., nn.Module]] = None
) -> None:
super(ResNet, self).__init__()
# ...
self.layer44 = self._make_layer(block, 512, layers[3], stride=2,
dilate=replace_stride_with_dilation[2])
# ...
def _forward_impl(self, x: Tensor) -> Tensor:
# See note [TorchScript super()]
# ...
x = self.layer44(x)
# ...
return x
重新执行测试代码,出现以下错误:
bash
Traceback (most recent call last):
File "/PATH/TO/torchvision_demo.py", line 32, in <module>
model.load_state_dict(model_dict) # 加载参数字典
File "/home/yoyo/miniconda3/envs/yolov5-pytorch/lib/python3.9/site-packages/torch/nn/modules/module.py", line 2041, in load_state_dict
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for ResNet:
Missing key(s) in state_dict: "layer44.0.conv1.weight", "layer44.0.bn1.weight", "layer44.0.bn1.bias", "layer44.0.bn1.running_mean", "layer44.0.bn1.running_var", "layer44.0.conv2.weight", "layer44.0.bn2.weight", "layer44.0.bn2.bias", "layer44.0.bn2.running_mean", "layer44.0.bn2.running_var", "layer44.0.downsample.0.weight", "layer44.0.downsample.1.weight", "layer44.0.downsample.1.bias", "layer44.0.downsample.1.running_mean", "layer44.0.downsample.1.running_var", "layer44.1.conv1.weight", "layer44.1.bn1.weight", "layer44.1.bn1.bias", "layer44.1.bn1.running_mean", "layer44.1.bn1.running_var", "layer44.1.conv2.weight", "layer44.1.bn2.weight", "layer44.1.bn2.bias", "layer44.1.bn2.running_mean", "layer44.1.bn2.running_var".
Unexpected key(s) in state_dict: "layer4.0.conv1.weight", "layer4.0.bn1.running_mean", "layer4.0.bn1.running_var", "layer4.0.bn1.weight", "layer4.0.bn1.bias", "layer4.0.conv2.weight", "layer4.0.bn2.running_mean", "layer4.0.bn2.running_var", "layer4.0.bn2.weight", "layer4.0.bn2.bias", "layer4.0.downsample.0.weight", "layer4.0.downsample.1.running_mean", "layer4.0.downsample.1.running_var", "layer4.0.downsample.1.weight", "layer4.0.downsample.1.bias", "layer4.1.conv1.weight", "layer4.1.bn1.running_mean", "layer4.1.bn1.running_var", "layer4.1.bn1.weight", "layer4.1.bn1.bias", "layer4.1.conv2.weight", "layer4.1.bn2.running_mean", "layer4.1.bn2.running_var", "layer4.1.bn2.weight", "layer4.1.bn2.bias".
Process finished with exit code 1
解释说明
Missing key(s) in state_dict:
这一行表示网络结构的权重参数;Unexpected key(s) in state_dict:
这一行表示预训练模型的权重参数。
我们希望将原来预训练模型的权重参数(resnet18-f37072fd.pth
)迁移到新的resnet18网络,当然只能迁移二者相同的权重参数,不同的权重参数还是随机初始化的。
python
import torch
from torchvision import models
def transfer_model(pretrained_file, model):
'''
只导入pretrained_file部分模型参数
tensor([-0.7119, 0.0688, -1.7247, -1.7182, -1.2161, -0.7323, -2.1065, -0.5433,-1.5893, -0.5562]
update:
D.update([E, ]**F) -> None. Update D from dict/iterable E and F.
If E is present and has a .keys() method, then does: for k in E: D[k] = E[k]
If E is present and lacks a .keys() method, then does: for k, v in E: D[k] = v
In either case, this is followed by: for k in F: D[k] = F[k]
:param pretrained_file:
:param model:
:return:
'''
pretrained_dict = torch.load(pretrained_file) # 加载预训练模型的权重参数/字典
model_dict = model.state_dict() # 加载网络结构的权重参数/字典
# 在合并前(update),需要去除pretrained_dict一些不需要的参数
pretrained_dict = transfer_state_dict(pretrained_dict, model_dict) # 去除pretrained_dict一些不需要的字典
# pretrained_dict中有,但model_dict没有,则会增加到model_dict
# pretrained_dict和model_dict都有,则更新为pretrained_dict的键值对
model_dict.update(pretrained_dict) # 更新(合并)网络结构的权重参数
model.load_state_dict(model_dict) # 加载网络结构的参数/字典,并绑定模型的权重参数/字典
return model
def transfer_state_dict(pretrained_dict, model_dict):
'''
根据 model_dict,去除 pretrained_dict 一些不需要的字典,以便迁移到新的网络
url: https://blog.csdn.net/qq_34914551/article/details/87871134
:param pretrained_dict: 预训练模型的权重参数/字典
:param model_dict: 网络结构的权重参数/字典
:return:
'''
# state_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict.keys()}
state_dict = {}
for k, v in pretrained_dict.items():
if k in model_dict.keys():
# state_dict.setdefault(k, v)
state_dict[k] = v
else:
print("Missing key(s) in state_dict :{}".format(k))
return state_dict
if __name__ == "__main__":
input_tensor = torch.zeros(1, 3, 224, 224) # 定义输入
print('input_tensor:', input_tensor.shape)
pretrained_file = "model/resnet18-f37072fd.pth" # 预训练模型文件
model1 = models.resnet18() # 实例化网络结构
model1 = transfer_model(pretrained_file, model1) # 模型迁移
model1.eval()
out1 = model1(input_tensor) # 执行推理
print("out1:", out1.shape, out1[0, 0:10])
输出结果
bash
input_tensor: torch.Size([1, 3, 224, 224])
Missing key(s) in state_dict :layer4.0.conv1.weight
Missing key(s) in state_dict :layer4.0.bn1.running_mean
Missing key(s) in state_dict :layer4.0.bn1.running_var
Missing key(s) in state_dict :layer4.0.bn1.weight
Missing key(s) in state_dict :layer4.0.bn1.bias
Missing key(s) in state_dict :layer4.0.conv2.weight
Missing key(s) in state_dict :layer4.0.bn2.running_mean
Missing key(s) in state_dict :layer4.0.bn2.running_var
Missing key(s) in state_dict :layer4.0.bn2.weight
Missing key(s) in state_dict :layer4.0.bn2.bias
Missing key(s) in state_dict :layer4.0.downsample.0.weight
Missing key(s) in state_dict :layer4.0.downsample.1.running_mean
Missing key(s) in state_dict :layer4.0.downsample.1.running_var
Missing key(s) in state_dict :layer4.0.downsample.1.weight
Missing key(s) in state_dict :layer4.0.downsample.1.bias
Missing key(s) in state_dict :layer4.1.conv1.weight
Missing key(s) in state_dict :layer4.1.bn1.running_mean
Missing key(s) in state_dict :layer4.1.bn1.running_var
Missing key(s) in state_dict :layer4.1.bn1.weight
Missing key(s) in state_dict :layer4.1.bn1.bias
Missing key(s) in state_dict :layer4.1.conv2.weight
Missing key(s) in state_dict :layer4.1.bn2.running_mean
Missing key(s) in state_dict :layer4.1.bn2.running_var
Missing key(s) in state_dict :layer4.1.bn2.weight
Missing key(s) in state_dict :layer4.1.bn2.bias
out: torch.Size([1, 1000]) tensor([-0.1838, -0.5729, -0.2731, -0.7303, -0.3759, -0.2288, -0.5441, 0.5422,
0.2055, -0.1339], grad_fn=<SliceBackward0>)
3. 修改预训练模型文件的迁移学习
修改预训练模型文件,以匹配新的网络结构。
前一章节仅修改网络结构,并未修改预训练模型文件。本章节,将修改预训练模型 model/resnet18-f37072fd.pth
以符合新的网络结构。
总体思路:只需要将预训练模型 resnet18-f37072fd.pth
的权重参数中所有前缀为 layer4
的名称,修改为 layer44
即可。
python
import torch
from torchvision import models
def string_rename(old_string, new_string, start, end):
new_string = old_string[:start] + new_string + old_string[end:]
return new_string
def modify_model(pretrained_file, model, old_prefix, new_prefix):
'''
:param pretrained_file:
:param model:
:param old_prefix:
:param new_prefix:
:return:
'''
pretrained_dict = torch.load(pretrained_file) # 加载预训练模型的权重参数/字典
model_dict = model.state_dict() # 加载网络结构的权重参数/字典
state_dict = modify_state_dict(pretrained_dict, model_dict, old_prefix, new_prefix)
model.load_state_dict(state_dict) # 加载网络结构的参数/字典,并绑定模型的权重参数/字典
return model
def modify_state_dict(pretrained_dict, model_dict, old_prefix, new_prefix):
'''
修改 model dict
:param pretrained_dict: 预训练模型的权重参数/字典
:param model_dict: 网络结构的权重参数/字典
:param old_prefix: ["layer4"]
:param new_prefix: ["layer44"]
:return:
'''
state_dict = {}
for k, v in pretrained_dict.items():
if k in model_dict.keys():
# state_dict.setdefault(k, v)
state_dict[k] = v
else:
for o, n in zip(old_prefix, new_prefix):
prefix = k[:len(o)]
if prefix == o:
kk = string_rename(old_string=k, new_string=n, start=0, end=len(o))
print("rename layer modules:{}-->{}".format(k, kk))
state_dict[kk] = v
return state_dict
if __name__ == "__main__":
input_tensor = torch.zeros(1, 3, 224, 224) # 定义输入
print('input_tensor:', input_tensor.shape)
pretrained_file = "model/resnet18-f37072fd.pth" # 预训练模型文件
new_file = "model/new_model.pth" # 修改后的预训练模型文件
model = models.resnet18() # 实例化网络结构
new_model = modify_model(pretrained_file, model, old_prefix=["layer4"], new_prefix=["layer44"])
torch.save(new_model.state_dict(), new_file)
model2 = models.resnet18() # 实例化网络结构
model2.load_state_dict(torch.load(new_file))
model2.eval()
out2 = model2(input_tensor) # 执行推理
print("out2:", out2.shape, out2[0, 0:10])
输出结果
bash
input_tensor: torch.Size([1, 3, 224, 224])
rename layer modules:layer4.0.conv1.weight-->layer44.0.conv1.weight
rename layer modules:layer4.0.bn1.running_mean-->layer44.0.bn1.running_mean
rename layer modules:layer4.0.bn1.running_var-->layer44.0.bn1.running_var
rename layer modules:layer4.0.bn1.weight-->layer44.0.bn1.weight
rename layer modules:layer4.0.bn1.bias-->layer44.0.bn1.bias
rename layer modules:layer4.0.conv2.weight-->layer44.0.conv2.weight
rename layer modules:layer4.0.bn2.running_mean-->layer44.0.bn2.running_mean
rename layer modules:layer4.0.bn2.running_var-->layer44.0.bn2.running_var
rename layer modules:layer4.0.bn2.weight-->layer44.0.bn2.weight
rename layer modules:layer4.0.bn2.bias-->layer44.0.bn2.bias
rename layer modules:layer4.0.downsample.0.weight-->layer44.0.downsample.0.weight
rename layer modules:layer4.0.downsample.1.running_mean-->layer44.0.downsample.1.running_mean
rename layer modules:layer4.0.downsample.1.running_var-->layer44.0.downsample.1.running_var
rename layer modules:layer4.0.downsample.1.weight-->layer44.0.downsample.1.weight
rename layer modules:layer4.0.downsample.1.bias-->layer44.0.downsample.1.bias
rename layer modules:layer4.1.conv1.weight-->layer44.1.conv1.weight
rename layer modules:layer4.1.bn1.running_mean-->layer44.1.bn1.running_mean
rename layer modules:layer4.1.bn1.running_var-->layer44.1.bn1.running_var
rename layer modules:layer4.1.bn1.weight-->layer44.1.bn1.weight
rename layer modules:layer4.1.bn1.bias-->layer44.1.bn1.bias
rename layer modules:layer4.1.conv2.weight-->layer44.1.conv2.weight
rename layer modules:layer4.1.bn2.running_mean-->layer44.1.bn2.running_mean
rename layer modules:layer4.1.bn2.running_var-->layer44.1.bn2.running_var
rename layer modules:layer4.1.bn2.weight-->layer44.1.bn2.weight
rename layer modules:layer4.1.bn2.bias-->layer44.1.bn2.bias
out2: torch.Size([1, 1000]) tensor([-0.0694, 0.6170, -1.9313, -0.9805, -0.8599, 0.7094, -2.0857, -1.5393,
-2.0534, -1.2726], grad_fn=<SliceBackward0>)
4. 去除网络结构的某些层
python
import torch
import torchvision.models as models
from collections import OrderedDict
if __name__ == "__main__":
resnet18 = models.resnet18(False)
print("resnet18", resnet18)
# use named_children()
resnet18_v1 = OrderedDict(resnet18.named_children())
# remove avgpool,fc
resnet18_v1.pop("avgpool")
resnet18_v1.pop("fc")
resnet18_v1 = torch.nn.Sequential(resnet18_v1)
print("resnet18_v1", resnet18_v1)
# use children
resnet18_v2 = torch.nn.Sequential(*list(resnet18.children())[:-2])
print(resnet18_v2, resnet18_v2)