nn.Identity 单位矩阵,同一矩阵

文章目录

  • [1. 说明](#1. 说明)
  • [2. pytorch 代码](#2. pytorch 代码)

1. 说明

在搭建网络结构中,为了保证搭建的网络具有高度扩展性和后续调试模型框架,在保证整体结构完整情况下,用nn.Identity 进行占位符处理。

2. pytorch 代码

  • pytorch代码
python 复制代码
import torch
import torch.nn as nn

class MyModel(nn.Module):
    def __init__(self, use_dropout=True):
        super(MyModel, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            # 根据配置决定使用 Dropout 还是 Identity
            nn.Identity() if not use_dropout else nn.Dropout(p=0.5),
            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.ReLU(),
        )
        self.classifier = nn.Linear(32 * 32 * 32, 10)  # 假设输入图像尺寸为 32x32

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        return self.classifier(x)

# 实例化模型时选择是否使用 Dropout
model_with_dropout = MyModel(use_dropout=True)
model_without_dropout = MyModel(use_dropout=False)

# 测试模型输出形状
x = torch.randn(4, 3, 32, 32)
print("Output with dropout:", model_with_dropout(x).shape)
print("Output without dropout:", model_without_dropout(x).shape)
  • result
python 复制代码
Output with dropout: torch.Size([4, 10])
Output without dropout: torch.Size([4, 10])
相关推荐
程序员打怪兽14 小时前
详解Visual Transformer (ViT)网络模型
深度学习
CoovallyAIHub3 天前
仿生学突破:SILD模型如何让无人机在电力线迷宫中发现“隐形威胁”
深度学习·算法·计算机视觉
CoovallyAIHub3 天前
从春晚机器人到零样本革命:YOLO26-Pose姿态估计实战指南
深度学习·算法·计算机视觉
CoovallyAIHub3 天前
Le-DETR:省80%预训练数据,这个实时检测Transformer刷新SOTA|Georgia Tech & 北交大
深度学习·算法·计算机视觉
CoovallyAIHub3 天前
强化学习凭什么比监督学习更聪明?RL的“聪明”并非来自算法,而是因为它学会了“挑食”
深度学习·算法·计算机视觉
CoovallyAIHub3 天前
YOLO-IOD深度解析:打破实时增量目标检测的三重知识冲突
深度学习·算法·计算机视觉
用户1474853079743 天前
AI-动手深度学习环境搭建-d2l
深度学习
OpenBayes贝式计算3 天前
解决视频模型痛点,TurboDiffusion 高效视频扩散生成系统;Google Streetview 涵盖多个国家的街景图像数据集
人工智能·深度学习·机器学习
OpenBayes贝式计算3 天前
OCR教程汇总丨DeepSeek/百度飞桨/华中科大等开源创新技术,实现OCR高精度、本地化部署
人工智能·深度学习·机器学习
在人间耕耘4 天前
HarmonyOS Vision Kit 视觉AI实战:把官方 Demo 改造成一套能长期复用的组件库
人工智能·深度学习·harmonyos