【踩坑记录】pytorch 自定义嵌套网络时部分网络输出不变的问题

问题描述

使用如下的自定义的多层嵌套网络进行训练:

python 复制代码
class FC1_bot(nn.Module):
    def __init__(self):
        super(FC1_bot, self).__init__()
        self.embeddings = nn.Sequential(
        	nn.Linear(10, 10)
        )
       
    def forward(self, x):
        emb = self.embeddings(x)
        return emb

    
class FC1_top(nn.Module):
    def __init__(self):
        super(FC1_top, self).__init__()
        self.prediction = nn.Sequential(
            nn.Dropout(p=0.5),
            nn.Linear(10, 10)
        )
        
    def forward(self, x):
        logit = self.prediction(x)
        return logit


class FC1(nn.Module):
    def __init__(self, num):
        super(FC1, self).__init__()
        self.num = num

        self.bot = []
        for _ in range(num):
            self.bot.append(FC1_bot())

        self.top = FC1_top()
        
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        x = list(x)
        emb = []
        for i in range(self.num):
            emb.append(self.bot[i](x[i]))

        agg_emb = self._aggregate(emb)
        logit = self.top(agg_emb)

        pred = self.softmax(logit)

        return emb, pred
    
    def _aggregate(self, x):
        # Note: x is a list of tensors.
        return torch.cat(x, dim=1)

训练的代码如下:

python 复制代码
def train(self):
	# train entire model
	self.model.train()

	for epoch in range(self.args.epochs):
		...

解决办法

需要把所有用到的模型都变成训练模式,否则只有top模型在被训练。

python 复制代码
def train(self):
	# train entire model
	self.model.train()
	self.model.top.train()
	for i in range(self.args.num):
	    self.model.bot[i].train()

	for epoch in range(self.args.epochs):
		...
相关推荐
shchojj3 小时前
Software Applications - Lifecycle of a generative AI project
人工智能
人工智能技术咨询.3 小时前
知识图谱:AI的超级大脑
人工智能
roman_日积跬步-终至千里3 小时前
【系统架构师】从软件架构师考试内容看 AI 时代的软件工程管理
人工智能·系统架构·软件工程
工业机器人销售服务3 小时前
不锈钢制品美容焊手:法奥机器人施焊成型焊缝色泽均匀,防腐性能与母材保持一致
大数据·人工智能
hnxaoli3 小时前
统信小程序(十三)循环键鼠操作程序
python·小程序
这是谁的博客?3 小时前
AI 领域精选新闻(2026-05-21)
人工智能·gpt·ai·google·大模型·gemini·新闻
逆境不可逃3 小时前
【与我学 ClaudeCode】规划与协调篇 之 Task System :持久化任务图与多 Agent 协作骨架
人工智能·agent
code 小楊3 小时前
2026两大新王对决:Qwen3\.7\-Max vs Gemini 3\.5 Flash 全维度深度测评(能力、对比、选型、优劣)
大数据·人工智能
程序员学习Chat3 小时前
计算机视觉-Backbone超详细整理(上)-卷积时代
人工智能·计算机视觉
华普微HOPERF3 小时前
智能手表集成数字气压传感器,就能实现楼层定位功能?
人工智能·计算机视觉·智能手表