【踩坑记录】pytorch 自定义嵌套网络时部分网络输出不变的问题

问题描述

使用如下的自定义的多层嵌套网络进行训练:

python 复制代码
class FC1_bot(nn.Module):
    def __init__(self):
        super(FC1_bot, self).__init__()
        self.embeddings = nn.Sequential(
        	nn.Linear(10, 10)
        )
       
    def forward(self, x):
        emb = self.embeddings(x)
        return emb

    
class FC1_top(nn.Module):
    def __init__(self):
        super(FC1_top, self).__init__()
        self.prediction = nn.Sequential(
            nn.Dropout(p=0.5),
            nn.Linear(10, 10)
        )
        
    def forward(self, x):
        logit = self.prediction(x)
        return logit


class FC1(nn.Module):
    def __init__(self, num):
        super(FC1, self).__init__()
        self.num = num

        self.bot = []
        for _ in range(num):
            self.bot.append(FC1_bot())

        self.top = FC1_top()
        
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        x = list(x)
        emb = []
        for i in range(self.num):
            emb.append(self.bot[i](x[i]))

        agg_emb = self._aggregate(emb)
        logit = self.top(agg_emb)

        pred = self.softmax(logit)

        return emb, pred
    
    def _aggregate(self, x):
        # Note: x is a list of tensors.
        return torch.cat(x, dim=1)

训练的代码如下:

python 复制代码
def train(self):
	# train entire model
	self.model.train()

	for epoch in range(self.args.epochs):
		...

解决办法

需要把所有用到的模型都变成训练模式,否则只有top模型在被训练。

python 复制代码
def train(self):
	# train entire model
	self.model.train()
	self.model.top.train()
	for i in range(self.args.num):
	    self.model.bot[i].train()

	for epoch in range(self.args.epochs):
		...
相关推荐
用户5191495848455 分钟前
在AI技术唾手可得的时代,挖掘JavaScript学习资源的新需求成为关键
人工智能·aigc
北邮刘老师13 分钟前
【未来】智能体互联时代的商业模式变化和挑战:从HOM到AOM
人工智能·大模型·智能体·智能体互联网
东方芷兰14 分钟前
LLM 笔记 —— 03 大语言模型安全性评定
人工智能·笔记·python·语言模型·自然语言处理·nlp·gpt-3
小树苗19316 分钟前
Berachain稳定币使用指南:HONEY与跨链稳定币的协同之道
大数据·人工智能·区块链
MediaTea18 分钟前
Python 库手册:keyword 关键字查询
开发语言·python
java1234_小锋22 分钟前
Scikit-learn Python机器学习 - 模型保存及加载
python·机器学习·scikit-learn
睿思达DBA_WGX24 分钟前
使用 python-docx 库操作 word 文档(1):文件操作
开发语言·python·word
攻城狮7号29 分钟前
快手推出KAT系列编码大模型,甚至还有开源版本?
人工智能·ai编程·kat-coder·快手kat·快手开源模型
说私域31 分钟前
互联网新热土视角下开源AI大模型与S2B2C商城小程序的县域市场渗透策略研究
人工智能·小程序·开源
IT_陈寒31 分钟前
Python 3.12新特性实战:5个让你的代码提速30%的性能优化技巧
前端·人工智能·后端