【踩坑记录】pytorch 自定义嵌套网络时部分网络输出不变的问题

问题描述

使用如下的自定义的多层嵌套网络进行训练:

python 复制代码
class FC1_bot(nn.Module):
    def __init__(self):
        super(FC1_bot, self).__init__()
        self.embeddings = nn.Sequential(
        	nn.Linear(10, 10)
        )
       
    def forward(self, x):
        emb = self.embeddings(x)
        return emb

    
class FC1_top(nn.Module):
    def __init__(self):
        super(FC1_top, self).__init__()
        self.prediction = nn.Sequential(
            nn.Dropout(p=0.5),
            nn.Linear(10, 10)
        )
        
    def forward(self, x):
        logit = self.prediction(x)
        return logit


class FC1(nn.Module):
    def __init__(self, num):
        super(FC1, self).__init__()
        self.num = num

        self.bot = []
        for _ in range(num):
            self.bot.append(FC1_bot())

        self.top = FC1_top()
        
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        x = list(x)
        emb = []
        for i in range(self.num):
            emb.append(self.bot[i](x[i]))

        agg_emb = self._aggregate(emb)
        logit = self.top(agg_emb)

        pred = self.softmax(logit)

        return emb, pred
    
    def _aggregate(self, x):
        # Note: x is a list of tensors.
        return torch.cat(x, dim=1)

训练的代码如下:

python 复制代码
def train(self):
	# train entire model
	self.model.train()

	for epoch in range(self.args.epochs):
		...

解决办法

需要把所有用到的模型都变成训练模式,否则只有top模型在被训练。

python 复制代码
def train(self):
	# train entire model
	self.model.train()
	self.model.top.train()
	for i in range(self.args.num):
	    self.model.bot[i].train()

	for epoch in range(self.args.epochs):
		...
相关推荐
墩墩冰10 小时前
计算机图形学 实现直线段的反走样
人工智能·机器学习
Pyeako10 小时前
深度学习--卷积神经网络(下)
人工智能·python·深度学习·卷积神经网络·数据增强·保存最优模型·数据预处理dataset
OPEN-Source10 小时前
大模型实战:搭建一张“看得懂”的大模型应用可观测看板
人工智能·python·langchain·rag·deepseek
廖圣平10 小时前
从零开始,福袋直播间脚本研究【七】《添加分组和比特浏览器》
python
B站_计算机毕业设计之家10 小时前
豆瓣电影数据可视化分析系统 | Python Flask框架 requests Echarts 大数据 人工智能 毕业设计源码(建议收藏)✅
大数据·python·机器学习·数据挖掘·flask·毕业设计·echarts
zzz的学习笔记本10 小时前
AI智能体时代的记忆 笔记(由大模型生成)
人工智能·智能体
AGI-四顾10 小时前
文生图模型选型速览
人工智能·ai
大尚来也10 小时前
一篇搞懂AI通识:用大白话讲清人工智能的核心逻辑
人工智能
Coder_Boy_10 小时前
Deeplearning4j+ Spring Boot 电商用户复购预测案例
java·人工智能·spring boot·后端·spring
风指引着方向10 小时前
动态形状算子支持:CANN ops-nn 的灵活推理方案
人工智能·深度学习·神经网络