-
torch.squeeze(input, dim=None)将给定的
input这个 tensor 中,大小为 1 的 dim 全部压缩。如下例子:
pyimport torch t = torch.tensor([[1], [2], [3]]) print(t) # tensor([[1], [2], [3]]) shape=(3,1) t = torch.squeeze(t) print(t) # tensor([1, 2, 3]) shape=(3,) -
torch.unsqueeze(input, dim)将给定的
input这个 tensor 中,指定的 dim 扩充一维如下例子:
pyimport torch t = torch.tensor([1, 2, 3]) print(torch.unsqueeze(t, 0)) # tensor([[1, 2, 3]]) shape=(1,3) print(torch.unsqueeze(t, 1)) # tensor([[1], [2], [3]]) shape=(3,1) -
torch.index_select(input, dim, index, *, out=None)在给定的
input这个 tensor 中,选择维度 dim ,然后在这个维度中选择索引 index 的部分返回。如下例子:
pyimport torch t = torch.arange(1, 13).view(3, 4) print(t) """ tensor([[ 1, 2, 3, 4], [ 5, 6, 7, 8], [ 9, 10, 11, 12]]) shape=(3, 4) """ indices = torch.tensor([0, 2]) print(torch.index_select(t, 0, indices)) """ tensor([[ 1, 2, 3, 4], [ 9, 10, 11, 12]]) 选择 dim=0 ,有 (0, 1, 2) 三个,选择第 0 行和第 2 行 """ print(torch.index_select(t, 1, indices)) """ tensor([[ 1, 3], [ 5, 7], [ 9, 11]]) 选择 dim=1 ,有 (0, 1, 2, 3) 四个,选择第 0 列和第 4 列 """ -
torch.norm(input, p='fro', dim=None, keepdim=False, out=None, dtype=None)将给定的
input这个 tensor 按照给定的 dim 计算范数,具体计算的是什么范数由 p 决定p=1 表示第一范数,即 tensor 中每个元素绝对值之和
p=2 表示第二范数,即 tensor 中每个元素平方和的和,再开根号
其他表示无穷范数,即 tensor 中绝对值最大的元素
pyimport torch """ inputs.shape = (3, 3, 4) """ inputs = torch.tensor([[[ 1., 2., 3., 4.], [ 2., 4., 6., 8.], [ 3., 6., 9., 12.]], [[ 1., 2., 3., 4.], [ 2., 4., 6., 8.], [ 3., 6., 9., 12.]], [[ 1., 2., 3., 4.], [ 2., 4., 6., 8.], [ 3., 6., 9., 12.]]]) """ inputs1.shape = (1, 3, 4) 对于 dim=0 进行 L2 范数的计算,就是考虑将 (i, j, k) 其中所有的 i 的元素平方和加起来再开根号 这里 sqrt((0, 0, 0)^2 + (1, 0, 0)^2 + (2, 0, 0)^2) = sqrt(3) = 1.7321 tensor([[[ 1.7321, 3.4641, 5.1962, 6.9282], [ 3.4641, 6.9282, 10.3923, 13.8564], [ 5.1962, 10.3923, 15.5885, 20.7846]]]) """ inputs1 = torch.norm(inputs, p=2, dim=0, keepdim=True) print(inputs1) """ inputs2.shape = (3, 1, 4) 对于 dim=1 进行 L2 范数的计算,就是考虑将 (i, j, k) 其中所有的 j 的元素平方和加起来再开根号 这里 sqrt((0, 0, 0)^2 + (0, 1, 0)^2 + (0, 2, 0)^2) = sqrt(1+4+9) = 3.7417 tensor([[[ 3.7417, 7.4833, 11.2250, 14.9666]], [[ 3.7417, 7.4833, 11.2250, 14.9666]], [[ 3.7417, 7.4833, 11.2250, 14.9666]]]) """ inputs2 = torch.norm(inputs, p=2, dim=1, keepdim=True) print(inputs2) """ inputs3.shape = (3, 3, 1) 对于 dim=2 进行 L2 范数的计算,就是考虑将 (i, j, k) 其中所有的 k 的元素平方和加起来再开根号 这里 sqrt((0, 0, 0)^2+(0, 0, 1)^2+(0, 0, 2)^2+(0, 0, 3)^2) = sqrt(1+4+9+16) = 5.4772 tensor([[[ 5.4772], [10.9545], [16.4317]], [[ 5.4772], [10.9545], [16.4317]], [[ 5.4772], [10.9545], [16.4317]]]) """ inputs3 = torch.norm(inputs, p=2, dim=2, keepdim=True) print(inputs3) -
torch.chunk(input, chunks, dim=0) → List of Tensors将
input这个tensor分成 chunks 个 tensors ,按照 dim 来划分。pyimport torch t = torch.arange(1, 28).view(3, 3, 3) """ tensor([[[ 1, 2, 3], [ 4, 5, 6], [ 7, 8, 9]], [[10, 11, 12], [13, 14, 15], [16, 17, 18]], [[19, 20, 21], [22, 23, 24], [25, 26, 27]]]) shape = (3, 3, 3) """ print(t) """ 按照 dim=0 划分,那么划分结果为 (tensor([[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]), tensor([[[10, 11, 12], [13, 14, 15], [16, 17, 18]]]), tensor([[[19, 20, 21], [22, 23, 24], [25, 26, 27]]])) """ print(torch.chunk(t, chunks=3, dim=0)) """ 按照 dim=1 划分,那么划分结果为 (tensor([[[ 1, 2, 3]], [[10, 11, 12]], [[19, 20, 21]]]), tensor([[[ 4, 5, 6]], [[13, 14, 15]], [[22, 23, 24]]]), tensor([[[ 7, 8, 9]], [[16, 17, 18]], [[25, 26, 27]]])) """ print(torch.chunk(t, chunks=3, dim=1)) """ 按照 dim=2 划分,那么划分结果为 (tensor([[[ 1], [ 4], [ 7]], [[10], [13], [16]], [[19], [22], [25]]]), tensor([[[ 2], [ 5], [ 8]], [[11], [14], [17]], [[20], [23], [26]]]), tensor([[[ 3], [ 6], [ 9]], [[12], [15], [18]], [[21], [24], [27]]])) """ print(torch.chunk(t, chunks=3, dim=2))
Pytorch API
solego2023-10-08 23:11
相关推荐
nn在炼金10 分钟前
FlashAttention 1 深度解读:原理、价值、应用与实战沐雪轻挽萤10 分钟前
pytorch模型部署基础知识极客BIM工作室15 分钟前
从GAN到Sora:生成式AI在图像与视频领域的技术演进全景nix.gnehc16 分钟前
PyTorch数据加载与预处理skywalk816318 分钟前
用Trae的sole模式来模拟文心快码comate的Spec Mode模式来做一个esp32操作系统的项目*星星之火*24 分钟前
【大白话 AI 答疑】第5篇 从 “窄域专精” 到 “广谱通用”:传统机器学习与大模型的 6 大核心区别roman_日积跬步-终至千里24 分钟前
【模式识别与机器学习(7)】主要算法与技术(下篇:高级模型与集成方法)之 扩展线性模型(Extending Linear Models)张飞签名上架25 分钟前
苹果TF签名:革新应用分发的解决方案xcLeigh27 分钟前
AI 绘制图表专栏:用豆包轻松实现 HTML 柱状图、折线图与饼图玖日大大30 分钟前
LongCat-Flash-Omni:5600 亿参数开源全模态模型的技术革命与产业实践