PyTorch 张量(Tensors)全面指南:从基础到实战

文章目录

什么是张量?

张量(Tensors)是 PyTorch 中的核心数据结构,类似于数组和矩阵,但具有更强大的功能。在深度学习中,我们使用张量来表示:

  • 模型的输入和输出数据
  • 模型的参数(权重和偏置)
  • 中间计算过程中的数据

张量与 NumPy 的 ndarrays 类似,但有两大关键优势:

  1. GPU 加速:可在 GPU 或其他硬件加速器上运行
  2. 自动微分:支持自动求导,这对深度学习至关重要
python 复制代码
import torch
import numpy as np

张量初始化方法

1. 直接从数据创建

python 复制代码
data = [[1, 2], [3, 4]]
x_data = torch.tensor(data)

2. 从 NumPy 数组转换

python 复制代码
np_array = np.array(data)
x_np = torch.from_numpy(np_array)

3. 基于现有张量创建

python 复制代码
x_ones = torch.ones_like(x_data)  # 保留原始张量属性
x_rand = torch.rand_like(x_data, dtype=torch.float)  # 覆盖数据类型

print(f"Ones Tensor:\n{x_ones}")
print(f"Random Tensor:\n{x_rand}")

4. 使用随机值或常量

python 复制代码
shape = (2, 3)
rand_tensor = torch.rand(shape)
ones_tensor = torch.ones(shape)
zeros_tensor = torch.zeros(shape)

print(f"Random Tensor:\n{rand_tensor}")
print(f"Ones Tensor:\n{ones_tensor}")
print(f"Zeros Tensor:\n{zeros_tensor}")

张量属性

每个张量都有三个关键属性:

python 复制代码
tensor = torch.rand(3, 4)

print(f"Shape: {tensor.shape}")    # 形状
print(f"Datatype: {tensor.dtype}") # 数据类型
print(f"Device: {tensor.device}")  # 存储设备 (CPU/GPU)

张量操作

设备转移

python 复制代码
# 转移到GPU(如果可用)
device = "cuda" if torch.cuda.is_available() else "cpu"
tensor = tensor.to(device)
print(f"Device after transfer: {tensor.device}")

索引和切片

python 复制代码
tensor = torch.ones(4, 4)
print(f"First row: {tensor[0]}")
print(f"First column: {tensor[:, 0]}")
print(f"Last column: {tensor[..., -1]}")

tensor[:, 1] = 0  # 修改第二列
print(tensor)

连接张量

python 复制代码
t1 = torch.cat([tensor, tensor, tensor], dim=1)
print(f"Concatenated tensor:\n{t1}")

算术运算

python 复制代码
# 矩阵乘法(三种等效方式)
y1 = tensor @ tensor.T
y2 = tensor.matmul(tensor.T)
y3 = torch.rand_like(y1)
torch.matmul(tensor, tensor.T, out=y3)

# 逐元素乘法(三种等效方式)
z1 = tensor * tensor
z2 = tensor.mul(tensor)
z3 = torch.rand_like(tensor)
torch.mul(tensor, tensor, out=z3)

单元素张量转换

python 复制代码
agg = tensor.sum()
agg_item = agg.item()  # 转换为Python标量
print(f"Sum value: {agg_item}, Type: {type(agg_item)}")

原地操作(In-place Operations)

原地操作直接修改张量内容,使用 _ 后缀表示:

python 复制代码
print("Original tensor:")
print(tensor)

tensor.add_(5)  # 原地加5
print("\nAfter in-place addition:")
print(tensor)

注意:虽然原地操作节省内存,但在自动微分中可能导致梯度计算问题,应谨慎使用。

PyTorch 与 NumPy 互操作

张量转 NumPy 数组

python 复制代码
t = torch.ones(5)
n = t.numpy()
print(f"Tensor: {t}\nNumPy: {n}")

# 修改张量会影响NumPy数组
t.add_(1)
print(f"\nAfter modification:\nTensor: {t}\nNumPy: {n}")

NumPy 数组转张量

python 复制代码
n = np.ones(5)
t = torch.from_numpy(n)
print(f"NumPy: {n}\nTensor: {t}")

# 修改NumPy数组会影响张量
np.add(n, 1, out=n)
print(f"\nAfter modification:\nNumPy: {n}\nTensor: {t}")

张量操作总结表

操作类型 方法示例 说明
创建 torch.tensor(), torch.rand(), torch.zeros() 多种初始化方式
属性 .shape, .dtype, .device 获取张量元数据
索引 tensor[0], tensor[:, 1] 类似NumPy的索引
运算 torch.matmul(), tensor.sum() 矩阵运算和归约
连接 torch.cat(), torch.stack() 合并多个张量
转换 .numpy(), torch.from_numpy() 与NumPy互转

最佳实践与注意事项

  1. 设备管理:明确张量所在的设备(CPU/GPU),避免不必要的设备间传输
  2. 数据类型 :注意操作中的数据类型一致性,使用 .dtype 检查
  3. 内存共享:PyTorch 和 NumPy 数组共享内存,修改一个会影响另一个
  4. 自动微分:避免在需要梯度的计算图中使用原地操作
  5. 性能优化:对大规模数据使用 GPU 加速,对小规模操作可能 CPU 更高效
python 复制代码
# 高效设备转移示例
if torch.cuda.is_available():
    tensor = tensor.to('cuda')
    
# 保持数据类型一致
float_tensor = torch.rand(3, dtype=torch.float32)
int_tensor = torch.tensor([1, 2, 3], dtype=torch.int32)
result = float_tensor + int_tensor.float()  # 显式转换

掌握张量操作是使用 PyTorch 进行深度学习的基础。通过本文介绍的各种方法,您可以高效地创建、操作和转换张量,为构建复杂模型奠定坚实基础!

官方文档:https://docs.pytorch.org/tutorials/beginner/basics/tensorqs_tutorial.html

相关推荐
诸葛箫声7 分钟前
基于PyTorch的CIFAR-10图像分类项目总结
人工智能·pytorch·分类
en-route17 分钟前
从零开始学神经网络——GRU(门控循环单元)
人工智能·深度学习·gru
说私域23 分钟前
基于开源AI大模型AI智能名片S2B2C商城小程序的产地优势产品营销策略研究
人工智能·小程序·开源
说私域25 分钟前
蒸汽机革命后工业生产方式的变革与AI智能名片S2B2C商城小程序的影响
大数据·人工智能·小程序
一人の梅雨31 分钟前
亚马逊 MWS 关键字 API 实战:关键字搜索商品列表接口深度解析与优化方案
python·spring
MongoVIP35 分钟前
AI提示词应用
人工智能·职场和发展·简历优化·简历制作
深圳UMI1 小时前
AI笔记在学习与工作中的高效运用
大数据·人工智能
大模型真好玩1 小时前
深入浅出LangGraph AI Agent智能体开发教程(八)—LangGraph底层API实现ReACT智能体
人工智能·agent·deepseek
IT_陈寒1 小时前
告别低效!用这5个Python技巧让你的数据处理速度提升300% 🚀
前端·人工智能·后端
北京耐用通信2 小时前
神秘魔法?耐达讯自动化Modbus TCP 转 Profibus 如何为光伏逆变器编织通信“天网”
网络·人工智能·网络协议·网络安全·自动化·信息与通信