训练过程中可能遇到的问题

训练过程中可能遇到的问题

1. 前向传播正常,反向更新梯度时报错。

可能的原因是,你想给基线网络加含参数模块,但是这个模块你只在forward中穿插进去了,而在构建网络时,你将这个模块随便搞了个位置new了出来。这样在反向传播时,构建的传播图中没有你的新模块的位置。导致反向传播无法计算梯度并更新。

原来的代码是:

PYTHON 复制代码
class Net(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        input = nn.Block()
        layer1 = nn.Block()
        layer2 = nn.Block()
        layer3 = nn.Block()
        layer4 = nn.Block()
        layer5 = nn.Block()
        output = nn.Block()
    
    def forward(self, x):
        x = input(x)
        x = layer1(x)
        x = layer2(x)
        x = layer3(x)
        x = layer4(x)
        x = layer5(x)
        return x 

错误的代码是:

PYTHON 复制代码
class Net(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        input = nn.Block()
        layer1 = nn.Block()
        layer2 = nn.Block()
        layer3 = nn.Block()
        layer4 = nn.Block()
        layer5 = nn.Block()
        output = nn.Block()
        new_block = nn.Block() # 错误,位置不对
    
    def forward(self, x): # forward没问题
        x = input(x)
        x = layer1(x)
        x = layer2(x)
        brunch = new_block(x)
        x = layer3(x)
        x = layer4(x)
        x = layer5(x)
        x = output(x + brunch)
        return x 

正确的代码应该是:

PYTHON 复制代码
class Net(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        input = nn.Block()
        layer1 = nn.Block()
        layer2 = nn.Block()
        new_block = nn.Block() # 这里才对
        layer3 = nn.Block()
        layer4 = nn.Block()
        layer5 = nn.Block()
        output = nn.Block()
    
    def forward(self, x):
        x = input(x)
        x = layer1(x)
        x = layer2(x)
        brunch = new_block(x)
        x = layer3(x)
        x = layer4(x)
        x = layer5(x)
        x = output(x + brunch)
        return x 

上面例子过于简单,很可能这样写并没有问题。但是当你的layer2是一个复合模块,且你的new_block需要从layer内部分支出来时。就会出问题。

那么在这个时候,我们就需要将基线网络的原本结构更改,将原本复合的layer2模块扁平化,然后再插入分支。

如下图所示:

你可能会想到:

PYTHON 复制代码
class Net(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        input = nn.Block()
        layer1 = nn.Block()
        layer2 = nn.Block()
        new_block = nn.Block() # 你可能想到的位置1
        layer4 = nn.Block()
        layer5 = nn.Block()
        output = nn.Block()
        new_block = nn.Block() # 你可能想到的位置2
    
    def forward(self, x):
        x = input(x)
        x = layer1(x)
        x = layer2[0](x)
        brunch = new_block(x)
        x = layer2[1](x)
        x = layer4(x)
        x = layer5(x)
        x = output(x + brunch)
        return x 

很遗憾,这样的话,new_block是无法计算梯度的。正确的做法应该是吧图2中的架构扁平化为图1,再加入new_block。

那么有个问题,更改基线模型了之后(比如扁平化),加载预训练权重会无法对应,怎么办?

2. 修改基线模型后,如何最大化利用原基线模型的预训练权重?

如图所示

如果你将图3展平成图4,那么按图4的model去load图3的state_dict是不行的,layer2 layer3加载不了权重。当这种冲突比较少时,可以手动通过pop替换。当这种冲突多的时候。我们要怎么办呢?以下是我的方案:

  1. 将原基线模型重组成新的基线模型,但是保持元架构不变,即最小单位模块的数目及顺序/位置不变,即两个模型实质等同,只是元件的命名不同。
  2. 分别读取原基线模型的预训练权重和新基线模型的参数词典
  3. 逐层按张量形状,将原基线模型的权重更新到新基线模型的参数词典中
  4. 把新参数词典中的权重文件保存下来

当你需要更改模块时,直接在新基线模型的基础上进行更改,然后读取新保存的权重作为预训练权重即可。代码参考下方。

PYTHON 复制代码
def load_state_dict_and_drop(self, path:str = None, strict: bool = False):
    pretrain_state_dict = torch.load(path)
    model_state_dict = self.model.state_dict()
    pretrain_key_list = list(pretrain_state_dict.keys())
    model_key_list = list(model_state_dict.keys())
    key_nums = len(pretrain_key_list)
    for i in range(key_nums):
        mod = model_key_list[i]
        pre = pretrain_key_list[i]
        if model_state_dict[mod].shape != pretrain_state_dict[pre].shape:
            print(mod + "不匹配" + pre)
        else:
            model_state_dict[mod] = pretrain_state_dict[pre]
    # model_state_dict.pop("head.weight")
    # model_state_dict.pop("head.bias")
    missing_keys, unexpected_keys = self.load_state_dict(model_state_dict, True)
    torch.save(self.model.state_dict(), "/sj/hold-on/pretrain_model_ckpt/net_pritrain_my.pth")
    return (missing_keys, unexpected_keys)

原理在于state_dict是一个有序的dict。

相关推荐
熙梦数字化1 分钟前
企业资源计划(ERP)系统是什么?有哪些特点?
大数据·人工智能·erp
GISer_Jing4 分钟前
SSE Conf大会分享——大模型驱动的智能 可视分析与故事叙述
前端·人工智能·信息可视化
Wai-Ngai7 分钟前
自动驾驶控制算法——模型预测控制(MPC)
人工智能·机器学习·自动驾驶
北京耐用通信7 分钟前
突破协议壁垒:耐达讯自动化Ethernet/IP转CC-Link网关在工业互联中的核心应用
人工智能·网络协议·安全·自动化·信息与通信
扫描电镜8 分钟前
扫描电镜选购指南:智能、稳定与自动化的综合考量
人工智能·自动化·扫描电镜·自动扫描电镜
AI人工智能+9 分钟前
炫彩活体检测技术:利用RGB色光序列检测用户面部生物特征反应,能有效识别3D面具、Deepfake等伪造攻击
人工智能·人脸识别·炫彩活体检测
无代码专家11 分钟前
数字化转型下的订单管理全流程优化方案
大数据·运维·人工智能
QianCenRealSim14 分钟前
FSD入华“加速”中国自动驾驶产业的推动与重构
人工智能·重构·自动驾驶
roman_日积跬步-终至千里15 分钟前
【模式识别与机器学习(1+)】基础概念之:机器学习基础
人工智能·机器学习
itwangyang52016 分钟前
AIDD-人工智能药物设计-StoL:像搭乐高一样用扩散模型构建大分子 3D 构象
人工智能·3d