torch常用函数

目录

一、 torch.bmm

torch.bmm 是 PyTorch 中的一个函数,用于执行批矩阵乘法( matrix multiplication)操作。

它的输入是三维张量,形状为 (batch, n, m) 和 (batch, m, p):

其中 n 是第一个矩阵的列数,m 是两个矩阵共享的维度,p 是第二个矩阵的列数。

torch.bmm 将批中的每对矩阵相乘,返回一个新的三维张量,形状为 (batch, n, p)。

二、 torch.einsum

torch.einsum是pytorch上的一个强大的函数,用于矩阵相关的计算,注意,这里没有限定为矩阵乘法。torch.einsum基于爱因斯坦求和约定执行张量操作,能够用简洁的表达式实现复杂的多维数组操作,从而避免繁琐的张量操作组合(如reshape、permute、bmm等),减少错误率。需要说明的是,尽管einsum函数内部进行了大量计算优化,但其主要优势在于表达式简洁,如果与单步reshape等pytorch实现的矩阵运算操作相比,其运算速度与内存占用不一定占优势。

1.矩阵乘法:'ij,jk->ik' 表示形状为(i,j)与形状为(j,k)的矩阵进行矩阵乘法,得到新矩阵形状为(i,k)。这也是torch.einsum最常规的用法。

2.维度调换:'ij->ji'表示形状为(i,j)的矩阵维度调换成为形状为(j,i)的矩阵。

torch.einsum还有多种用法,遇到再来添加

三、python中变量前面有个*

在Python中,变量前面的星号(*)有多种用法,主要与函数参数或解包序列有关。

1、在函数参数中,星号(*)用来表示任意多个参数,这些参数会被当作元组传递。例如:

python 复制代码
def fun(*args):
    for i in args:
        print(i)
 
fun(1, 2, 3, 4)

2、在函数参数中,星号(*)还可以用来解包序列。例如:

python 复制代码
def fun(a, b, c, d):
    print(a, b, c, d)
 
args = (1, 2, 3, 4)
fun(*args)

3、在函数参数中,星号(*)还可以与命名参数,或者字典一起使用。例如:

python 复制代码
def fun(*args, a=1):
    print(args, a)
 
fun(1, 2, 3, a=4)

def fun(*args, **kwargs):
    print(args, kwargs)
 
fun(1, 2, 3, a=4, b=5)

4、 在解包列表或元组时,星号(*)也可以用来解包选定项。例如:

python 复制代码
lst = [1, 2, 3, 4, 5]
a, *b, c = lst
print(a, b, c)

四、numpy.prod

计算元素和

python 复制代码
print(np.prod([[1., 2.], [3., 4.]], axis=0))按列计算元素和
print(np.prod([[1., 2.], [3., 4.]], axis=1))按行计算元素和
print(np.prod([[1., 2.], [3., 4.]], axis=0))计算所有元素和

五、torch.chunk

对于一个输入tensor,torch.chunk方法会按照dim指定的维度将输入tensor划分为若干个chunk,划分的数量为chunks。

python 复制代码
torch.chunk(input, chunks, dim=0) 

temp=torch.randn((4,6))
print(torch.chunk(temp,2,0))行方向分块
print(torch.chunk(temp,2,1))列方向分块

六、torch.contiguous

在PyTorch中,.contiguous()方法的作用是确保张量在内存中是连续存储的。当你对张量执行某些操作,如transpose()、permute()、narrow()、expand()等之后,得到的张量可能不再在内存中连续排列。这些操作通常返回一个张量的视图,它们改变的是数据访问的方式,而不是实际的数据存储方式。

在内存中连续排列的张量有一个特性:对于张量中任意两个相邻的元素,它们在物理内存中的位置也是相邻的。换句话说,张量在物理存储上的排列顺序与在张量形式上的逻辑排列顺序一致。

当调用.contiguous()时,如果张量已经是连续的,这个函数实际上不会做任何事;但如果不是,PyTorch将会重新分配内存并确保张量的数据连续排列。这涉及到复制数据到新的内存区域,并返回一个新的张量,该张量在内存中实际是连续的。

调用view之前最好先contiguous,也就是x.contiguous().view()

python 复制代码
import torch

# 创建一个非连续张量
x = torch.arange(12).view(3, 4).permute(1,0)  # 移动维度
print(x.is_contiguous())  # False

# 使用 .contiguous() 来确保张量是连续的
y = x.contiguous()
print(y.is_contiguous())  # True

六、torch.clamp

在PyTorch中,clamp函数是一个非常实用的操作,它允许你将张量(Tensor)中的元素值限制在一个指定的范围内。这个函数特别有用,比如在图像处理中调整像素值、在神经网络中防止梯度爆炸或消失时限制激活函数的输出等场景。

clamp函数的基本用法如下:

python 复制代码
torch.clamp(input, min, max) → Tensor

input:输入的张量。

min:元素值的下限。所有小于min的元素都将被设置为min。

max:元素值的上限。所有大于max的元素都将被设置为max。

如果min或max是None,则相应的边界将不被限制。例如,如果只指定了min而没有指定max,则所有小于min的元素会被设置为min,而大于min的元素则保持不变。

python 复制代码
import torch

# 创建一个张量
x = torch.tensor([-5.0, -2.0, 0.0, 3.0, 5.0, 7.0, 10.0])

# 使用clamp函数限制值在0到5之间
y = x.clamp( 0, 5)

print(y)
# 输出: tensor([0., 0., 0., 3., 5., 5., 5.])

七、torchvision.utils.make_grid

torchvision.utils.make_grid 是 PyTorch 中的一个非常有用的函数,它可以将多个图像(通常是 tensor 格式)拼接成一个网格(grid)图像。这个函数在展示或保存多张图像时特别有用,比如在训练过程中可视化模型输出的多个样本。

python 复制代码
torchvision.utils.make_grid(tensor, nrow=8, padding=2, normalize=False, range=None, scale_each=False, pad_value=0)

参数解释

tensor (Tensor or list of Tensors): 一个 4D mini-batch Tensor,形状为 (B, C, H, W),或者是一个包含这样的 Tensor 的列表。Tensor 中的图像应该位于 [0, 1] 范围内,如果是其他范围,则需要通过 normalize 参数进行归一化。

nrow (int, optional): 网格中每行的图像数量。默认值为 8。

padding (int, optional): 图像之间的填充大小。默认值为 2。

normalize (bool, optional): 如果为 True,则将所有图像缩放到 [0, 1] 范围内,并假设输入图像位于指定的 range 内。默认值为 False。

range (tuple, optional): 用于归一化的值范围(min, max)。只有当 normalize=True 时才有效。默认值为 (0, 1)。

scale_each (bool, optional): 如果为 True,则单独对每张图像进行缩放,而不是统一对整个 mini-batch 进行缩放。默认值为 False。

pad_value (float, optional): 用于填充的值。默认值为 0。

返回值

一个 3D Tensor,形状为 (C, H*(nrow + (nrow-1)padding), W(ncol + (ncol-1)*padding)),其中 C 是通道数,H 和 W 是原始图像的高度和宽度,nrow 和 ncol 是通过输入 Tensor 的批量大小自动计算的列数(ncol = ceil(B / nrow))。

八、 arr[..., ::-1]

在Python中,表达式arr[..., ::-1]通常用于NumPy数组或兼容NumPy索引的数组(如PyTorch张量)上,用于对数组进行切片操作。这个表达式的目的是沿着数组的最后一个维度(或轴)反转元素的顺序。

这里的...是NumPy引入的省略号(ellipsis),它用于表示在指定位置选择多个未明确指出的维度。在这个上下文中,...意味着"选择所有前面的维度",而::-1是一个切片操作,表示"从末尾开始到开头,步长为-1",即反转当前维度的元素顺序。

示例

假设我们有一个形状为(3, 4, 3)的三维NumPy数组,代表三个颜色通道(RGB)的图像批次,每个图像的大小为4x3像素。如果我们想要将每个图像的RGB通道顺序更改为BGR,我们可以使用arr[..., ::-1]来实现这一点。

python 复制代码
python
import numpy as np  
  
# 创建一个形状为(3, 4, 3)的随机数组,模拟RGB图像批次  
arr = np.random.randint(0, 256, size=(3, 4, 3), dtype=np.uint8)  
  
# 打印原始数组的形状和一部分内容  
print("Original shape:", arr.shape)  
print("Original array (first image):\n", arr[0])  
  
# 使用arr[..., ::-1]反转最后一个维度(颜色通道)  
arr_bgr = arr[..., ::-1]  
  
# 打印反转后数组的形状和一部分内容  
print("Modified shape:", arr_bgr.shape)  # 形状保持不变  
print("Modified array (first image, now in BGR):\n", arr_bgr[0])

在这个例子中,arr[..., ::-1]会保持数组的形状不变(因为我们只反转了最后一个维度),但是会改变最后一个维度的元素顺序,从而将RGB通道更改为BGR通道。

这种操作在处理图像数据时非常有用,因为不同的图像处理库和框架可能期望不同的颜色通道顺序。例如,OpenCV默认使用BGR顺序,而PIL(Python Imaging Library)和matplotlib则使用RGB顺序。

九、 yield

在Python中,yield 关键字用于从函数中返回一个生成器(generator)。生成器是一个可以记住上一次返回位置的对象,并在下一次迭代时从该位置继续执行。这使得它们非常适合用于需要逐个处理大量数据的场景,因为它们可以按需生成数据,从而节省内存。

当你调用一个包含 yield 的函数时,该函数不会立即执行其代码,而是返回一个迭代器(即生成器)。然后,你可以通过迭代这个生成器来逐步执行函数中的代码。每次迭代时,yield 语句会"暂停"函数的执行,并返回紧随其后的值给迭代器的调用者。当迭代器再次请求下一个值时,函数会从上次暂停的位置继续执行,直到遇到下一个 yield 语句或函数结束。

示例

下面是一个简单的使用 yield 的例子,该函数生成一个斐波那契数列(Fibonacci sequence):

python 复制代码
python
def fibonacci(n):  
    a, b = 0, 1  
    count = 0  
    while count < n:  
        yield a  
        a, b = b, a + b  
        count += 1  
for num in fibonacci(10):  
  print(num)

使用生成器

在这个例子中,fibonacci 函数是一个生成器函数,它使用 yield 来逐个返回斐波那契数列中的数。当我们使用 for 循环迭代 fibonacci(10) 时,函数会在每次迭代时执行到下一个 yield 语句,并返回当前的 a 值。当函数内部的状态(即 a 和 b 的值以及 count)被保存起来,并在下一次迭代时恢复,直到生成了 n 个数为止。

注意事项

使用 yield 的函数会返回一个生成器对象,而不是一次性返回所有值。

生成器只能迭代一次。一旦生成器迭代完成,它就不能再次从头开始迭代。

生成器非常适合用于实现迭代器协议,因为它们提供了惰性求值(lazy evaluation)的能力,即只有在需要时才计算值。

在生成器中,return 语句会立即停止迭代,但可以通过 return 语句返回一个值给迭代器的 StopIteration 异常(在Python 3.3及以后版本中,如果生成器因为 return 语句而终止,则 return 语句后的值(如果有的话)会被用作 StopIteration 异常的 value 属性)。如果生成器中没有 return 语句,或者 return 语句没有值,则迭代会在自然结束时停止。

相关推荐
HPC_fac1305206781623 分钟前
科研深度学习:如何精选GPU以优化服务器性能
服务器·人工智能·深度学习·神经网络·机器学习·数据挖掘·gpu算力
猎嘤一号1 小时前
个人笔记本安装CUDA并配合Pytorch使用NVIDIA GPU训练神经网络的计算以及CPUvsGPU计算时间的测试代码
人工智能·pytorch·神经网络
湫ccc7 小时前
《Python基础》之字符串格式化输出
开发语言·python
mqiqe7 小时前
Python MySQL通过Binlog 获取变更记录 恢复数据
开发语言·python·mysql
AttackingLin7 小时前
2024强网杯--babyheap house of apple2解法
linux·开发语言·python
哭泣的眼泪4088 小时前
解析粗糙度仪在工业制造及材料科学和建筑工程领域的重要性
python·算法·django·virtualenv·pygame
湫ccc8 小时前
《Python基础》之基本数据类型
开发语言·python
余炜yw9 小时前
【LSTM实战】跨越千年,赋诗成文:用LSTM重现唐诗的韵律与情感
人工智能·rnn·深度学习
drebander9 小时前
使用 Java Stream 优雅实现List 转化为Map<key,Map<key,value>>
java·python·list
莫叫石榴姐9 小时前
数据科学与SQL:组距分组分析 | 区间分布问题
大数据·人工智能·sql·深度学习·算法·机器学习·数据挖掘