训练过程中可能遇到的问题

训练过程中可能遇到的问题

1. 前向传播正常,反向更新梯度时报错。

可能的原因是,你想给基线网络加含参数模块,但是这个模块你只在forward中穿插进去了,而在构建网络时,你将这个模块随便搞了个位置new了出来。这样在反向传播时,构建的传播图中没有你的新模块的位置。导致反向传播无法计算梯度并更新。

原来的代码是:

PYTHON 复制代码
class Net(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        input = nn.Block()
        layer1 = nn.Block()
        layer2 = nn.Block()
        layer3 = nn.Block()
        layer4 = nn.Block()
        layer5 = nn.Block()
        output = nn.Block()
    
    def forward(self, x):
        x = input(x)
        x = layer1(x)
        x = layer2(x)
        x = layer3(x)
        x = layer4(x)
        x = layer5(x)
        return x 

错误的代码是:

PYTHON 复制代码
class Net(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        input = nn.Block()
        layer1 = nn.Block()
        layer2 = nn.Block()
        layer3 = nn.Block()
        layer4 = nn.Block()
        layer5 = nn.Block()
        output = nn.Block()
        new_block = nn.Block() # 错误,位置不对
    
    def forward(self, x): # forward没问题
        x = input(x)
        x = layer1(x)
        x = layer2(x)
        brunch = new_block(x)
        x = layer3(x)
        x = layer4(x)
        x = layer5(x)
        x = output(x + brunch)
        return x 

正确的代码应该是:

PYTHON 复制代码
class Net(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        input = nn.Block()
        layer1 = nn.Block()
        layer2 = nn.Block()
        new_block = nn.Block() # 这里才对
        layer3 = nn.Block()
        layer4 = nn.Block()
        layer5 = nn.Block()
        output = nn.Block()
    
    def forward(self, x):
        x = input(x)
        x = layer1(x)
        x = layer2(x)
        brunch = new_block(x)
        x = layer3(x)
        x = layer4(x)
        x = layer5(x)
        x = output(x + brunch)
        return x 

上面例子过于简单,很可能这样写并没有问题。但是当你的layer2是一个复合模块,且你的new_block需要从layer内部分支出来时。就会出问题。

那么在这个时候,我们就需要将基线网络的原本结构更改,将原本复合的layer2模块扁平化,然后再插入分支。

如下图所示:

你可能会想到:

PYTHON 复制代码
class Net(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        input = nn.Block()
        layer1 = nn.Block()
        layer2 = nn.Block()
        new_block = nn.Block() # 你可能想到的位置1
        layer4 = nn.Block()
        layer5 = nn.Block()
        output = nn.Block()
        new_block = nn.Block() # 你可能想到的位置2
    
    def forward(self, x):
        x = input(x)
        x = layer1(x)
        x = layer2[0](x)
        brunch = new_block(x)
        x = layer2[1](x)
        x = layer4(x)
        x = layer5(x)
        x = output(x + brunch)
        return x 

很遗憾,这样的话,new_block是无法计算梯度的。正确的做法应该是吧图2中的架构扁平化为图1,再加入new_block。

那么有个问题,更改基线模型了之后(比如扁平化),加载预训练权重会无法对应,怎么办?

2. 修改基线模型后,如何最大化利用原基线模型的预训练权重?

如图所示

如果你将图3展平成图4,那么按图4的model去load图3的state_dict是不行的,layer2 layer3加载不了权重。当这种冲突比较少时,可以手动通过pop替换。当这种冲突多的时候。我们要怎么办呢?以下是我的方案:

  1. 将原基线模型重组成新的基线模型,但是保持元架构不变,即最小单位模块的数目及顺序/位置不变,即两个模型实质等同,只是元件的命名不同。
  2. 分别读取原基线模型的预训练权重和新基线模型的参数词典
  3. 逐层按张量形状,将原基线模型的权重更新到新基线模型的参数词典中
  4. 把新参数词典中的权重文件保存下来

当你需要更改模块时,直接在新基线模型的基础上进行更改,然后读取新保存的权重作为预训练权重即可。代码参考下方。

PYTHON 复制代码
def load_state_dict_and_drop(self, path:str = None, strict: bool = False):
    pretrain_state_dict = torch.load(path)
    model_state_dict = self.model.state_dict()
    pretrain_key_list = list(pretrain_state_dict.keys())
    model_key_list = list(model_state_dict.keys())
    key_nums = len(pretrain_key_list)
    for i in range(key_nums):
        mod = model_key_list[i]
        pre = pretrain_key_list[i]
        if model_state_dict[mod].shape != pretrain_state_dict[pre].shape:
            print(mod + "不匹配" + pre)
        else:
            model_state_dict[mod] = pretrain_state_dict[pre]
    # model_state_dict.pop("head.weight")
    # model_state_dict.pop("head.bias")
    missing_keys, unexpected_keys = self.load_state_dict(model_state_dict, True)
    torch.save(self.model.state_dict(), "/sj/hold-on/pretrain_model_ckpt/net_pritrain_my.pth")
    return (missing_keys, unexpected_keys)

原理在于state_dict是一个有序的dict。

相关推荐
weixin_40938312几秒前
强化lora训练 这次好点 下次在训练数据增加正常对话
人工智能·深度学习·机器学习·qwen
喜欢吃豆1 分钟前
大语言模型混合专家(MoE)架构深度技术综述
人工智能·语言模型·架构·moe
老蒋新思维2 分钟前
创客匠人:当知识IP遇上系统化AI,变现效率如何实现阶跃式突破?
大数据·网络·人工智能·网络协议·tcp/ip·重构·创客匠人
有一个好名字4 分钟前
Spring AI 工具调用(Tool Calling):解锁智能应用新能力
java·人工智能·spring
Das15 分钟前
【计算机视觉】07_几何变换
人工智能·计算机视觉
却道天凉_好个秋6 分钟前
OpenCV(四十六):OBR特征检测
人工智能·opencv·计算机视觉
JosieBook8 分钟前
【大模型】用 AI Ping 免费体验 GLM-4.7 与 MiniMax M2.1:从配置到实战的完整教程
数据库·人工智能·redis
deephub13 分钟前
Anthropic 开源 Bloom:基于 LLM 的自动化行为评估框架
人工智能·python·自动化·大语言模型·行为评估
十铭忘13 分钟前
动作识别9——TSN训练实验
人工智能·深度学习·机器学习
小真zzz14 分钟前
当前集成Nano Banana Pro模型的AI PPT工具排名与分析
开发语言·人工智能·ai·powerpoint·ppt