【踩坑记录】pytorch 自定义嵌套网络时部分网络输出不变的问题

问题描述

使用如下的自定义的多层嵌套网络进行训练:

python 复制代码
class FC1_bot(nn.Module):
    def __init__(self):
        super(FC1_bot, self).__init__()
        self.embeddings = nn.Sequential(
        	nn.Linear(10, 10)
        )
       
    def forward(self, x):
        emb = self.embeddings(x)
        return emb

    
class FC1_top(nn.Module):
    def __init__(self):
        super(FC1_top, self).__init__()
        self.prediction = nn.Sequential(
            nn.Dropout(p=0.5),
            nn.Linear(10, 10)
        )
        
    def forward(self, x):
        logit = self.prediction(x)
        return logit


class FC1(nn.Module):
    def __init__(self, num):
        super(FC1, self).__init__()
        self.num = num

        self.bot = []
        for _ in range(num):
            self.bot.append(FC1_bot())

        self.top = FC1_top()
        
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        x = list(x)
        emb = []
        for i in range(self.num):
            emb.append(self.bot[i](x[i]))

        agg_emb = self._aggregate(emb)
        logit = self.top(agg_emb)

        pred = self.softmax(logit)

        return emb, pred
    
    def _aggregate(self, x):
        # Note: x is a list of tensors.
        return torch.cat(x, dim=1)

训练的代码如下:

python 复制代码
def train(self):
	# train entire model
	self.model.train()

	for epoch in range(self.args.epochs):
		...

解决办法

需要把所有用到的模型都变成训练模式,否则只有top模型在被训练。

python 复制代码
def train(self):
	# train entire model
	self.model.train()
	self.model.top.train()
	for i in range(self.args.num):
	    self.model.bot[i].train()

	for epoch in range(self.args.epochs):
		...
相关推荐
天天要nx5 分钟前
D105【python 接口自动化学习】- pytest进阶参数化用法
python·pytest
是十一月末15 分钟前
Opencv实现图片和视频的加噪、平滑处理
人工智能·python·opencv·计算机视觉·音视频
周盛欢27 分钟前
云服务器yum无法解析mirrorlist.centos.org
开发语言·python
三月七(爱看动漫的程序员)41 分钟前
HiQA: A Hierarchical Contextual Augmentation RAG for Multi-Documents QA---附录
人工智能·单片机·嵌入式硬件·物联网·机器学习·语言模型·自然语言处理
Schwertlilien1 小时前
图像处理-Ch1-数字图像基础
图像处理·人工智能·算法
程序员一诺1 小时前
【深度学习】嘿马深度学习笔记第10篇:卷积神经网络,学习目标【附代码文档】
人工智能·python·深度学习·算法
是我知白哒1 小时前
pdf转换文本:基于python的tesseract
python·pdf·ocr
MUTA️1 小时前
RT-DETR学习笔记(2)
人工智能·笔记·深度学习·学习·机器学习·计算机视觉
代码的乐趣1 小时前
支持selenium的chrome driver更新到131.0.6778.204
chrome·python·selenium
开发者每周简报2 小时前
求职市场变化
人工智能·面试·职场和发展