maxpool long数据类型报错

报错

RuntimeError: "max_pool2d" not implemented for 'Long'

源码:

python 复制代码
import torch
from torch import nn
from torch.nn import MaxPool2d

input = torch.tensor([[1, 2, 0, 3, 1],
                      [0, 1, 2, 3, 1],
                      [1, 2, 1, 0, 0],
                      [5, 2, 3, 1, 1],
                      [2, 1, 0, 1, 1]])

input = torch.reshape(input, (-1, 1, 5, 5))

class Tudui(nn.Module):
    def __init__(self):
        super(Tudui, self).__init__()
        self.maxpool1 = MaxPool2d(kernel_size=3, ceil_mode=True)
    def forward(self,input):
        output = self.maxpool1(input)
        return output

tudui = Tudui()
output = tudui(input)
print(output)

错误是说,max_pool2d不能操作long数据类型。也就是我们需要修改input的数据类型。

对input的tensor进行修改,增加一句

python 复制代码
dtype=torch.float32
python 复制代码
input = torch.tensor([[1, 2, 0, 3, 1],
                      [0, 1, 2, 3, 1],
                      [1, 2, 1, 0, 0],
                      [5, 2, 3, 1, 1],
                      [2, 1, 0, 1, 1]], dtype=torch.float32)

修改后代码:

python 复制代码
import torch
from torch import nn
from torch.nn import MaxPool2d

input = torch.tensor([[1, 2, 0, 3, 1],
                      [0, 1, 2, 3, 1],
                      [1, 2, 1, 0, 0],
                      [5, 2, 3, 1, 1],
                      [2, 1, 0, 1, 1]], dtype=torch.float32)

input = torch.reshape(input, (-1, 1, 5, 5))

class Tudui(nn.Module):
    def __init__(self):
        super(Tudui, self).__init__()
        self.maxpool1 = MaxPool2d(kernel_size=3, ceil_mode=True)
    def forward(self,input):
        output = self.maxpool1(input)
        return output

tudui = Tudui()
output = tudui(input)
print(output)

运行结果:

运行结果中的数据类型都变成了float。

相关推荐
Francek Chen1 小时前
【自然语言处理】应用04:自然语言推断与数据集
人工智能·pytorch·深度学习·神经网络·自然语言处理
在屏幕前出油9 小时前
二、Python面向对象编程基础——理解self
开发语言·python
Java后端的Ai之路9 小时前
【神经网络基础】-神经网络学习全过程(大白话版)
人工智能·深度学习·神经网络·学习
阿方索9 小时前
python文件与数据格式化
开发语言·python
信创天地11 小时前
信创国产化数据库的厂商有哪些?分别用在哪个领域?
数据库·python·网络安全·系统架构·系统安全·运维开发
不哦罗密经11 小时前
python相关
服务器·前端·python
happybasic11 小时前
python字典中字段重复性的分析~~
开发语言·python
山海青风12 小时前
人工智能基础与应用 - 数据处理、建模与预测流程 6 模型训练
人工智能·python·机器学习
l木本I12 小时前
Reinforcement Learning for VLA(强化学习+VLA)
c++·人工智能·python·机器学习·机器人
颖风船12 小时前
锂电池SOC估计的一种算法(改进无迹卡尔曼滤波)
python·算法·信号处理