【数据类型】Python 与 PyTorch 常见数据类型对应

张量类型的获取、转化与判别
python3
>>> import torch
>>> a = torch.randn(2,3)
>>> a
tensor([[-1.7818, -0.2472, -2.0684],
[ 0.0117, 1.4698, -0.9359]])
用
a.type()获取数据类型,用.type(目标类型)和.目标类型()转化类型,用isinstance(a, 目标类型)进行类型合法化检测
-
类型获取
python3>>> a.type() ## 获取数据类型 'torch.FloatTensor' -
类型转化
python3>>> a.type(torch.DoubleTensor) ## 类型转换方法一 .type(目标类型) >>> a.double() ## 类型转换方法二 .目标类型() -
类型判别
python3>>> isinstance(a, torch.FloatTensor) ## 类型合法化检测 True