pytorch中的各种计算

对tensor矩阵的维度变换,加减乘除等是深度学习中的常用操作,本文对一些常用方法进行总结

矩阵乘法

混合矩阵相乘,官网

c 复制代码
torch.matmul(input, other, *, out=None) → Tensor

这个方法执行矩阵相乘操作,需要第一个矩阵的最后一个维度和第二个矩阵的第一个维度相同,即:假设我们有两个矩阵 A 和 B,它们的 size 分别为 (m, n) 和 (n, p),那么 A x B 的 size 为 (m, p)。

矩阵点乘,官网

c 复制代码
torch.mul(input, other, *, out=None) → Tensor

这个方法对矩阵做点积运算(也可简写为*),这个方法要求第一个矩阵的第一个维度和第二个矩阵的第一个维度对应。torch.dot()类似于mul(),它是向量(即只能是一维的张量)的对应位相乘再求和,返回一个tensor。

矩阵维度变换

tensor.view方法,用于调整矩阵的维度,这个方法要求矩阵在调整为度前后的元素个数必须是相同的,官网,例子:

c 复制代码
>>> t = torch.rand(4, 4)
>>> b = t.view(2, 8)
>>> t.storage().data_ptr() == b.storage().data_ptr()  # `t` and `b` share the same underlying data.
True
# Modifying view tensor changes base tensor as well.
>>> b[0][0] = 3.14
>>> t[0][0]
tensor(3.14)

torch中对矩阵的压缩和解压操作:torch.squeeze和torch.unsqueeze,这两种方法的作用是压缩矩阵中的某一个维度或者增加一个维度,官网,两种方法的详解可以参考我之前的笔记pytorch中的torch.squeeze和torch.unsqueeze

矩阵填充,官网torch.nn.functional.pad

c 复制代码
torch.nn.functional.pad(input, pad, mode='constant', value=None) → Tensor
Args:
	"""
	input:四维或者五维的tensor Variabe
	pad:不同Tensor的填充方式
		1.四维Tensor:传入四元素tuple(pad_l, pad_r, pad_t, pad_b),
		指的是(左填充,右填充,上填充,下填充),其数值代表填充次数
		2.六维Tensor:传入六元素tuple(pleft, pright, ptop, pbottom, pfront, pback),
		指的是(左填充,右填充,上填充,下填充,前填充,后填充),其数值代表填充次数
	mode: 'constant', 'reflect' or 'replicate'三种模式,指的是常量,反射,复制三种模式
	value:填充的数值,在"contant"模式下默认填充0,mode="reflect" or "replicate"时没有			

如果给入的填充次数是负数,该函数可以实现从该方向对矩阵的裁剪操作。

需要注意的是,本文中提到的所有方法都支持broadcast操作,也就是,除了参与操作的最后两个维度(矩阵),前面的所有维度都会被认为是batch,以torch,matmul为例,该方法使用两个tensor的后两个维度来计算,其他的维度都可以认为是batch。假设两个输入的维度分别是 i n p u t ( 1000 × 500 × 99 × 11 ) input(1000×500×99×11) input(1000×500×99×11), o t h e r ( 500 × 11 × 99 ) other(500×11×99) other(500×11×99),那么我们可以认为 t o r c h . m a t m u l ( i n p u t , o t h e r ) torch.matmul(input,other) torch.matmul(input,other) 首先是进行后两位矩阵乘法得到 ( 99 × 99 ) (99×99) (99×99) ,然后分析两个参数的batch size分别是 ( 1000 × 500 ) (1000×500) (1000×500)和 ( 500 ) (500) (500), 可以广播成为 ( 1000 × 500 ) (1000×500) (1000×500),因此最终输出的维度是 ( 1000 × 500 × 99 × 99 ) (1000×500×99×99) (1000×500×99×99)。

相关推荐
岑梓铭28 分钟前
(CentOs系统虚拟机)Standalone模式下安装部署“基于Python编写”的Spark框架
linux·python·spark·centos
边缘计算社区31 分钟前
首个!艾灵参编的工业边缘计算国家标准正式发布
大数据·人工智能·边缘计算
游客52042 分钟前
opencv中的各种滤波器简介
图像处理·人工智能·python·opencv·计算机视觉
一位小说男主42 分钟前
编码器与解码器:从‘乱码’到‘通话’
人工智能·深度学习
Eric.Lee20211 小时前
moviepy将图片序列制作成视频并加载字幕 - python 实现
开发语言·python·音视频·moviepy·字幕视频合成·图像制作为视频
Dontla1 小时前
vscode怎么设置anaconda python解释器(anaconda解释器、vscode解释器)
ide·vscode·python
深圳南柯电子1 小时前
深圳南柯电子|电子设备EMC测试整改:常见问题与解决方案
人工智能
Kai HVZ1 小时前
《OpenCV计算机视觉》--介绍及基础操作
人工智能·opencv·计算机视觉
biter00881 小时前
opencv(15) OpenCV背景减除器(Background Subtractors)学习
人工智能·opencv·学习
吃个糖糖1 小时前
35 Opencv 亚像素角点检测
人工智能·opencv·计算机视觉