1、加减乘除
import torch
a = torch.tensor([[1, 2], [3, 4]])
b = torch.tensor([[5, 6], [7, 8]])
# 加法
c = a + b # tensor([[ 6, 8], [10, 12]])
# 减法
d = a - b # tensor([[-4, -4], [-4, -4]])
# 乘法(逐元素相乘,不是矩阵乘法)
e = a * b # tensor([[ 5, 12], [21, 32]])
# 除法
f = a / b # tensor([[0.2000, 0.3333], [0.4286, 0.5000]])
2、矩阵乘法的三种
a = torch.tensor([[3., 3.],
[3., 3.]])
b = torch.ones(2, 2) # tensor([[1., 1.],
# [1., 1.])
torch.mm(a, b)
# tensor([[6., 6.],
# [6., 6.])
torch.matmul(a, b)
# tensor([[6., 6.],
# [6., 6.])
a@b
# tensor([[6., 6.],
# [6., 6.])
3、tensor运算
import torch
x = torch.tensor([1, 2, 3, 4])
# 平方
x2 = x ** 2 # tensor([1, 4, 9, 16])
# 立方
x3 = x ** 3 # tensor([ 1, 8, 27, 64])
# 任意次方(例如 0.5 次方即开平方)
x_half = x ** 0.5 # tensor([1.0000, 1.4142, 1.7321, 2.0000])
torch.pow()等价与**
y = torch.pow(x, 2) # 同 x ** 2
y = torch.pow(2, x) # 也可以底数为标量:2^x
平方根相关函数
torch.sqrt(x):计算 x ** 0.5,更高效
torch.rsqrt(x):计算 1 / sqrt(x)(倒数平方根)
torch.square(x):计算 x ** 2
a = torch.tensor([4.0, 9.0, 16.0])
print(torch.sqrt(a)) # tensor([2., 3., 4.])
print(torch.square(a)) # tensor([16., 81., 256.])
指数与幂
torch.exp(x):计算 e^x
torch.log(x):自然对数
注意:对负数开平方会得到 nan(实数域),如需复数请用 torch.complex64
a.floor()向下取整
a.ceil()向上取整
a.trunc()取整数部分
a.frac()取小数部分
a.round()四舍五入
4、clamp() 裁剪函数用法
torch.clamp()(及 tensor.clamp())用于将张量中的每个元素限制在指定的数值范围 [min, max] 内。
小于
min的元素变为min大于
max的元素变为max其余元素保持不变
import torch
x = torch.tensor([-2, 1, 3, 5, 8])
# 裁剪到 [0, 5] 之间
y = torch.clamp(x, min=0, max=5) # tensor([0, 1, 3, 5, 5])
# 只设置下限
y = x.clamp(min=0) # tensor([0, 1, 3, 5, 8])
# 只设置上限
y = x.clamp(max=4) # tensor([-2, 1, 3, 4, 4])