pytorch冻结参数训练的坑

由于项目需要训练一个主干网络接多个分支的模型,所以先训练一个主干网络加第一个分支,再用另外的数据训练第二个分支,训练的过程中需要冻结主干网络部分,后面的分支训练过程也一样需要冻结主干网络部分。

冻结模型的方式

python 复制代码
for name, para in model.named_parameters():
      # 冻结backbone的权重
	if name.split(".")[0] == "backbone":
          para.requires_grad = False        # 或者用para.requires_grad_(False),一个是通过属性直接赋值,一个是通过函数赋值
    else:
          para.requires_grad = True
python 复制代码
# 可以打印需要更新梯度的参数
for name, value in model.named_parameters():
    print(name, "\t更新梯度:",value.requires_grad)

坑1:这样做并不能冻结batchnorm层的参数,所以需要在训练中手动冻结。如:

python 复制代码
def fix_bn(m):
    classname = m.__class__.__name__
    if classname.find('SyncBatchNorm') != -1 or classname.find('InstanceNorm2d') != -1 or classname.find('BatchNorm2d') != -1:          #SyncBatchNorm, InstanceNorm2d
        if m.num_features in [32, 64, 96, 128, 256, 384, 768, 192, 1152, 224]:      # 需要冻结的BN层的通道数
            m.eval()

def train():
	for epoch in range(max_epoch):
		model.train()
		if args.freeze:
			model.apply(fix_bn)
			model.backbone[5][0].block[0][1].eval()   # 假如需要冻结的BN层通道数和不需要冻结的BN层通道数一样,则需要单独写
		for batch_idx, (data, target) in enumerate(train_loader):
			...
			

坑2:用了冻结训练(freeze)就不要用EMA方式更新模型了,不然收敛缓慢不说,还会造成前面冻结的参数产生变化,可以从EMA的代码看出端倪:

python 复制代码
class EMA:
    def __init__(self, model, decay=0.9999):
        super().__init__()
        import copy
        self.decay = decay
        self.model = copy.deepcopy(model)

        self.model.eval()

    def update_fn(self, model, fn):
        with torch.no_grad():
            e_std = self.model.state_dict().values()
            #m_std = model.module.state_dict().values()   # multi-gpu
            m_std = model.state_dict().values()          # single-gpu
            for e, m in zip(e_std, m_std):
                e.copy_(fn(e, m))

    def update(self, model):
        self.update_fn(model, fn=lambda e, m: self.decay * e + (1. - self.decay) * m)

可以看出EMA的方式更新模型方式,大部分是结合上一个模型的参数的,即:

model_update = decay*model(t-1) + (1-decay)*model(t)       # model(t-1) 代表上一次迭代模型的参数,model(t)代表当前迭代得到的模型参数

虽然冻结了backbone的参数,阻止了梯度在backbone中反向传播,但参数由于经过如上乘法及加法运算,由于精度原因,还是会发生微小变化,虽然训练次数增加,这个变化会扩大,从而达不到冻结训练的效果。而且从计算公式可以看出来,采用EMA的方式更新模型参数,参数会更新得很慢,会造成网络难以学习的"错觉"。我在这里困住了3天,有怀疑过是否是网络设计问题,是否是多GPU同步的问题,是否是参数设置,如学习率过小,权重衰减过大,或者dropout设置过大等等,最终一步一步排除定位到EMA的问题。

以这次的经验来看,EMA只适合在上一次训练得到模型的基础上,这一次加了额外的数据,需要在上一次的基础上做微调的情况。

相关推荐
Sxiaocai6 分钟前
使用 PyTorch 实现并训练 VGGNet 用于 MNIST 分类
pytorch·深度学习·分类
GL_Rain7 分钟前
【OpenCV】Could NOT find TIFF (missing: TIFF_LIBRARY TIFF_INCLUDE_DIR)
人工智能·opencv·计算机视觉
shansjqun12 分钟前
教学内容全覆盖:航拍杂草检测与分类
人工智能·分类·数据挖掘
狸克先生14 分钟前
如何用AI写小说(二):Gradio 超简单的网页前端交互
前端·人工智能·chatgpt·交互
baiduopenmap29 分钟前
百度世界2024精选公开课:基于地图智能体的导航出行AI应用创新实践
前端·人工智能·百度地图
小任同学Alex32 分钟前
浦语提示词工程实践(LangGPT版,服务器上部署internlm2-chat-1_8b,踩坑很多才完成的详细教程,)
人工智能·自然语言处理·大模型
新加坡内哥谈技术38 分钟前
微软 Ignite 2024 大会
人工智能
nuclear20111 小时前
使用Python 在Excel中创建和取消数据分组 - 详解
python·excel数据分组·创建excel分组·excel分类汇总·excel嵌套分组·excel大纲级别·取消excel分组
江瀚视野1 小时前
Q3净利增长超预期,文心大模型调用量大增,百度未来如何分析?
人工智能
Lucky小小吴1 小时前
有关django、python版本、sqlite3版本冲突问题
python·django·sqlite