训练过程中可能遇到的问题

训练过程中可能遇到的问题

1. 前向传播正常,反向更新梯度时报错。

可能的原因是,你想给基线网络加含参数模块,但是这个模块你只在forward中穿插进去了,而在构建网络时,你将这个模块随便搞了个位置new了出来。这样在反向传播时,构建的传播图中没有你的新模块的位置。导致反向传播无法计算梯度并更新。

原来的代码是:

PYTHON 复制代码
class Net(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        input = nn.Block()
        layer1 = nn.Block()
        layer2 = nn.Block()
        layer3 = nn.Block()
        layer4 = nn.Block()
        layer5 = nn.Block()
        output = nn.Block()
    
    def forward(self, x):
        x = input(x)
        x = layer1(x)
        x = layer2(x)
        x = layer3(x)
        x = layer4(x)
        x = layer5(x)
        return x 

错误的代码是:

PYTHON 复制代码
class Net(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        input = nn.Block()
        layer1 = nn.Block()
        layer2 = nn.Block()
        layer3 = nn.Block()
        layer4 = nn.Block()
        layer5 = nn.Block()
        output = nn.Block()
        new_block = nn.Block() # 错误,位置不对
    
    def forward(self, x): # forward没问题
        x = input(x)
        x = layer1(x)
        x = layer2(x)
        brunch = new_block(x)
        x = layer3(x)
        x = layer4(x)
        x = layer5(x)
        x = output(x + brunch)
        return x 

正确的代码应该是:

PYTHON 复制代码
class Net(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        input = nn.Block()
        layer1 = nn.Block()
        layer2 = nn.Block()
        new_block = nn.Block() # 这里才对
        layer3 = nn.Block()
        layer4 = nn.Block()
        layer5 = nn.Block()
        output = nn.Block()
    
    def forward(self, x):
        x = input(x)
        x = layer1(x)
        x = layer2(x)
        brunch = new_block(x)
        x = layer3(x)
        x = layer4(x)
        x = layer5(x)
        x = output(x + brunch)
        return x 

上面例子过于简单,很可能这样写并没有问题。但是当你的layer2是一个复合模块,且你的new_block需要从layer内部分支出来时。就会出问题。

那么在这个时候,我们就需要将基线网络的原本结构更改,将原本复合的layer2模块扁平化,然后再插入分支。

如下图所示:

你可能会想到:

PYTHON 复制代码
class Net(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        input = nn.Block()
        layer1 = nn.Block()
        layer2 = nn.Block()
        new_block = nn.Block() # 你可能想到的位置1
        layer4 = nn.Block()
        layer5 = nn.Block()
        output = nn.Block()
        new_block = nn.Block() # 你可能想到的位置2
    
    def forward(self, x):
        x = input(x)
        x = layer1(x)
        x = layer2[0](x)
        brunch = new_block(x)
        x = layer2[1](x)
        x = layer4(x)
        x = layer5(x)
        x = output(x + brunch)
        return x 

很遗憾,这样的话,new_block是无法计算梯度的。正确的做法应该是吧图2中的架构扁平化为图1,再加入new_block。

那么有个问题,更改基线模型了之后(比如扁平化),加载预训练权重会无法对应,怎么办?

2. 修改基线模型后,如何最大化利用原基线模型的预训练权重?

如图所示

如果你将图3展平成图4,那么按图4的model去load图3的state_dict是不行的,layer2 layer3加载不了权重。当这种冲突比较少时,可以手动通过pop替换。当这种冲突多的时候。我们要怎么办呢?以下是我的方案:

  1. 将原基线模型重组成新的基线模型,但是保持元架构不变,即最小单位模块的数目及顺序/位置不变,即两个模型实质等同,只是元件的命名不同。
  2. 分别读取原基线模型的预训练权重和新基线模型的参数词典
  3. 逐层按张量形状,将原基线模型的权重更新到新基线模型的参数词典中
  4. 把新参数词典中的权重文件保存下来

当你需要更改模块时,直接在新基线模型的基础上进行更改,然后读取新保存的权重作为预训练权重即可。代码参考下方。

PYTHON 复制代码
def load_state_dict_and_drop(self, path:str = None, strict: bool = False):
    pretrain_state_dict = torch.load(path)
    model_state_dict = self.model.state_dict()
    pretrain_key_list = list(pretrain_state_dict.keys())
    model_key_list = list(model_state_dict.keys())
    key_nums = len(pretrain_key_list)
    for i in range(key_nums):
        mod = model_key_list[i]
        pre = pretrain_key_list[i]
        if model_state_dict[mod].shape != pretrain_state_dict[pre].shape:
            print(mod + "不匹配" + pre)
        else:
            model_state_dict[mod] = pretrain_state_dict[pre]
    # model_state_dict.pop("head.weight")
    # model_state_dict.pop("head.bias")
    missing_keys, unexpected_keys = self.load_state_dict(model_state_dict, True)
    torch.save(self.model.state_dict(), "/sj/hold-on/pretrain_model_ckpt/net_pritrain_my.pth")
    return (missing_keys, unexpected_keys)

原理在于state_dict是一个有序的dict。

相关推荐
这儿有一堆花1 天前
Pixel 与 iPhone 安全性对比:硬件芯片、系统更新和实际防护谁更可靠
人工智能·chatgpt
AC赳赳老秦1 天前
测试工程师:OpenClaw自动化测试脚本生成,批量执行测试用例
大数据·linux·人工智能·python·django·测试用例·openclaw
Rubin智造社1 天前
04月18日AI每日参考:Claude Design上线冲击设计圈,OpenAI高管接连出走
人工智能·anthropic·claude design·openai高管·metr·ai拟人化监管
人工智能AI技术1 天前
面试官内部面经,仅限应届生看
人工智能
rainbow7242441 天前
AI学习路线分享:通用型认证与算法认证学习体验对比
人工智能·学习·算法
IT_陈寒1 天前
Java集合的这个坑,我调试了整整3小时才爬出来
前端·人工智能·后端
Simon_lca1 天前
验厂不翻车!Acushnet 11 项核心政策 + 自查要点,一文搞定
大数据·人工智能·经验分享·算法·制造
2501_948114241 天前
2026 深度评测:Qwen 3.6-Plus 全模态逻辑链融合架构解析与高可用接入实践
人工智能·gpt·ai·架构·claude
水如烟1 天前
孤能子视角:AI分形定律,结构依赖度 = AI能效比,以及科研“结构偏见“端倪
人工智能
一江寒逸1 天前
人工智能的“记忆灵魂”:深度拆解大模型时代的上下文技术体系、实战与未来
人工智能