nn.Identity 单位矩阵,同一矩阵

文章目录

  • [1. 说明](#1. 说明)
  • [2. pytorch 代码](#2. pytorch 代码)

1. 说明

在搭建网络结构中,为了保证搭建的网络具有高度扩展性和后续调试模型框架,在保证整体结构完整情况下,用nn.Identity 进行占位符处理。

2. pytorch 代码

  • pytorch代码
python 复制代码
import torch
import torch.nn as nn

class MyModel(nn.Module):
    def __init__(self, use_dropout=True):
        super(MyModel, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            # 根据配置决定使用 Dropout 还是 Identity
            nn.Identity() if not use_dropout else nn.Dropout(p=0.5),
            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.ReLU(),
        )
        self.classifier = nn.Linear(32 * 32 * 32, 10)  # 假设输入图像尺寸为 32x32

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        return self.classifier(x)

# 实例化模型时选择是否使用 Dropout
model_with_dropout = MyModel(use_dropout=True)
model_without_dropout = MyModel(use_dropout=False)

# 测试模型输出形状
x = torch.randn(4, 3, 32, 32)
print("Output with dropout:", model_with_dropout(x).shape)
print("Output without dropout:", model_without_dropout(x).shape)
  • result
python 复制代码
Output with dropout: torch.Size([4, 10])
Output without dropout: torch.Size([4, 10])
相关推荐
会Tk矩阵群控的小木1 小时前
小红书矩阵系统开发:私域流量转化与管理完整技术实现
矩阵·新媒体运营·开源软件·个人开发·tk
AI_yangxi2 小时前
短视频矩阵系统服务商
大数据·人工智能·矩阵
硅谷秋水2 小时前
SkillOpt:自演化智体技能的执行策略
大数据·人工智能·深度学习·机器学习·语言模型
硅谷秋水3 小时前
Qwen-VLA:跨任务、环境与机器人形态的视觉-语言-动作统一建模
人工智能·深度学习·算法·计算机视觉·语言模型·机器人
装不满的克莱因瓶3 小时前
实现矩阵的转置:从数学原理到 NumPy 实战
线性代数·机器学习·矩阵·数据分析·numpy·特征分解
YOLO数据集集合4 小时前
智慧电网红外热成像数据集|电力设备组件识别目标检测深度学习数据集
人工智能·深度学习·yolo·目标检测·计算机视觉
hengsf1234564 小时前
Transformer初探
人工智能·深度学习·transformer
吃好睡好便好5 小时前
矩阵旋转的计算
学习·线性代数·算法·矩阵
weixin_468466855 小时前
空洞卷积与膨胀卷积新手入门指南
图像处理·人工智能·深度学习·ai·机器视觉·卷积·空洞卷积