PyTorch 中Tensor常用数据结构(int, list, numpy array等)互相转换和实战示例

在 PyTorch 中,tensor 是一种强大且灵活的数据结构,可以与多种 Python 常用数据结构(如 int, list, numpy array 等)互相转换。下面是详细解释和代码示例:


1. Tensor ↔ int / float

转为 int / float(前提是 tensor 中只有一个元素)

python 复制代码
import torch

t = torch.tensor(3.14)
i = t.item()        # 转为 float
j = int(t.item())   # 强制转为 int

print(i)  # 3.14
print(j)  # 3

.item() 只能用于单元素张量:tensor.numel() == 1,否则会报错。


2. Tensor ↔ list

Tensor 转 list(Python 原生嵌套 list)

python 复制代码
t = torch.tensor([[1, 2], [3, 4]])
lst = t.tolist()
print(lst)  # [[1, 2], [3, 4]]

list 转 Tensor

python 复制代码
lst = [[1, 2], [3, 4]]
t = torch.tensor(lst)
print(t)  # tensor([[1, 2], [3, 4]])

支持嵌套 list(矩阵)、一维 list(向量)。


3. Tensor ↔ numpy.ndarray

PyTorch Tensor 和 NumPy array 之间可以无缝转换,共享内存(改变其中一个会影响另一个)。

Tensor → numpy array

python 复制代码
import numpy as np
t = torch.tensor([[1, 2], [3, 4]])
a = t.numpy()
print(type(a))  # <class 'numpy.ndarray'>

numpy array → Tensor

python 复制代码
a = np.array([[1, 2], [3, 4]])
t = torch.from_numpy(a)
print(type(t))  # <class 'torch.Tensor'>

numpy 数组必须是数值型(不能是对象数组等),否则会报错。


4. Tensor ↔ Python scalar 类型(int, float)

如果你从计算结果中获取单个数值,比如:

python 复制代码
t = torch.tensor([5.5])
val = float(t)   # 也可以使用 float(t.item())
print(val)       # 5.5

# 对于整型:
t2 = torch.tensor([3])
val2 = int(t2)   # 等效于 int(t2.item())
print(val2)      # 3

5. Tensor ↔ bytes(用于序列化,如保存到文件)

Tensor → bytes

python 复制代码
t = torch.tensor([1, 2, 3])
b = t.numpy().tobytes()

bytes → Tensor

python 复制代码
import numpy as np
b = b'\x01\x00\x00\x00\x02\x00\x00\x00\x03\x00\x00\x00'
a = np.frombuffer(b, dtype=np.int32)
t = torch.from_numpy(a)
print(t)  # tensor([1, 2, 3], dtype=torch.int32)

6.实战示例

下面我们从三个实际应用场景来讲解 PyTorch 中 tensor 与其他类型(如 listnumpyint 等)相互转换的用途和技巧:


场景一:数据加载与预处理

读取图像数据(使用 PIL) → 转为 tensor
python 复制代码
from PIL import Image
from torchvision import transforms

img = Image.open('cat.jpg')  # 打开图片为 PIL.Image
to_tensor = transforms.ToTensor()
t = to_tensor(img)  # 转为 [C, H, W] 的 float32 Tensor

此时你获得了一个 Tensor,可以送入模型。但如果你想可视化或分析:

Tensor → numpy → 可视化或保存
python 复制代码
import matplotlib.pyplot as plt

img_np = t.permute(1, 2, 0).numpy()  # [H, W, C]
plt.imshow(img_np)
plt.show()

permute 是因为 ToTensor 会变成 [C,H,W],而 matplotlib 需要 [H,W,C]


场景二:模型推理后的结果处理(转为 Python 值)

假设你有一个分类网络,输出如下:

python 复制代码
output = torch.tensor([[0.1, 0.7, 0.2]])  # 假设输出为 batch_size=1 的 logits
pred_idx = output.argmax(dim=1)  # tensor([1])

你要拿到预测类别的整数值:

python 复制代码
pred_class = pred_idx.item()  # 1
print(type(pred_class))       # <class 'int'>

.item() 在推理阶段非常常用!


场景三:保存 Tensor 到磁盘 / 网络传输

Tensor 保存和加载时经常需要转为 numpy 或 byte 流:

保存为 bytes 再写入文件
python 复制代码
t = torch.tensor([1, 2, 3, 4], dtype=torch.int32)
with open("tensor.bin", "wb") as f:
    f.write(t.numpy().tobytes())
从文件读回 tensor
python 复制代码
with open("tensor.bin", "rb") as f:
    byte_data = f.read()

import numpy as np
arr = np.frombuffer(byte_data, dtype=np.int32)
t2 = torch.from_numpy(arr)
print(t2)  # tensor([1, 2, 3, 4], dtype=torch.int32)

你必须记住原始 dtypeshape 才能正确还原!


场景四:构造 batch 时将 list 转为 Tensor

在训练时经常从数据集中拿到多个样本组成 batch(Python list):

python 复制代码
samples = [[1.0, 2.0], [3.0, 4.0]]
batch_tensor = torch.tensor(samples, dtype=torch.float32)
print(batch_tensor.shape)  # torch.Size([2, 2])

或者更通用的方式(可以处理动态 shape):

python 复制代码
batch_tensor = torch.stack([torch.tensor(s) for s in samples])

补充:在 with torch.no_grad() 中常用转换

推理阶段经常用 Tensor → numpy → list

python 复制代码
with torch.no_grad():
    output = model(input_tensor)
    pred = output.softmax(dim=1)
    top1_class = pred.argmax(dim=1).item()

小结对照表

转换类型 方法 注意事项
Tensor → int/float .item() 只能单元素
Tensor → list .tolist() 支持嵌套
list → Tensor torch.tensor(list) 自动推断类型
Tensor → ndarray .numpy() 共享内存
ndarray → Tensor torch.from_numpy(ndarray) 共享内存
Tensor → bytes tensor.numpy().tobytes() 用于存储
bytes → Tensor np.frombuffer + from_numpy 需知道 dtype
相关推荐
FL16238631294 分钟前
红花识别分割数据集labelme格式144张1类别
人工智能·深度学习
先做个垃圾出来………8 分钟前
1. 两数之和
算法·leetcode·职场和发展
程序员JerrySUN20 分钟前
OpenCV 全解读:核心、源码结构与图像/视频渲染能力深度对比
linux·人工智能·驱动开发·opencv·计算机视觉·缓存·音视频
张较瘦_25 分钟前
[论文阅读] 人工智能 + 软件工程 | GitHub Marketplace中CI Actions的功能冗余与演化规律研究
论文阅读·人工智能·软件工程
神器阿龙33 分钟前
排序算法-冒泡排序
数据结构·算法·排序算法
瘦的可以下饭了43 分钟前
Tensorboard
pytorch
C++ 老炮儿的技术栈1 小时前
在vscode 如何运行a.nut 程序(Squirrel语言)
c语言·开发语言·c++·ide·vscode·算法·编辑器
martian6651 小时前
深度学习核心:神经网络-激活函数 - 原理、实现及在医学影像领域的应用
人工智能·深度学习·神经网络·机器学习·医学影像·影像大模型
HKUST_ZJH1 小时前
交互 Codeforces Round 1040 Interactive RBS
c++·算法·交互