PyTorch 实现动态输入

使用 PyTorch 实现动态输入:支持训练和推理输入维度不一致的 CNN 和 LSTM/GRU 模型

在深度学习中,处理不同大小的输入数据是一个常见的挑战。许多实际应用需要模型能够灵活地处理可变长度的输入。本文将介绍如何使用 PyTorch 实现支持动态输入的 CNN 和 LSTM/GRU 模型,并打印每一层的输入和输出。

  • 卷积神经网络(CNN):CNN 通常用于处理图像数据。它通过卷积层提取局部特征,并能够处理不同大小的输入图像。通过使用全局池化层,CNN 可以将不同大小的特征图转换为固定大小的输出。

  • 长短期记忆网络(LSTM)和门控循环单元(GRU):LSTM 和 GRU 是处理序列数据的 RNN 变体。它们能够捕捉时间序列中的长期依赖关系,并支持可变长度的输入序列。

模型搭建

1. CNN 模型

我们将构建一个简单的 CNN 模型,支持动态输入大小,并打印每一层的输入和输出。

python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F

class DynamicCNN(nn.Module):
    def __init__(self):
        super(DynamicCNN, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3)
        self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3)
        self.pool = nn.AdaptiveAvgPool2d((1, 1))  # 自适应池化层
        self.fc = nn.Linear(32, 10)  # 输出10个类别

    def forward(self, x):
        print(f'Input to CNN: {x.shape}')
        x = F.relu(self.conv1(x))
        print(f'Output after conv1: {x.shape}')
        x = F.relu(self.conv2(x))
        print(f'Output after conv2: {x.shape}')
        x = self.pool(x)
        print(f'Output after pooling: {x.shape}')
        x = x.view(x.size(0), -1)  # 展平
        x = self.fc(x)
        print(f'Output after fc: {x.shape}')
        return x

# 创建模型
cnn_model = DynamicCNN()

# 测试动态输入
input_tensor_cnn = torch.randn(1, 3, 64, 64)  # 输入形状为 (batch_size, channels, height, width)
output_cnn = cnn_model(input_tensor_cnn)
python 复制代码
Input to CNN: torch.Size([1, 3, 55, 64])
Output after conv1: torch.Size([1, 16, 53, 62])
Output after conv2: torch.Size([1, 32, 51, 60])
Output after pooling: torch.Size([1, 32, 1, 1])
Output after fc: torch.Size([1, 10])
python 复制代码
Input to CNN: torch.Size([1, 3, 64, 64])
Output after conv1: torch.Size([1, 16, 62, 62])
Output after conv2: torch.Size([1, 32, 60, 60])
Output after pooling: torch.Size([1, 32, 1, 1])
Output after fc: torch.Size([1, 10])

2. LSTM/GRU 模型

接下来,我们将构建一个支持动态输入的 LSTM 模型,并打印每一层的输入和输出。

python 复制代码
import torch
import torch.nn as nn


class DynamicLSTM(nn.Module):
    def __init__(self):
        super(DynamicLSTM, self).__init__()
        self.lstm = nn.LSTM(input_size=10, hidden_size=20, batch_first=True)
        self.fc = nn.Linear(20, 1)  # 输出一个值

    def forward(self, x):
        print(f'Input to LSTM: {x.shape}')
        x, _ = self.lstm(x)
        print(f'Output after LSTM: {x.shape}')
        x = self.fc(x[:, -1, :])  # 取最后一个时间步的输出
        print(f'Output after fc: {x.shape}')
        return x


# 创建模型
lstm_model = DynamicLSTM()

# 测试动态输入
input_tensor_lstm = torch.randn(5, 15, 10)  # 输入形状为 (batch_size, seq_length, input_size)
output_lstm = lstm_model(input_tensor_lstm)
python 复制代码
Input to LSTM: torch.Size([5, 15, 10])
Output after LSTM: torch.Size([5, 15, 20])
Output after fc: torch.Size([5, 1])
python 复制代码
Input to LSTM: torch.Size([5, 20, 10])
Output after LSTM: torch.Size([5, 20, 20])
Output after fc: torch.Size([5, 1])

代码说明

  1. DynamicCNN :该模型包含两个卷积层和一个全连接层。使用自适应平均池化层将特征图的大小调整为 (1, 1),从而支持不同大小的输入图像。每一层的输入和输出形状在前向传播中被打印出来。

  2. DynamicLSTM:该模型包含一个 LSTM 层和一个全连接层。LSTM 层能够处理可变长度的输入序列,输出的形状在前向传播中被打印出来。

相关推荐
数智顾问16 小时前
(107页PPT)数字化采购发展报告(附下载方式)
大数据·人工智能
数据的世界0116 小时前
重构智慧书-第15条:广纳智士
人工智能
Blossom.11816 小时前
基于图神经网络+大模型的网络安全APT检测系统:从流量日志到攻击链溯源的实战落地
人工智能·分布式·深度学习·安全·web安全·开源软件·embedding
一水鉴天16 小时前
整体设计 定稿 之 5 讨论问题汇总 和新建 表述总表/项目结构表 文档分析,到读表工具核心设计讨论(豆包助手)
数据库·人工智能·重构
roman_日积跬步-终至千里16 小时前
【计算机视觉(12)】神经网络与反向传播基础篇:从线性分类器到多层感知机
人工智能·神经网络·计算机视觉
源于花海16 小时前
迁移学习的基本方法——基于样本、特征、模型、关系的迁移
人工智能·机器学习·迁移学习
Xiaoxiaoxiao020916 小时前
GAEA AudioVisual:用“真实情绪数据”训练 AI 的一次尝试
人工智能
Deepoch16 小时前
具身智能:正打破农业机器人的“自动化孤岛”
人工智能·机器人·自动化·具身模型·deepoc
志凌海纳SmartX16 小时前
AI知识科普丨学习框架和推理引擎有什么区别?
人工智能
电化学仪器白超16 小时前
《可编程固定阻值电子负载的制作与自动化标定技术》
python·单片机·嵌入式硬件·自动化