【踩坑记录】pytorch 自定义嵌套网络时部分网络输出不变的问题

问题描述

使用如下的自定义的多层嵌套网络进行训练:

python 复制代码
class FC1_bot(nn.Module):
    def __init__(self):
        super(FC1_bot, self).__init__()
        self.embeddings = nn.Sequential(
        	nn.Linear(10, 10)
        )
       
    def forward(self, x):
        emb = self.embeddings(x)
        return emb

    
class FC1_top(nn.Module):
    def __init__(self):
        super(FC1_top, self).__init__()
        self.prediction = nn.Sequential(
            nn.Dropout(p=0.5),
            nn.Linear(10, 10)
        )
        
    def forward(self, x):
        logit = self.prediction(x)
        return logit


class FC1(nn.Module):
    def __init__(self, num):
        super(FC1, self).__init__()
        self.num = num

        self.bot = []
        for _ in range(num):
            self.bot.append(FC1_bot())

        self.top = FC1_top()
        
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        x = list(x)
        emb = []
        for i in range(self.num):
            emb.append(self.bot[i](x[i]))

        agg_emb = self._aggregate(emb)
        logit = self.top(agg_emb)

        pred = self.softmax(logit)

        return emb, pred
    
    def _aggregate(self, x):
        # Note: x is a list of tensors.
        return torch.cat(x, dim=1)

训练的代码如下:

python 复制代码
def train(self):
	# train entire model
	self.model.train()

	for epoch in range(self.args.epochs):
		...

解决办法

需要把所有用到的模型都变成训练模式,否则只有top模型在被训练。

python 复制代码
def train(self):
	# train entire model
	self.model.train()
	self.model.top.train()
	for i in range(self.args.num):
	    self.model.bot[i].train()

	for epoch in range(self.args.epochs):
		...
相关推荐
王者鳜錸15 分钟前
讯飞语音唤醒+语音识别+语音合成+文生图完整集成实战
人工智能·文生图·语音识别·xcode·语音生图
小羊羔heihei20 分钟前
Python列表操作全攻略
经验分享·笔记·python·学习·其他·交友
码农小白AI21 分钟前
AI报告文档审核助力排气烟度精准管控:IACheck守护绿色动力环境与合规发展新底线
大数据·人工智能
2501_9083298523 分钟前
实战:用OpenCV和Python进行人脸识别
jvm·数据库·python
深圳市快瞳科技有限公司26 分钟前
高精度宠物鼻纹识别算法原理解析:从图像采集到特征匹配
人工智能·计算机视觉·智慧城市
DX_水位流量监测27 分钟前
德希科技在线 pH 传感器
人工智能·科技·水质监测·水质传感器·水质厂家·供水水质监测·污水监测
热点速递30 分钟前
苹果首款AI穿戴硬件“Apple Pin”曝光:配iPhone的“AI眼睛”,能否突破独立局限?
人工智能·业界资讯
Java后端的Ai之路1 小时前
Milvus 向量数据库从入门到精通:AI 时代的“记忆中枢“实战指南(建议收藏!)
数据库·人工智能·milvus·向量数据库·rag
xixixi777771 小时前
AI的“血管”:从大模型需求看6G、高速光纤与智算中心网络的技术变革
人工智能·ai·大模型·算力·通信·光纤·政策