torch常用函数

目录

一、 torch.bmm

torch.bmm 是 PyTorch 中的一个函数,用于执行批矩阵乘法( matrix multiplication)操作。

它的输入是三维张量,形状为 (batch, n, m) 和 (batch, m, p):

其中 n 是第一个矩阵的列数,m 是两个矩阵共享的维度,p 是第二个矩阵的列数。

torch.bmm 将批中的每对矩阵相乘,返回一个新的三维张量,形状为 (batch, n, p)。

二、 torch.einsum

torch.einsum是pytorch上的一个强大的函数,用于矩阵相关的计算,注意,这里没有限定为矩阵乘法。torch.einsum基于爱因斯坦求和约定执行张量操作,能够用简洁的表达式实现复杂的多维数组操作,从而避免繁琐的张量操作组合(如reshape、permute、bmm等),减少错误率。需要说明的是,尽管einsum函数内部进行了大量计算优化,但其主要优势在于表达式简洁,如果与单步reshape等pytorch实现的矩阵运算操作相比,其运算速度与内存占用不一定占优势。

1.矩阵乘法:'ij,jk->ik' 表示形状为(i,j)与形状为(j,k)的矩阵进行矩阵乘法,得到新矩阵形状为(i,k)。这也是torch.einsum最常规的用法。

2.维度调换:'ij->ji'表示形状为(i,j)的矩阵维度调换成为形状为(j,i)的矩阵。

torch.einsum还有多种用法,遇到再来添加

三、python中变量前面有个*

在Python中,变量前面的星号(*)有多种用法,主要与函数参数或解包序列有关。

1、在函数参数中,星号(*)用来表示任意多个参数,这些参数会被当作元组传递。例如:

python 复制代码
def fun(*args):
    for i in args:
        print(i)
 
fun(1, 2, 3, 4)

2、在函数参数中,星号(*)还可以用来解包序列。例如:

python 复制代码
def fun(a, b, c, d):
    print(a, b, c, d)
 
args = (1, 2, 3, 4)
fun(*args)

3、在函数参数中,星号(*)还可以与命名参数,或者字典一起使用。例如:

python 复制代码
def fun(*args, a=1):
    print(args, a)
 
fun(1, 2, 3, a=4)

def fun(*args, **kwargs):
    print(args, kwargs)
 
fun(1, 2, 3, a=4, b=5)

4、 在解包列表或元组时,星号(*)也可以用来解包选定项。例如:

python 复制代码
lst = [1, 2, 3, 4, 5]
a, *b, c = lst
print(a, b, c)

四、numpy.prod

计算元素和

python 复制代码
print(np.prod([[1., 2.], [3., 4.]], axis=0))按列计算元素和
print(np.prod([[1., 2.], [3., 4.]], axis=1))按行计算元素和
print(np.prod([[1., 2.], [3., 4.]], axis=0))计算所有元素和

五、torch.chunk

对于一个输入tensor,torch.chunk方法会按照dim指定的维度将输入tensor划分为若干个chunk,划分的数量为chunks。

python 复制代码
torch.chunk(input, chunks, dim=0) 

temp=torch.randn((4,6))
print(torch.chunk(temp,2,0))行方向分块
print(torch.chunk(temp,2,1))列方向分块

六、torch.contiguous

在PyTorch中,.contiguous()方法的作用是确保张量在内存中是连续存储的。当你对张量执行某些操作,如transpose()、permute()、narrow()、expand()等之后,得到的张量可能不再在内存中连续排列。这些操作通常返回一个张量的视图,它们改变的是数据访问的方式,而不是实际的数据存储方式。

在内存中连续排列的张量有一个特性:对于张量中任意两个相邻的元素,它们在物理内存中的位置也是相邻的。换句话说,张量在物理存储上的排列顺序与在张量形式上的逻辑排列顺序一致。

当调用.contiguous()时,如果张量已经是连续的,这个函数实际上不会做任何事;但如果不是,PyTorch将会重新分配内存并确保张量的数据连续排列。这涉及到复制数据到新的内存区域,并返回一个新的张量,该张量在内存中实际是连续的。

调用view之前最好先contiguous,也就是x.contiguous().view()

python 复制代码
import torch

# 创建一个非连续张量
x = torch.arange(12).view(3, 4).permute(1,0)  # 移动维度
print(x.is_contiguous())  # False

# 使用 .contiguous() 来确保张量是连续的
y = x.contiguous()
print(y.is_contiguous())  # True

六、torch.clamp

在PyTorch中,clamp函数是一个非常实用的操作,它允许你将张量(Tensor)中的元素值限制在一个指定的范围内。这个函数特别有用,比如在图像处理中调整像素值、在神经网络中防止梯度爆炸或消失时限制激活函数的输出等场景。

clamp函数的基本用法如下:

python 复制代码
torch.clamp(input, min, max) → Tensor

input:输入的张量。

min:元素值的下限。所有小于min的元素都将被设置为min。

max:元素值的上限。所有大于max的元素都将被设置为max。

如果min或max是None,则相应的边界将不被限制。例如,如果只指定了min而没有指定max,则所有小于min的元素会被设置为min,而大于min的元素则保持不变。

python 复制代码
import torch

# 创建一个张量
x = torch.tensor([-5.0, -2.0, 0.0, 3.0, 5.0, 7.0, 10.0])

# 使用clamp函数限制值在0到5之间
y = x.clamp( 0, 5)

print(y)
# 输出: tensor([0., 0., 0., 3., 5., 5., 5.])

七、torchvision.utils.make_grid

torchvision.utils.make_grid 是 PyTorch 中的一个非常有用的函数,它可以将多个图像(通常是 tensor 格式)拼接成一个网格(grid)图像。这个函数在展示或保存多张图像时特别有用,比如在训练过程中可视化模型输出的多个样本。

python 复制代码
torchvision.utils.make_grid(tensor, nrow=8, padding=2, normalize=False, range=None, scale_each=False, pad_value=0)

参数解释

tensor (Tensor or list of Tensors): 一个 4D mini-batch Tensor,形状为 (B, C, H, W),或者是一个包含这样的 Tensor 的列表。Tensor 中的图像应该位于 [0, 1] 范围内,如果是其他范围,则需要通过 normalize 参数进行归一化。

nrow (int, optional): 网格中每行的图像数量。默认值为 8。

padding (int, optional): 图像之间的填充大小。默认值为 2。

normalize (bool, optional): 如果为 True,则将所有图像缩放到 [0, 1] 范围内,并假设输入图像位于指定的 range 内。默认值为 False。

range (tuple, optional): 用于归一化的值范围(min, max)。只有当 normalize=True 时才有效。默认值为 (0, 1)。

scale_each (bool, optional): 如果为 True,则单独对每张图像进行缩放,而不是统一对整个 mini-batch 进行缩放。默认值为 False。

pad_value (float, optional): 用于填充的值。默认值为 0。

返回值

一个 3D Tensor,形状为 (C, H*(nrow + (nrow-1)padding), W(ncol + (ncol-1)*padding)),其中 C 是通道数,H 和 W 是原始图像的高度和宽度,nrow 和 ncol 是通过输入 Tensor 的批量大小自动计算的列数(ncol = ceil(B / nrow))。

八、 arr[..., ::-1]

在Python中,表达式arr[..., ::-1]通常用于NumPy数组或兼容NumPy索引的数组(如PyTorch张量)上,用于对数组进行切片操作。这个表达式的目的是沿着数组的最后一个维度(或轴)反转元素的顺序。

这里的...是NumPy引入的省略号(ellipsis),它用于表示在指定位置选择多个未明确指出的维度。在这个上下文中,...意味着"选择所有前面的维度",而::-1是一个切片操作,表示"从末尾开始到开头,步长为-1",即反转当前维度的元素顺序。

示例

假设我们有一个形状为(3, 4, 3)的三维NumPy数组,代表三个颜色通道(RGB)的图像批次,每个图像的大小为4x3像素。如果我们想要将每个图像的RGB通道顺序更改为BGR,我们可以使用arr[..., ::-1]来实现这一点。

python 复制代码
python
import numpy as np  
  
# 创建一个形状为(3, 4, 3)的随机数组,模拟RGB图像批次  
arr = np.random.randint(0, 256, size=(3, 4, 3), dtype=np.uint8)  
  
# 打印原始数组的形状和一部分内容  
print("Original shape:", arr.shape)  
print("Original array (first image):\n", arr[0])  
  
# 使用arr[..., ::-1]反转最后一个维度(颜色通道)  
arr_bgr = arr[..., ::-1]  
  
# 打印反转后数组的形状和一部分内容  
print("Modified shape:", arr_bgr.shape)  # 形状保持不变  
print("Modified array (first image, now in BGR):\n", arr_bgr[0])

在这个例子中,arr[..., ::-1]会保持数组的形状不变(因为我们只反转了最后一个维度),但是会改变最后一个维度的元素顺序,从而将RGB通道更改为BGR通道。

这种操作在处理图像数据时非常有用,因为不同的图像处理库和框架可能期望不同的颜色通道顺序。例如,OpenCV默认使用BGR顺序,而PIL(Python Imaging Library)和matplotlib则使用RGB顺序。

九、 yield

在Python中,yield 关键字用于从函数中返回一个生成器(generator)。生成器是一个可以记住上一次返回位置的对象,并在下一次迭代时从该位置继续执行。这使得它们非常适合用于需要逐个处理大量数据的场景,因为它们可以按需生成数据,从而节省内存。

当你调用一个包含 yield 的函数时,该函数不会立即执行其代码,而是返回一个迭代器(即生成器)。然后,你可以通过迭代这个生成器来逐步执行函数中的代码。每次迭代时,yield 语句会"暂停"函数的执行,并返回紧随其后的值给迭代器的调用者。当迭代器再次请求下一个值时,函数会从上次暂停的位置继续执行,直到遇到下一个 yield 语句或函数结束。

示例

下面是一个简单的使用 yield 的例子,该函数生成一个斐波那契数列(Fibonacci sequence):

python 复制代码
python
def fibonacci(n):  
    a, b = 0, 1  
    count = 0  
    while count < n:  
        yield a  
        a, b = b, a + b  
        count += 1  
for num in fibonacci(10):  
  print(num)

使用生成器

在这个例子中,fibonacci 函数是一个生成器函数,它使用 yield 来逐个返回斐波那契数列中的数。当我们使用 for 循环迭代 fibonacci(10) 时,函数会在每次迭代时执行到下一个 yield 语句,并返回当前的 a 值。当函数内部的状态(即 a 和 b 的值以及 count)被保存起来,并在下一次迭代时恢复,直到生成了 n 个数为止。

注意事项

使用 yield 的函数会返回一个生成器对象,而不是一次性返回所有值。

生成器只能迭代一次。一旦生成器迭代完成,它就不能再次从头开始迭代。

生成器非常适合用于实现迭代器协议,因为它们提供了惰性求值(lazy evaluation)的能力,即只有在需要时才计算值。

在生成器中,return 语句会立即停止迭代,但可以通过 return 语句返回一个值给迭代器的 StopIteration 异常(在Python 3.3及以后版本中,如果生成器因为 return 语句而终止,则 return 语句后的值(如果有的话)会被用作 StopIteration 异常的 value 属性)。如果生成器中没有 return 语句,或者 return 语句没有值,则迭代会在自然结束时停止。

相关推荐
小毕超10 分钟前
基于 PyTorch 从零手搓一个GPT Transformer 对话大模型
pytorch·gpt·transformer
denghai邓海18 分钟前
红黑树删除之向上调整
python·b+树
千天夜37 分钟前
激活函数解析:神经网络背后的“驱动力”
人工智能·深度学习·神经网络
封步宇AIGC43 分钟前
量化交易系统开发-实时行情自动化交易-3.4.1.2.A股交易数据
人工智能·python·机器学习·数据挖掘
何曾参静谧44 分钟前
「Py」Python基础篇 之 Python都可以做哪些自动化?
开发语言·python·自动化
m0_523674211 小时前
技术前沿:从强化学习到Prompt Engineering,业务流程管理的创新之路
人工智能·深度学习·目标检测·机器学习·语言模型·自然语言处理·数据挖掘
Prejudices1 小时前
C++如何调用Python脚本
开发语言·c++·python
我狠狠地刷刷刷刷刷1 小时前
中文分词模拟器
开发语言·python·算法
Jam-Young1 小时前
Python的装饰器
开发语言·python
Mr.咕咕1 小时前
Django 搭建数据管理web——商品管理
前端·python·django