【求助帖(已解决)】用PyTorch搭建MLP网络时遇到奇怪的问题

(已解决,看最后)

求助:我在测试自己搭建的通用MLP网络时,发现它与等价的参数写死的MLP网络相比效果奇差无比,不知道是哪里出了问题,请大佬们帮忙看下。

我写的通用MLP网络:

python 复制代码
class MLP(nn.Module):
    def __init__(self, feature_num, class_num, *hidden_nums):
        super().__init__()
        self.feature_num = feature_num
        self.class_num = class_num
        self.hidden_nums = hidden_nums

        input_num = feature_num
        for i, hidden_num in enumerate(hidden_nums):
            self.__dict__['fc' + str(i)] = nn.Linear(input_num, hidden_num)
            input_num = hidden_num
        self.output = nn.Linear(input_num, class_num)

    def forward(self, x):
        for i in range(len(self.hidden_nums)):
            x = F.relu(self.__dict__['fc' + str(i)](x))
        x = self.output(x)[..., 0] if self.class_num == 1 else F.sigmoid(self.output(x))
        return x

按理说这样实例化时:

python 复制代码
model = MLP(57, 2, 30, 10)

它应该与下面这个网络等价:

python 复制代码
class MLPclassification(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc0 = nn.Linear(57, 30)
        self.fc1 = nn.Linear(30, 10)
        self.output = nn.Linear(10, 2)

    def forward(self, x):
        x = F.relu(self.fc0(x))
        x = F.relu(self.fc1(x))
        x = F.sigmoid(self.output(x))
        return x

但当我用model = MLP(57, 2, 30, 10)训练网络时,在二分类问题中,它把所有数据都预测成了类别0:

而用 model = MLPclassification()训练网络时,预测的效果很好:

我检查了半天,不知道是哪里出了问题,有没有大佬懂的,帮忙看下,十分感谢!


解决了!我检查了nn.Module的__setattr__()方法(向对象的name属性赋值、即定义实例变量时自动调用的方法),发现__setattr__()会将Module类型的变量移到_modules属性下面:

所以批量定义全连接层时不能直接向__dict__属性赋值,这样会绕过__setattr__()方法的类型检查,导致最后优化器无法通过model.parameters()获取并更新隐藏层的权重。所以应该在__dict__['_modules']属性中批量定义全连接层,就能解决这个问题了。更新后的通用MLP网络代码如下:

python 复制代码
class MLP(nn.Module):
    def __init__(self, feature_num, class_num, *hidden_nums):
        super().__init__()
        self.feature_num = feature_num
        self.class_num = class_num
        self.hidden_nums = hidden_nums

        input_num = feature_num
        for i, hidden_num in enumerate(hidden_nums):
            self.__dict__['_modules']['fc' + str(i)] = nn.Linear(input_num, hidden_num)
            input_num = hidden_num
        self.output = nn.Linear(input_num, class_num)

    def forward(self, x):
        for i in range(len(self.hidden_nums)):
            x = F.relu(self.__dict__['_modules']['fc' + str(i)](x))
        x = self.output(x)[..., 0] if self.class_num == 1 else F.softmax(self.output(x), dim=-1)
        return x

预测效果非常好:

感悟:看来没啥事还是不要随便动下划线开头的东西,你不知道会不会牵动到别的地方,出了问题处理起来挺麻烦的。

相关推荐
m5655bj1 天前
使用 Python 高效复制 Excel 行、列、单元格
开发语言·python·excel
龙言龙论1 天前
身份证信息批量处理系统:从入门到实战(附exe工具+核心源码)
数据库·python
m0_626535201 天前
代码分析 长音频分割为短音频
javascript·python·音视频
Wpa.wk1 天前
自动化测试环境配置-java+python
java·开发语言·python·测试工具·自动化
带刺的坐椅1 天前
AI 应用工作流:LangGraph 和 Solon AI Flow,我该选谁?
java·python·ai·solon·flow·langgraph
CoovallyAIHub1 天前
超越YOLOv8/v11!自研RKM-YOLO为输电线路巡检精度、速度双提升
深度学习·算法·计算机视觉
工业互联网专业1 天前
图片推荐系统_django+spider
python·django·毕业设计·源码·课程设计·spider·图片推荐系统
Lwcah1 天前
Python | LGBM+SHAP可解释性分析回归预测及可视化算法
python·算法·回归
@一辈子爱你1 天前
归来九十余日:在时代的夹缝中,与你共筑一道光
python
BagMM1 天前
FC-CLIP 论文阅读 开放词汇的检测与分割的统一
人工智能·深度学习·计算机视觉