PyTorch深度学习总结
第五章 PyTorch中张量(Tensor)计算操作
文章目录
前言
上文介绍了PyTorch中张量(Tensor)
的拆分
和拼接
操作,本文将介绍张量
的计算
操作。
一、张量比较大小
函数 | 描述 |
---|---|
torch.allclose() | 比较两个元素是否接近 |
torch.eq() | 逐元素比较是否相等 |
torch.equal() | 判断两个张量是否具有相同的形状和元素 |
torch.ge() | 逐元素比较大于等于 |
torch.gt() | 逐元素比较大于 |
torch.le() | 逐元素比较小于等于 |
torch.lt() | 逐元素比较小于 |
torch.ne() | 逐元素比较不等于 |
torch.isnan() | 判断是否为缺失值 |
1、torch.allclose()
函数用法:
torch.allclose(A, B, rtol=,atol=)
判断是否接近的公式如下:
∣ A − B ∣ ≤ a t o l + r t o l × ∣ B ∣ . |A-B| \leq atol+rtol\times|B|\,. ∣A−B∣≤atol+rtol×∣B∣.
python# 引入库 import torch # 创建张量A A = torch.tensor([10.0, 10.0]) # 测试函数 print(torch.allclose(A, A, rtol=0.1, atol=0.01,equal_nan=False))
输出结果为:False
2、torch.eq()和torch.equal()
①函数用法:
torch.eq(A, B)
主要比较
元素
之间的关系,即两个对应元素是否相等
。
python# 测试函数 print(torch.eq(A, A))
输出结果为:tensor([True, True])
②函数用法:
torch.equal(A, B)
主要比较
张量
之间的关系,即两个张量形状和大小是否相等
。
python# 测试函数 print(torch.equal(A, A))
输出结果为:True
3、ge、gt、le、lt、ne函数
函数用法:
torch.ge(A, B)
主要用于
逐元素
比较,看是否大于等于
( ≥ \geq ≥)。注:以上几个函数除本身意义不同外,其他用法几乎相同;故本文只针对
torch.ge()
进行展示。
python# 生成张量 B = torch.tensor([11.0, 9.0]) # 测试函数 print(torch.ge(A, B))
输出结果为:tensor([False, True])
4、torch.isnan()
函数用法:
torch.isnan(A)
判断
张量A
对应元素是否为缺失值
。
pythonprint(torch.isnan(A)) print(torch.isnan(torch.tensor([0, 1, float("nan")])))
输出结果:
tensor([False, False])
tensor([False, False, True])
二、基本运算
1、四则运算(加减乘除)
生成试验数组:
python# 引入库 import torch # 生成张量 A = torch.arange(6).reshape(2,3) B = torch.linspace(1, 6, steps=6).reshape(2,3) # 在1-6之间生成5个等步长的元素组成张量 print(A, B)
输出结果为:
tensor([[0, 1, 2], [3, 4, 5]])
tensor([[1., 2., 3.], [4., 5., 6.]])
加减乘除运算为:+
、-
、*
、/
整除://
幂运算为:torch.pow()
或**
示例:
pythonprint(A+B) print(A-B) print(A*B) print(A/B) print(B//A) print(A**2) print(torch.pow(A, 2))
输出结果为:
tensor([[ 1., 3., 5.], [ 7., 9., 11.]])
tensor([[-1., -1., -1.], [-1., -1., -1.]])
tensor([[ 0., 2., 6.], [12., 20., 30.]])
tensor([[0.0000, 0.5000, 0.6667], [0.7500, 0.8000, 0.8333]])
tensor([[inf, 2., 1.], [1., 1., 1.]])
tensor([[ 0, 1, 4], [ 9, 16, 25]])
tensor([[ 0, 1, 4], [ 9, 16, 25]])
2、其他计算
函数 | 描述 |
---|---|
torch.exp() | 张量的指数函数 |
torch.log() | 张量的对数函数 |
torch.sqrt() | 张量的平方根 |
torch.clamp_max() | 根据最大值裁剪 |
torch.clamp_min() | 根据最小值裁剪 |
torch.clamp() | 根据范围裁剪 |
torch.t() | 计算矩阵的转置 |
torch.matmul() | 计算矩阵的转置 |
torch.inverse() | 计算矩阵的逆矩阵 |
torch.trace() | 计算矩阵的迹 |