torch常用函数

目录

一、 torch.bmm

torch.bmm 是 PyTorch 中的一个函数,用于执行批矩阵乘法( matrix multiplication)操作。

它的输入是三维张量,形状为 (batch, n, m) 和 (batch, m, p):

其中 n 是第一个矩阵的列数,m 是两个矩阵共享的维度,p 是第二个矩阵的列数。

torch.bmm 将批中的每对矩阵相乘,返回一个新的三维张量,形状为 (batch, n, p)。

二、 torch.einsum

torch.einsum是pytorch上的一个强大的函数,用于矩阵相关的计算,注意,这里没有限定为矩阵乘法。torch.einsum基于爱因斯坦求和约定执行张量操作,能够用简洁的表达式实现复杂的多维数组操作,从而避免繁琐的张量操作组合(如reshape、permute、bmm等),减少错误率。需要说明的是,尽管einsum函数内部进行了大量计算优化,但其主要优势在于表达式简洁,如果与单步reshape等pytorch实现的矩阵运算操作相比,其运算速度与内存占用不一定占优势。

1.矩阵乘法:'ij,jk->ik' 表示形状为(i,j)与形状为(j,k)的矩阵进行矩阵乘法,得到新矩阵形状为(i,k)。这也是torch.einsum最常规的用法。

2.维度调换:'ij->ji'表示形状为(i,j)的矩阵维度调换成为形状为(j,i)的矩阵。

torch.einsum还有多种用法,遇到再来添加

三、python中变量前面有个*

在Python中,变量前面的星号(*)有多种用法,主要与函数参数或解包序列有关。

1、在函数参数中,星号(*)用来表示任意多个参数,这些参数会被当作元组传递。例如:

python 复制代码
def fun(*args):
    for i in args:
        print(i)
 
fun(1, 2, 3, 4)

2、在函数参数中,星号(*)还可以用来解包序列。例如:

python 复制代码
def fun(a, b, c, d):
    print(a, b, c, d)
 
args = (1, 2, 3, 4)
fun(*args)

3、在函数参数中,星号(*)还可以与命名参数,或者字典一起使用。例如:

python 复制代码
def fun(*args, a=1):
    print(args, a)
 
fun(1, 2, 3, a=4)

def fun(*args, **kwargs):
    print(args, kwargs)
 
fun(1, 2, 3, a=4, b=5)

4、 在解包列表或元组时,星号(*)也可以用来解包选定项。例如:

python 复制代码
lst = [1, 2, 3, 4, 5]
a, *b, c = lst
print(a, b, c)

四、numpy.prod

计算元素和

python 复制代码
print(np.prod([[1., 2.], [3., 4.]], axis=0))按列计算元素和
print(np.prod([[1., 2.], [3., 4.]], axis=1))按行计算元素和
print(np.prod([[1., 2.], [3., 4.]], axis=0))计算所有元素和

五、torch.chunk

对于一个输入tensor,torch.chunk方法会按照dim指定的维度将输入tensor划分为若干个chunk,划分的数量为chunks。

python 复制代码
torch.chunk(input, chunks, dim=0) 

temp=torch.randn((4,6))
print(torch.chunk(temp,2,0))行方向分块
print(torch.chunk(temp,2,1))列方向分块

六、torch.contiguous

在PyTorch中,.contiguous()方法的作用是确保张量在内存中是连续存储的。当你对张量执行某些操作,如transpose()、permute()、narrow()、expand()等之后,得到的张量可能不再在内存中连续排列。这些操作通常返回一个张量的视图,它们改变的是数据访问的方式,而不是实际的数据存储方式。

在内存中连续排列的张量有一个特性:对于张量中任意两个相邻的元素,它们在物理内存中的位置也是相邻的。换句话说,张量在物理存储上的排列顺序与在张量形式上的逻辑排列顺序一致。

当调用.contiguous()时,如果张量已经是连续的,这个函数实际上不会做任何事;但如果不是,PyTorch将会重新分配内存并确保张量的数据连续排列。这涉及到复制数据到新的内存区域,并返回一个新的张量,该张量在内存中实际是连续的。

调用view之前最好先contiguous,也就是x.contiguous().view()

python 复制代码
import torch

# 创建一个非连续张量
x = torch.arange(12).view(3, 4).permute(1,0)  # 移动维度
print(x.is_contiguous())  # False

# 使用 .contiguous() 来确保张量是连续的
y = x.contiguous()
print(y.is_contiguous())  # True

六、torch.clamp

在PyTorch中,clamp函数是一个非常实用的操作,它允许你将张量(Tensor)中的元素值限制在一个指定的范围内。这个函数特别有用,比如在图像处理中调整像素值、在神经网络中防止梯度爆炸或消失时限制激活函数的输出等场景。

clamp函数的基本用法如下:

python 复制代码
torch.clamp(input, min, max) → Tensor

input:输入的张量。

min:元素值的下限。所有小于min的元素都将被设置为min。

max:元素值的上限。所有大于max的元素都将被设置为max。

如果min或max是None,则相应的边界将不被限制。例如,如果只指定了min而没有指定max,则所有小于min的元素会被设置为min,而大于min的元素则保持不变。

python 复制代码
import torch

# 创建一个张量
x = torch.tensor([-5.0, -2.0, 0.0, 3.0, 5.0, 7.0, 10.0])

# 使用clamp函数限制值在0到5之间
y = x.clamp( 0, 5)

print(y)
# 输出: tensor([0., 0., 0., 3., 5., 5., 5.])

七、torchvision.utils.make_grid

torchvision.utils.make_grid 是 PyTorch 中的一个非常有用的函数,它可以将多个图像(通常是 tensor 格式)拼接成一个网格(grid)图像。这个函数在展示或保存多张图像时特别有用,比如在训练过程中可视化模型输出的多个样本。

python 复制代码
torchvision.utils.make_grid(tensor, nrow=8, padding=2, normalize=False, range=None, scale_each=False, pad_value=0)

参数解释

tensor (Tensor or list of Tensors): 一个 4D mini-batch Tensor,形状为 (B, C, H, W),或者是一个包含这样的 Tensor 的列表。Tensor 中的图像应该位于 [0, 1] 范围内,如果是其他范围,则需要通过 normalize 参数进行归一化。

nrow (int, optional): 网格中每行的图像数量。默认值为 8。

padding (int, optional): 图像之间的填充大小。默认值为 2。

normalize (bool, optional): 如果为 True,则将所有图像缩放到 [0, 1] 范围内,并假设输入图像位于指定的 range 内。默认值为 False。

range (tuple, optional): 用于归一化的值范围(min, max)。只有当 normalize=True 时才有效。默认值为 (0, 1)。

scale_each (bool, optional): 如果为 True,则单独对每张图像进行缩放,而不是统一对整个 mini-batch 进行缩放。默认值为 False。

pad_value (float, optional): 用于填充的值。默认值为 0。

返回值

一个 3D Tensor,形状为 (C, H*(nrow + (nrow-1)padding), W(ncol + (ncol-1)*padding)),其中 C 是通道数,H 和 W 是原始图像的高度和宽度,nrow 和 ncol 是通过输入 Tensor 的批量大小自动计算的列数(ncol = ceil(B / nrow))。

八、 arr[..., ::-1]

在Python中,表达式arr[..., ::-1]通常用于NumPy数组或兼容NumPy索引的数组(如PyTorch张量)上,用于对数组进行切片操作。这个表达式的目的是沿着数组的最后一个维度(或轴)反转元素的顺序。

这里的...是NumPy引入的省略号(ellipsis),它用于表示在指定位置选择多个未明确指出的维度。在这个上下文中,...意味着"选择所有前面的维度",而::-1是一个切片操作,表示"从末尾开始到开头,步长为-1",即反转当前维度的元素顺序。

示例

假设我们有一个形状为(3, 4, 3)的三维NumPy数组,代表三个颜色通道(RGB)的图像批次,每个图像的大小为4x3像素。如果我们想要将每个图像的RGB通道顺序更改为BGR,我们可以使用arr[..., ::-1]来实现这一点。

python 复制代码
python
import numpy as np  
  
# 创建一个形状为(3, 4, 3)的随机数组,模拟RGB图像批次  
arr = np.random.randint(0, 256, size=(3, 4, 3), dtype=np.uint8)  
  
# 打印原始数组的形状和一部分内容  
print("Original shape:", arr.shape)  
print("Original array (first image):\n", arr[0])  
  
# 使用arr[..., ::-1]反转最后一个维度(颜色通道)  
arr_bgr = arr[..., ::-1]  
  
# 打印反转后数组的形状和一部分内容  
print("Modified shape:", arr_bgr.shape)  # 形状保持不变  
print("Modified array (first image, now in BGR):\n", arr_bgr[0])

在这个例子中,arr[..., ::-1]会保持数组的形状不变(因为我们只反转了最后一个维度),但是会改变最后一个维度的元素顺序,从而将RGB通道更改为BGR通道。

这种操作在处理图像数据时非常有用,因为不同的图像处理库和框架可能期望不同的颜色通道顺序。例如,OpenCV默认使用BGR顺序,而PIL(Python Imaging Library)和matplotlib则使用RGB顺序。

九、 yield

在Python中,yield 关键字用于从函数中返回一个生成器(generator)。生成器是一个可以记住上一次返回位置的对象,并在下一次迭代时从该位置继续执行。这使得它们非常适合用于需要逐个处理大量数据的场景,因为它们可以按需生成数据,从而节省内存。

当你调用一个包含 yield 的函数时,该函数不会立即执行其代码,而是返回一个迭代器(即生成器)。然后,你可以通过迭代这个生成器来逐步执行函数中的代码。每次迭代时,yield 语句会"暂停"函数的执行,并返回紧随其后的值给迭代器的调用者。当迭代器再次请求下一个值时,函数会从上次暂停的位置继续执行,直到遇到下一个 yield 语句或函数结束。

示例

下面是一个简单的使用 yield 的例子,该函数生成一个斐波那契数列(Fibonacci sequence):

python 复制代码
python
def fibonacci(n):  
    a, b = 0, 1  
    count = 0  
    while count < n:  
        yield a  
        a, b = b, a + b  
        count += 1  
for num in fibonacci(10):  
  print(num)

使用生成器

在这个例子中,fibonacci 函数是一个生成器函数,它使用 yield 来逐个返回斐波那契数列中的数。当我们使用 for 循环迭代 fibonacci(10) 时,函数会在每次迭代时执行到下一个 yield 语句,并返回当前的 a 值。当函数内部的状态(即 a 和 b 的值以及 count)被保存起来,并在下一次迭代时恢复,直到生成了 n 个数为止。

注意事项

使用 yield 的函数会返回一个生成器对象,而不是一次性返回所有值。

生成器只能迭代一次。一旦生成器迭代完成,它就不能再次从头开始迭代。

生成器非常适合用于实现迭代器协议,因为它们提供了惰性求值(lazy evaluation)的能力,即只有在需要时才计算值。

在生成器中,return 语句会立即停止迭代,但可以通过 return 语句返回一个值给迭代器的 StopIteration 异常(在Python 3.3及以后版本中,如果生成器因为 return 语句而终止,则 return 语句后的值(如果有的话)会被用作 StopIteration 异常的 value 属性)。如果生成器中没有 return 语句,或者 return 语句没有值,则迭代会在自然结束时停止。

相关推荐
skywalk816321 分钟前
copyparty 是一款使用单个 Python 文件实现的内网文件共享工具,具有跨平台、低资源占用等特点,适合需要本地化文件管理的场景
开发语言·python
BYSJMG26 分钟前
计算机毕设选题:基于Python+MySQL校园美食推荐系统【源码+文档+调试】
大数据·开发语言·python·mysql·django·课程设计·美食
FairyGirlhub2 小时前
神经网络的初始化:权重与偏置的数学策略
人工智能·深度学习·神经网络
大写-凌祁7 小时前
零基础入门深度学习:从理论到实战,GitHub+开源资源全指南(2025最新版)
人工智能·深度学习·开源·github
CodeCraft Studio7 小时前
PDF处理控件Aspose.PDF教程:使用 Python 将 PDF 转换为 Base64
开发语言·python·pdf·base64·aspose·aspose.pdf
wan5555cn7 小时前
多张图片生成视频模型技术深度解析
人工智能·笔记·深度学习·算法·音视频
格林威8 小时前
机器视觉检测的光源基础知识及光源选型
人工智能·深度学习·数码相机·yolo·计算机视觉·视觉检测
困鲲鲲8 小时前
Python中内置装饰器
python
摩羯座-185690305949 小时前
Python数据可视化基础:使用Matplotlib绘制图表
大数据·python·信息可视化·matplotlib
爱隐身的官人9 小时前
cfshow-web入门-php特性
python·php·ctf