训练过程中可能遇到的问题

训练过程中可能遇到的问题

1. 前向传播正常,反向更新梯度时报错。

可能的原因是,你想给基线网络加含参数模块,但是这个模块你只在forward中穿插进去了,而在构建网络时,你将这个模块随便搞了个位置new了出来。这样在反向传播时,构建的传播图中没有你的新模块的位置。导致反向传播无法计算梯度并更新。

原来的代码是:

PYTHON 复制代码
class Net(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        input = nn.Block()
        layer1 = nn.Block()
        layer2 = nn.Block()
        layer3 = nn.Block()
        layer4 = nn.Block()
        layer5 = nn.Block()
        output = nn.Block()
    
    def forward(self, x):
        x = input(x)
        x = layer1(x)
        x = layer2(x)
        x = layer3(x)
        x = layer4(x)
        x = layer5(x)
        return x 

错误的代码是:

PYTHON 复制代码
class Net(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        input = nn.Block()
        layer1 = nn.Block()
        layer2 = nn.Block()
        layer3 = nn.Block()
        layer4 = nn.Block()
        layer5 = nn.Block()
        output = nn.Block()
        new_block = nn.Block() # 错误,位置不对
    
    def forward(self, x): # forward没问题
        x = input(x)
        x = layer1(x)
        x = layer2(x)
        brunch = new_block(x)
        x = layer3(x)
        x = layer4(x)
        x = layer5(x)
        x = output(x + brunch)
        return x 

正确的代码应该是:

PYTHON 复制代码
class Net(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        input = nn.Block()
        layer1 = nn.Block()
        layer2 = nn.Block()
        new_block = nn.Block() # 这里才对
        layer3 = nn.Block()
        layer4 = nn.Block()
        layer5 = nn.Block()
        output = nn.Block()
    
    def forward(self, x):
        x = input(x)
        x = layer1(x)
        x = layer2(x)
        brunch = new_block(x)
        x = layer3(x)
        x = layer4(x)
        x = layer5(x)
        x = output(x + brunch)
        return x 

上面例子过于简单,很可能这样写并没有问题。但是当你的layer2是一个复合模块,且你的new_block需要从layer内部分支出来时。就会出问题。

那么在这个时候,我们就需要将基线网络的原本结构更改,将原本复合的layer2模块扁平化,然后再插入分支。

如下图所示:

你可能会想到:

PYTHON 复制代码
class Net(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        input = nn.Block()
        layer1 = nn.Block()
        layer2 = nn.Block()
        new_block = nn.Block() # 你可能想到的位置1
        layer4 = nn.Block()
        layer5 = nn.Block()
        output = nn.Block()
        new_block = nn.Block() # 你可能想到的位置2
    
    def forward(self, x):
        x = input(x)
        x = layer1(x)
        x = layer2[0](x)
        brunch = new_block(x)
        x = layer2[1](x)
        x = layer4(x)
        x = layer5(x)
        x = output(x + brunch)
        return x 

很遗憾,这样的话,new_block是无法计算梯度的。正确的做法应该是吧图2中的架构扁平化为图1,再加入new_block。

那么有个问题,更改基线模型了之后(比如扁平化),加载预训练权重会无法对应,怎么办?

2. 修改基线模型后,如何最大化利用原基线模型的预训练权重?

如图所示

如果你将图3展平成图4,那么按图4的model去load图3的state_dict是不行的,layer2 layer3加载不了权重。当这种冲突比较少时,可以手动通过pop替换。当这种冲突多的时候。我们要怎么办呢?以下是我的方案:

  1. 将原基线模型重组成新的基线模型,但是保持元架构不变,即最小单位模块的数目及顺序/位置不变,即两个模型实质等同,只是元件的命名不同。
  2. 分别读取原基线模型的预训练权重和新基线模型的参数词典
  3. 逐层按张量形状,将原基线模型的权重更新到新基线模型的参数词典中
  4. 把新参数词典中的权重文件保存下来

当你需要更改模块时,直接在新基线模型的基础上进行更改,然后读取新保存的权重作为预训练权重即可。代码参考下方。

PYTHON 复制代码
def load_state_dict_and_drop(self, path:str = None, strict: bool = False):
    pretrain_state_dict = torch.load(path)
    model_state_dict = self.model.state_dict()
    pretrain_key_list = list(pretrain_state_dict.keys())
    model_key_list = list(model_state_dict.keys())
    key_nums = len(pretrain_key_list)
    for i in range(key_nums):
        mod = model_key_list[i]
        pre = pretrain_key_list[i]
        if model_state_dict[mod].shape != pretrain_state_dict[pre].shape:
            print(mod + "不匹配" + pre)
        else:
            model_state_dict[mod] = pretrain_state_dict[pre]
    # model_state_dict.pop("head.weight")
    # model_state_dict.pop("head.bias")
    missing_keys, unexpected_keys = self.load_state_dict(model_state_dict, True)
    torch.save(self.model.state_dict(), "/sj/hold-on/pretrain_model_ckpt/net_pritrain_my.pth")
    return (missing_keys, unexpected_keys)

原理在于state_dict是一个有序的dict。

相关推荐
xingshanchang2 小时前
PyTorch 不支持旧GPU的异常状态与解决方案:CUDNN_STATUS_NOT_SUPPORTED_ARCH_MISMATCH
人工智能·pytorch·python
reddingtons3 小时前
Adobe Firefly AI驱动设计:实用技巧与创新思维路径
大数据·人工智能·adobe·illustrator·photoshop·premiere·indesign
CertiK3 小时前
IBW 2025: CertiK首席商务官出席,探讨AI与Web3融合带来的安全挑战
人工智能·安全·web3
Deepoch4 小时前
Deepoc 大模型在无人机行业应用效果的方法
人工智能·科技·ai·语言模型·无人机
Deepoch4 小时前
Deepoc 大模型:无人机行业的智能变革引擎
人工智能·科技·算法·ai·动态规划·无人机
kngines4 小时前
【字节跳动】数据挖掘面试题0003:有一个文件,每一行是一个数字,如何用 MapReduce 进行排序和求每个用户每个页面停留时间
人工智能·数据挖掘·mapreduce·面试题
Binary_ey4 小时前
AR衍射光波导设计遇瓶颈,OAS 光学软件来破局
人工智能·软件需求·光学软件·光波导
昵称是6硬币4 小时前
YOLOv11: AN OVERVIEW OF THE KEY ARCHITECTURAL ENHANCEMENTS目标检测论文精读(逐段解析)
图像处理·人工智能·深度学习·yolo·目标检测·计算机视觉
平和男人杨争争5 小时前
机器学习2——贝叶斯理论下
人工智能·机器学习
静心问道5 小时前
XLSR-Wav2Vec2:用于语音识别的无监督跨语言表示学习
人工智能·学习·语音识别