nn.Identity 单位矩阵,同一矩阵

文章目录

  • [1. 说明](#1. 说明)
  • [2. pytorch 代码](#2. pytorch 代码)

1. 说明

在搭建网络结构中,为了保证搭建的网络具有高度扩展性和后续调试模型框架,在保证整体结构完整情况下,用nn.Identity 进行占位符处理。

2. pytorch 代码

  • pytorch代码
python 复制代码
import torch
import torch.nn as nn

class MyModel(nn.Module):
    def __init__(self, use_dropout=True):
        super(MyModel, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            # 根据配置决定使用 Dropout 还是 Identity
            nn.Identity() if not use_dropout else nn.Dropout(p=0.5),
            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.ReLU(),
        )
        self.classifier = nn.Linear(32 * 32 * 32, 10)  # 假设输入图像尺寸为 32x32

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        return self.classifier(x)

# 实例化模型时选择是否使用 Dropout
model_with_dropout = MyModel(use_dropout=True)
model_without_dropout = MyModel(use_dropout=False)

# 测试模型输出形状
x = torch.randn(4, 3, 32, 32)
print("Output with dropout:", model_with_dropout(x).shape)
print("Output without dropout:", model_without_dropout(x).shape)
  • result
python 复制代码
Output with dropout: torch.Size([4, 10])
Output without dropout: torch.Size([4, 10])
相关推荐
B站_计算机毕业设计之家17 分钟前
大数据YOLOv8无人机目标检测跟踪识别系统 深度学习 PySide界面设计 大数据 ✅
大数据·python·深度学习·信息可视化·数据挖掘·数据分析·flask
武子康8 小时前
AI研究-119 DeepSeek-OCR PyTorch FlashAttn 2.7.3 推理与部署 模型规模与资源详细分析
人工智能·深度学习·机器学习·ai·ocr·deepseek·deepseek-ocr
忙碌5449 小时前
AI大模型时代下的全栈技术架构:从深度学习到云原生部署实战
人工智能·深度学习·架构
听风吹等浪起11 小时前
基于改进TransUNet的港口船只图像分割系统研究
人工智能·深度学习·cnn·transformer
化作星辰11 小时前
深度学习_原理和进阶_PyTorch入门(2)后续语法3
人工智能·pytorch·深度学习
哥布林学者13 小时前
吴恩达深度学习课程二: 改善深层神经网络 第二周:优化算法(二)指数加权平均和学习率衰减
深度学习·ai
点云SLAM14 小时前
弱纹理图像特征匹配算法推荐汇总
人工智能·深度学习·算法·计算机视觉·机器人·slam·弱纹理图像特征匹配
Sunhen_Qiletian18 小时前
Python 类继承详解:深度学习神经网络架构的构建艺术
python·深度学习·神经网络
LHZSMASH!19 小时前
神经流形:大脑功能几何基础的革命性视角
人工智能·深度学习·神经网络·机器学习
忙碌54419 小时前
智能应用开发指南:深度学习、大数据与微服务的融合之道
大数据·深度学习·微服务