通俗易懂理解PyTorch之模型迁移和迁移学习

温故而知新,可以为师矣!

一、参考资料

Pytorch模型迁移和迁移学习,导入部分模型参数

二、PyTorch之模型迁移和迁移学习

1. 测试代码

python 复制代码
import torch
from torchvision import models


"""
model_urls = {
    'resnet18': 'https://download.pytorch.org/models/resnet18-f37072fd.pth',
    'resnet34': 'https://download.pytorch.org/models/resnet34-b627a593.pth',
    'resnet50': 'https://download.pytorch.org/models/resnet50-0676ba61.pth',
    'resnet101': 'https://download.pytorch.org/models/resnet101-63fe2227.pth',
    'resnet152': 'https://download.pytorch.org/models/resnet152-394f9c45.pth',
    'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
    'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
    'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
    'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
}
"""


if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # 获取设备信息
    print("-----device:{}".format(device))
    print("-----Pytorch version:{}".format(torch.__version__))

    input_tensor = torch.zeros(1, 3, 224, 224)  # 定义输入
    print('input_tensor:', input_tensor.shape)

    pretrained_file = "model/resnet18-f37072fd.pth"  # 预训练模型文件

    model = models.resnet18()  # 实例化网络结构
    pretrained_dict = torch.load(pretrained_file)  # 加载预训练模型的权重参数/字典
    model.load_state_dict(pretrained_dict)  # 加载网络结构的参数/字典,并绑定模型的权重参数/字典
    model.eval()

    out = model(input_tensor)  # 执行推理
    print("out:", out.shape, out[0, 0:10])

2. 修改网络结构的迁移学习

修改网络结构,预训练模型文件不变。

修改resnet18网络结构,将网络层名称 layer4 改为 layer44

源文件路径:/home/yoyo/miniconda3/envs/yolov5-pytorch/lib/python3.9/site-packages/torchvision/models/resnet.py

python 复制代码
class ResNet(nn.Module):

    def __init__(
        self,
        block: Type[Union[BasicBlock, Bottleneck]],
        layers: List[int],
        num_classes: int = 1000,
        zero_init_residual: bool = False,
        groups: int = 1,
        width_per_group: int = 64,
        replace_stride_with_dilation: Optional[List[bool]] = None,
        norm_layer: Optional[Callable[..., nn.Module]] = None
    ) -> None:
        super(ResNet, self).__init__()
        
        # ...
        self.layer44 = self._make_layer(block, 512, layers[3], stride=2,
                                       dilate=replace_stride_with_dilation[2])
        # ...

    def _forward_impl(self, x: Tensor) -> Tensor:
        # See note [TorchScript super()]
        # ...
        x = self.layer44(x)
        # ...

        return x

重新执行测试代码,出现以下错误:

bash 复制代码
Traceback (most recent call last):
  File "/PATH/TO/torchvision_demo.py", line 32, in <module>
    model.load_state_dict(model_dict)  # 加载参数字典
  File "/home/yoyo/miniconda3/envs/yolov5-pytorch/lib/python3.9/site-packages/torch/nn/modules/module.py", line 2041, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for ResNet:
	Missing key(s) in state_dict: "layer44.0.conv1.weight", "layer44.0.bn1.weight", "layer44.0.bn1.bias", "layer44.0.bn1.running_mean", "layer44.0.bn1.running_var", "layer44.0.conv2.weight", "layer44.0.bn2.weight", "layer44.0.bn2.bias", "layer44.0.bn2.running_mean", "layer44.0.bn2.running_var", "layer44.0.downsample.0.weight", "layer44.0.downsample.1.weight", "layer44.0.downsample.1.bias", "layer44.0.downsample.1.running_mean", "layer44.0.downsample.1.running_var", "layer44.1.conv1.weight", "layer44.1.bn1.weight", "layer44.1.bn1.bias", "layer44.1.bn1.running_mean", "layer44.1.bn1.running_var", "layer44.1.conv2.weight", "layer44.1.bn2.weight", "layer44.1.bn2.bias", "layer44.1.bn2.running_mean", "layer44.1.bn2.running_var". 
	Unexpected key(s) in state_dict: "layer4.0.conv1.weight", "layer4.0.bn1.running_mean", "layer4.0.bn1.running_var", "layer4.0.bn1.weight", "layer4.0.bn1.bias", "layer4.0.conv2.weight", "layer4.0.bn2.running_mean", "layer4.0.bn2.running_var", "layer4.0.bn2.weight", "layer4.0.bn2.bias", "layer4.0.downsample.0.weight", "layer4.0.downsample.1.running_mean", "layer4.0.downsample.1.running_var", "layer4.0.downsample.1.weight", "layer4.0.downsample.1.bias", "layer4.1.conv1.weight", "layer4.1.bn1.running_mean", "layer4.1.bn1.running_var", "layer4.1.bn1.weight", "layer4.1.bn1.bias", "layer4.1.conv2.weight", "layer4.1.bn2.running_mean", "layer4.1.bn2.running_var", "layer4.1.bn2.weight", "layer4.1.bn2.bias". 

Process finished with exit code 1

解释说明

  • Missing key(s) in state_dict: 这一行表示网络结构的权重参数;
  • Unexpected key(s) in state_dict: 这一行表示预训练模型的权重参数。

我们希望将原来预训练模型的权重参数(resnet18-f37072fd.pth)迁移到新的resnet18网络,当然只能迁移二者相同的权重参数,不同的权重参数还是随机初始化的。

python 复制代码
import torch
from torchvision import models


def transfer_model(pretrained_file, model):
    '''
    只导入pretrained_file部分模型参数
    tensor([-0.7119,  0.0688, -1.7247, -1.7182, -1.2161, -0.7323, -2.1065, -0.5433,-1.5893, -0.5562]
    update:
        D.update([E, ]**F) -> None.  Update D from dict/iterable E and F.
        If E is present and has a .keys() method, then does:  for k in E: D[k] = E[k]
        If E is present and lacks a .keys() method, then does:  for k, v in E: D[k] = v
        In either case, this is followed by: for k in F:  D[k] = F[k]
    :param pretrained_file:
    :param model:
    :return:
    '''
    pretrained_dict = torch.load(pretrained_file)  # 加载预训练模型的权重参数/字典
    model_dict = model.state_dict()  # 加载网络结构的权重参数/字典

    # 在合并前(update),需要去除pretrained_dict一些不需要的参数
    pretrained_dict = transfer_state_dict(pretrained_dict, model_dict)  # 去除pretrained_dict一些不需要的字典
    # pretrained_dict中有,但model_dict没有,则会增加到model_dict
    # pretrained_dict和model_dict都有,则更新为pretrained_dict的键值对
    model_dict.update(pretrained_dict)  # 更新(合并)网络结构的权重参数
    model.load_state_dict(model_dict)  # 加载网络结构的参数/字典,并绑定模型的权重参数/字典
    return model


def transfer_state_dict(pretrained_dict, model_dict):
    '''
    根据 model_dict,去除 pretrained_dict 一些不需要的字典,以便迁移到新的网络
    url: https://blog.csdn.net/qq_34914551/article/details/87871134
    :param pretrained_dict: 预训练模型的权重参数/字典
    :param model_dict: 网络结构的权重参数/字典
    :return:
    '''
    # state_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict.keys()}
    state_dict = {}
    for k, v in pretrained_dict.items():
        if k in model_dict.keys():
            # state_dict.setdefault(k, v)
            state_dict[k] = v
        else:
            print("Missing key(s) in state_dict :{}".format(k))
    return state_dict


if __name__ == "__main__":
    input_tensor = torch.zeros(1, 3, 224, 224)  # 定义输入
    print('input_tensor:', input_tensor.shape)

    pretrained_file = "model/resnet18-f37072fd.pth"  # 预训练模型文件

    model1 = models.resnet18()  # 实例化网络结构

    model1 = transfer_model(pretrained_file, model1)  # 模型迁移

    model1.eval()

    out1 = model1(input_tensor)  # 执行推理
    print("out1:", out1.shape, out1[0, 0:10])

输出结果

bash 复制代码
input_tensor: torch.Size([1, 3, 224, 224])
Missing key(s) in state_dict :layer4.0.conv1.weight
Missing key(s) in state_dict :layer4.0.bn1.running_mean
Missing key(s) in state_dict :layer4.0.bn1.running_var
Missing key(s) in state_dict :layer4.0.bn1.weight
Missing key(s) in state_dict :layer4.0.bn1.bias
Missing key(s) in state_dict :layer4.0.conv2.weight
Missing key(s) in state_dict :layer4.0.bn2.running_mean
Missing key(s) in state_dict :layer4.0.bn2.running_var
Missing key(s) in state_dict :layer4.0.bn2.weight
Missing key(s) in state_dict :layer4.0.bn2.bias
Missing key(s) in state_dict :layer4.0.downsample.0.weight
Missing key(s) in state_dict :layer4.0.downsample.1.running_mean
Missing key(s) in state_dict :layer4.0.downsample.1.running_var
Missing key(s) in state_dict :layer4.0.downsample.1.weight
Missing key(s) in state_dict :layer4.0.downsample.1.bias
Missing key(s) in state_dict :layer4.1.conv1.weight
Missing key(s) in state_dict :layer4.1.bn1.running_mean
Missing key(s) in state_dict :layer4.1.bn1.running_var
Missing key(s) in state_dict :layer4.1.bn1.weight
Missing key(s) in state_dict :layer4.1.bn1.bias
Missing key(s) in state_dict :layer4.1.conv2.weight
Missing key(s) in state_dict :layer4.1.bn2.running_mean
Missing key(s) in state_dict :layer4.1.bn2.running_var
Missing key(s) in state_dict :layer4.1.bn2.weight
Missing key(s) in state_dict :layer4.1.bn2.bias
out: torch.Size([1, 1000]) tensor([-0.1838, -0.5729, -0.2731, -0.7303, -0.3759, -0.2288, -0.5441,  0.5422,
         0.2055, -0.1339], grad_fn=<SliceBackward0>)

3. 修改预训练模型文件的迁移学习

修改预训练模型文件,以匹配新的网络结构。

前一章节仅修改网络结构,并未修改预训练模型文件。本章节,将修改预训练模型 model/resnet18-f37072fd.pth 以符合新的网络结构。

总体思路:只需要将预训练模型 resnet18-f37072fd.pth 的权重参数中所有前缀为 layer4 的名称,修改为 layer44 即可

python 复制代码
import torch
from torchvision import models


def string_rename(old_string, new_string, start, end):
    new_string = old_string[:start] + new_string + old_string[end:]
    return new_string


def modify_model(pretrained_file, model, old_prefix, new_prefix):
    '''
    :param pretrained_file:
    :param model:
    :param old_prefix:
    :param new_prefix:
    :return:
    '''
    pretrained_dict = torch.load(pretrained_file)  # 加载预训练模型的权重参数/字典
    model_dict = model.state_dict()  # 加载网络结构的权重参数/字典
    state_dict = modify_state_dict(pretrained_dict, model_dict, old_prefix, new_prefix)
    model.load_state_dict(state_dict)  # 加载网络结构的参数/字典,并绑定模型的权重参数/字典
    return model


def modify_state_dict(pretrained_dict, model_dict, old_prefix, new_prefix):
    '''
    修改 model dict
    :param pretrained_dict: 预训练模型的权重参数/字典
    :param model_dict: 网络结构的权重参数/字典
    :param old_prefix: ["layer4"]
    :param new_prefix: ["layer44"]
    :return:
    '''
    state_dict = {}
    for k, v in pretrained_dict.items():
        if k in model_dict.keys():
            # state_dict.setdefault(k, v)
            state_dict[k] = v
        else:
            for o, n in zip(old_prefix, new_prefix):
                prefix = k[:len(o)]
                if prefix == o:
                    kk = string_rename(old_string=k, new_string=n, start=0, end=len(o))
                    print("rename layer modules:{}-->{}".format(k, kk))
                    state_dict[kk] = v
    return state_dict


if __name__ == "__main__":
    input_tensor = torch.zeros(1, 3, 224, 224)  # 定义输入
    print('input_tensor:', input_tensor.shape)

    pretrained_file = "model/resnet18-f37072fd.pth"  # 预训练模型文件

    new_file = "model/new_model.pth"  # 修改后的预训练模型文件

    model = models.resnet18()  # 实例化网络结构

    new_model = modify_model(pretrained_file, model, old_prefix=["layer4"], new_prefix=["layer44"])
    torch.save(new_model.state_dict(), new_file)

    model2 = models.resnet18()  # 实例化网络结构
    model2.load_state_dict(torch.load(new_file))
    model2.eval()
    out2 = model2(input_tensor)  # 执行推理
    print("out2:", out2.shape, out2[0, 0:10])

输出结果

bash 复制代码
input_tensor: torch.Size([1, 3, 224, 224])
rename layer modules:layer4.0.conv1.weight-->layer44.0.conv1.weight
rename layer modules:layer4.0.bn1.running_mean-->layer44.0.bn1.running_mean
rename layer modules:layer4.0.bn1.running_var-->layer44.0.bn1.running_var
rename layer modules:layer4.0.bn1.weight-->layer44.0.bn1.weight
rename layer modules:layer4.0.bn1.bias-->layer44.0.bn1.bias
rename layer modules:layer4.0.conv2.weight-->layer44.0.conv2.weight
rename layer modules:layer4.0.bn2.running_mean-->layer44.0.bn2.running_mean
rename layer modules:layer4.0.bn2.running_var-->layer44.0.bn2.running_var
rename layer modules:layer4.0.bn2.weight-->layer44.0.bn2.weight
rename layer modules:layer4.0.bn2.bias-->layer44.0.bn2.bias
rename layer modules:layer4.0.downsample.0.weight-->layer44.0.downsample.0.weight
rename layer modules:layer4.0.downsample.1.running_mean-->layer44.0.downsample.1.running_mean
rename layer modules:layer4.0.downsample.1.running_var-->layer44.0.downsample.1.running_var
rename layer modules:layer4.0.downsample.1.weight-->layer44.0.downsample.1.weight
rename layer modules:layer4.0.downsample.1.bias-->layer44.0.downsample.1.bias
rename layer modules:layer4.1.conv1.weight-->layer44.1.conv1.weight
rename layer modules:layer4.1.bn1.running_mean-->layer44.1.bn1.running_mean
rename layer modules:layer4.1.bn1.running_var-->layer44.1.bn1.running_var
rename layer modules:layer4.1.bn1.weight-->layer44.1.bn1.weight
rename layer modules:layer4.1.bn1.bias-->layer44.1.bn1.bias
rename layer modules:layer4.1.conv2.weight-->layer44.1.conv2.weight
rename layer modules:layer4.1.bn2.running_mean-->layer44.1.bn2.running_mean
rename layer modules:layer4.1.bn2.running_var-->layer44.1.bn2.running_var
rename layer modules:layer4.1.bn2.weight-->layer44.1.bn2.weight
rename layer modules:layer4.1.bn2.bias-->layer44.1.bn2.bias
out2: torch.Size([1, 1000]) tensor([-0.0694,  0.6170, -1.9313, -0.9805, -0.8599,  0.7094, -2.0857, -1.5393,
        -2.0534, -1.2726], grad_fn=<SliceBackward0>)

4. 去除网络结构的某些层

python 复制代码
import torch
import torchvision.models as models
from collections import OrderedDict


if __name__ == "__main__":
    resnet18 = models.resnet18(False)
    print("resnet18", resnet18)

    # use named_children()
    resnet18_v1 = OrderedDict(resnet18.named_children())
    # remove avgpool,fc
    resnet18_v1.pop("avgpool")
    resnet18_v1.pop("fc")
    resnet18_v1 = torch.nn.Sequential(resnet18_v1)
    print("resnet18_v1", resnet18_v1)
    # use children
    resnet18_v2 = torch.nn.Sequential(*list(resnet18.children())[:-2])
    print(resnet18_v2, resnet18_v2)
    
相关推荐
四口鲸鱼爱吃盐4 小时前
Pytorch | 从零构建GoogleNet对CIFAR10进行分类
人工智能·pytorch·分类
leaf_leaves_leaf5 小时前
win11用一条命令给anaconda环境安装GPU版本pytorch,并检查是否为GPU版本
人工智能·pytorch·python
夜雨飘零15 小时前
基于Pytorch实现的说话人日志(说话人分离)
人工智能·pytorch·python·声纹识别·说话人分离·说话人日志
四口鲸鱼爱吃盐5 小时前
Pytorch | 从零构建MobileNet对CIFAR10进行分类
人工智能·pytorch·分类
苏言の狗5 小时前
Pytorch中关于Tensor的操作
人工智能·pytorch·python·深度学习·机器学习
四口鲸鱼爱吃盐11 小时前
Pytorch | 利用VMI-FGSM针对CIFAR10上的ResNet分类器进行对抗攻击
人工智能·pytorch·python
四口鲸鱼爱吃盐11 小时前
Pytorch | 利用PI-FGSM针对CIFAR10上的ResNet分类器进行对抗攻击
人工智能·pytorch·python
电子海鸥11 小时前
迁移学习--fasttext概述
人工智能·机器学习·迁移学习
love you joyfully1 天前
目标检测与R-CNN——pytorch与paddle实现目标检测与R-CNN
人工智能·pytorch·目标检测·cnn·paddle
这个男人是小帅1 天前
【AutoDL】通过【SSH远程连接】【vscode】
运维·人工智能·pytorch·vscode·深度学习·ssh