本文通过示例代码全面讲解PyTorch中张量的基本操作,包含创建、运算、广播机制、索引切片等核心功能,并提供完整的代码和输出结果。
1. 张量创建与基本属性
python
import torch
# 创建连续数值张量
x = torch.arange(12, dtype=torch.float32)
print("原始张量:\n", x)
print("形状:", x.shape)
print("元素总数:", x.numel())
# 创建全零/全一张量
zero = torch.zeros(2, 3, 4)
print("\n三维零张量:\n", zero)
one = torch.ones(3, 4)
print("\n全一张量:\n", one)
# 手动创建张量
a = torch.tensor([[1,2,3,4], [5,6,7,8], [9,10,11,12]])
print("\n自定义张量:\n", a)
输出结果:
bash
原始张量:
tensor([ 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11.])
形状: torch.Size([12])
元素总数: 12
三维零张量:
tensor([[[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]],
[[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]]])
全一张量:
tensor([[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.]])
自定义张量:
tensor([[ 1, 2, 3, 4],
[ 5, 6, 7, 8],
[ 9, 10, 11, 12]])
2. 张量重塑与转置
python
x = x.reshape(3, 4)
print("重塑后的3x4张量:\n", x)
print("转置张量:\n", x.T)
输出结果:
bash
重塑后的3x4张量:
tensor([[ 0., 1., 2., 3.],
[ 4., 5., 6., 7.],
[ 8., 9., 10., 11.]])
转置张量:
tensor([[ 0., 4., 8.],
[ 1., 5., 9.],
[ 2., 6., 10.],
[ 3., 7., 11.]])
3. 数学运算
python
# 矩阵减法
print("x - one:\n", x - one)
# 指数运算
b = torch.exp(a)
print("\n指数运算结果:\n", b)
输出结果:
bash
x - one:
tensor([[-1., 0., 1., 2.],
[ 3., 4., 5., 6.],
[ 7., 8., 9., 10.]])
指数运算结果:
tensor([[2.7183e+00, 7.3891e+00, 2.0086e+01, 5.4598e+01],
[1.4841e+02, 4.0343e+02, 1.0966e+03, 2.9810e+03],
[8.1031e+03, 2.2026e+04, 5.9874e+04, 1.6275e+05]])
4. 张量拼接与比较
python
# 行拼接
c = torch.cat((x, one), dim=0)
print("行拼接结果:\n", c)
# 列拼接
d = torch.cat((x, one), dim=1)
print("\n列拼接结果:\n", d)
# 张量比较
print("\n张量比较:\n", x == a)
输出结果:
bash
行拼接结果:
tensor([[ 0., 1., 2., 3.],
[ 4., 5., 6., 7.],
[ 8., 9., 10., 11.],
[ 1., 1., 1., 1.],
[ 1., 1., 1., 1.],
[ 1., 1., 1., 1.]])
列拼接结果:
tensor([[ 0., 1., 2., 3., 1., 1., 1., 1.],
[ 4., 5., 6., 7., 1., 1., 1., 1.],
[ 8., 9., 10., 11., 1., 1., 1., 1.]])
张量比较:
tensor([[False, False, False, False],
[False, False, False, False],
[False, False, False, False]])
5. 广播机制
python
e = torch.arange(3).reshape(3, 1)
print("广播加法:\n", x + e)
输出结果:
bash
广播加法:
tensor([[ 0., 1., 2., 3.],
[ 5., 6., 7., 8.],
[10., 11., 12., 13.]])
6. 索引与切片
python
print("最后一行:", x[-1])
print("第二到第三行:\n", x[1:3])
x[1, 2] = 100 # 修改单个元素
x[0:2, 1:3] = 0 # 修改子区域
print("\n修改后的张量:\n", x)
输出结果:
bash
最后一行: tensor([ 8., 9., 10., 11.])
第二到第三行:
tensor([[ 4., 5., 6., 7.],
[ 8., 9., 10., 11.]])
修改后的张量:
tensor([[ 0., 0., 0., 3.],
[ 4., 0., 0., 7.],
[ 8., 9., 10., 11.]])
7. 内存地址管理
python
before = id(x)
x = x + a # 新内存分配
# x += a # 原地操作
print("内存地址是否变化:", before == id(x))
D = x.clone()
print("克隆张量地址对比:", before == id(D))
输出结果:
bash
内存地址是否变化: False
克隆张量地址对比: False
8. PyTorch与NumPy转换
python
A = x.numpy()
B = torch.tensor(A)
print("类型转换:", type(A), type(B))
输出结果:
bash
类型转换: <class 'numpy.ndarray'> <class 'torch.Tensor'>
9. 统计操作
python
sum_a = a.sum(axis=1, keepdims=True)
print("按行求和:\n", sum_a)
print("归一化结果:\n", a / sum_a)
print("按列累加:\n", a.cumsum(axis=0))
输出结果:
bash
按行求和:
tensor([[10],
[26],
[42]])
归一化结果:
tensor([[0.1000, 0.2000, 0.3000, 0.4000],
[0.1923, 0.2308, 0.2692, 0.3077],
[0.2143, 0.2381, 0.2619, 0.2857]])
按列累加:
tensor([[ 1, 2, 3, 4],
[ 6, 8, 10, 12],
[15, 18, 21, 24]])
通过本文的示例代码,您可以快速掌握PyTorch张量操作的核心功能。建议读者在实际项目中多加练习以巩固知识!