PyTorch 中Tensor常用数据结构(int, list, numpy array等)互相转换和实战示例

在 PyTorch 中,tensor 是一种强大且灵活的数据结构,可以与多种 Python 常用数据结构(如 int, list, numpy array 等)互相转换。下面是详细解释和代码示例:


1. Tensor ↔ int / float

转为 int / float(前提是 tensor 中只有一个元素)

python 复制代码
import torch

t = torch.tensor(3.14)
i = t.item()        # 转为 float
j = int(t.item())   # 强制转为 int

print(i)  # 3.14
print(j)  # 3

.item() 只能用于单元素张量:tensor.numel() == 1,否则会报错。


2. Tensor ↔ list

Tensor 转 list(Python 原生嵌套 list)

python 复制代码
t = torch.tensor([[1, 2], [3, 4]])
lst = t.tolist()
print(lst)  # [[1, 2], [3, 4]]

list 转 Tensor

python 复制代码
lst = [[1, 2], [3, 4]]
t = torch.tensor(lst)
print(t)  # tensor([[1, 2], [3, 4]])

支持嵌套 list(矩阵)、一维 list(向量)。


3. Tensor ↔ numpy.ndarray

PyTorch Tensor 和 NumPy array 之间可以无缝转换,共享内存(改变其中一个会影响另一个)。

Tensor → numpy array

python 复制代码
import numpy as np
t = torch.tensor([[1, 2], [3, 4]])
a = t.numpy()
print(type(a))  # <class 'numpy.ndarray'>

numpy array → Tensor

python 复制代码
a = np.array([[1, 2], [3, 4]])
t = torch.from_numpy(a)
print(type(t))  # <class 'torch.Tensor'>

numpy 数组必须是数值型(不能是对象数组等),否则会报错。


4. Tensor ↔ Python scalar 类型(int, float)

如果你从计算结果中获取单个数值,比如:

python 复制代码
t = torch.tensor([5.5])
val = float(t)   # 也可以使用 float(t.item())
print(val)       # 5.5

# 对于整型:
t2 = torch.tensor([3])
val2 = int(t2)   # 等效于 int(t2.item())
print(val2)      # 3

5. Tensor ↔ bytes(用于序列化,如保存到文件)

Tensor → bytes

python 复制代码
t = torch.tensor([1, 2, 3])
b = t.numpy().tobytes()

bytes → Tensor

python 复制代码
import numpy as np
b = b'\x01\x00\x00\x00\x02\x00\x00\x00\x03\x00\x00\x00'
a = np.frombuffer(b, dtype=np.int32)
t = torch.from_numpy(a)
print(t)  # tensor([1, 2, 3], dtype=torch.int32)

6.实战示例

下面我们从三个实际应用场景来讲解 PyTorch 中 tensor 与其他类型(如 listnumpyint 等)相互转换的用途和技巧:


场景一:数据加载与预处理

读取图像数据(使用 PIL) → 转为 tensor
python 复制代码
from PIL import Image
from torchvision import transforms

img = Image.open('cat.jpg')  # 打开图片为 PIL.Image
to_tensor = transforms.ToTensor()
t = to_tensor(img)  # 转为 [C, H, W] 的 float32 Tensor

此时你获得了一个 Tensor,可以送入模型。但如果你想可视化或分析:

Tensor → numpy → 可视化或保存
python 复制代码
import matplotlib.pyplot as plt

img_np = t.permute(1, 2, 0).numpy()  # [H, W, C]
plt.imshow(img_np)
plt.show()

permute 是因为 ToTensor 会变成 [C,H,W],而 matplotlib 需要 [H,W,C]


场景二:模型推理后的结果处理(转为 Python 值)

假设你有一个分类网络,输出如下:

python 复制代码
output = torch.tensor([[0.1, 0.7, 0.2]])  # 假设输出为 batch_size=1 的 logits
pred_idx = output.argmax(dim=1)  # tensor([1])

你要拿到预测类别的整数值:

python 复制代码
pred_class = pred_idx.item()  # 1
print(type(pred_class))       # <class 'int'>

.item() 在推理阶段非常常用!


场景三:保存 Tensor 到磁盘 / 网络传输

Tensor 保存和加载时经常需要转为 numpy 或 byte 流:

保存为 bytes 再写入文件
python 复制代码
t = torch.tensor([1, 2, 3, 4], dtype=torch.int32)
with open("tensor.bin", "wb") as f:
    f.write(t.numpy().tobytes())
从文件读回 tensor
python 复制代码
with open("tensor.bin", "rb") as f:
    byte_data = f.read()

import numpy as np
arr = np.frombuffer(byte_data, dtype=np.int32)
t2 = torch.from_numpy(arr)
print(t2)  # tensor([1, 2, 3, 4], dtype=torch.int32)

你必须记住原始 dtypeshape 才能正确还原!


场景四:构造 batch 时将 list 转为 Tensor

在训练时经常从数据集中拿到多个样本组成 batch(Python list):

python 复制代码
samples = [[1.0, 2.0], [3.0, 4.0]]
batch_tensor = torch.tensor(samples, dtype=torch.float32)
print(batch_tensor.shape)  # torch.Size([2, 2])

或者更通用的方式(可以处理动态 shape):

python 复制代码
batch_tensor = torch.stack([torch.tensor(s) for s in samples])

补充:在 with torch.no_grad() 中常用转换

推理阶段经常用 Tensor → numpy → list

python 复制代码
with torch.no_grad():
    output = model(input_tensor)
    pred = output.softmax(dim=1)
    top1_class = pred.argmax(dim=1).item()

小结对照表

转换类型 方法 注意事项
Tensor → int/float .item() 只能单元素
Tensor → list .tolist() 支持嵌套
list → Tensor torch.tensor(list) 自动推断类型
Tensor → ndarray .numpy() 共享内存
ndarray → Tensor torch.from_numpy(ndarray) 共享内存
Tensor → bytes tensor.numpy().tobytes() 用于存储
bytes → Tensor np.frombuffer + from_numpy 需知道 dtype
相关推荐
学术小八1 分钟前
2025年人工智能、虚拟现实与交互设计国际学术会议
人工智能·交互·vr
岁忧1 小时前
(LeetCode 面试经典 150 题 ) 11. 盛最多水的容器 (贪心+双指针)
java·c++·算法·leetcode·面试·go
仗剑_走天涯1 小时前
基于pytorch.nn模块实现线性模型
人工智能·pytorch·python·深度学习
chao_7891 小时前
二分查找篇——搜索旋转排序数组【LeetCode】两次二分查找
开发语言·数据结构·python·算法·leetcode
cnbestec2 小时前
协作机器人UR7e与UR12e:轻量化设计与高负载能力助力“小而美”智造升级
人工智能·机器人·协作机器人·ur协作机器人·ur7e·ur12e
zskj_zhyl2 小时前
毫米波雷达守护银发安全:七彩喜跌倒检测仪重构居家养老防线
人工智能·安全·重构
秋说3 小时前
【PTA数据结构 | C语言版】一元多项式求导
c语言·数据结构·算法
gaosushexiangji3 小时前
利用sCMOS科学相机测量激光散射强度
大数据·人工智能·数码相机·计算机视觉
Maybyy3 小时前
力扣61.旋转链表
算法·leetcode·链表
谭林杰4 小时前
B树和B+树
数据结构·b树