tensor向量按任意维度进行切片、拆分、组合

torch.index_select(input_tensor, 切片维度, 切片索引)

注意:切完之后,转onnx时会生成Gather节点;

torch自带切片操作

start : end : step:

范围前闭后开,将其放在哪个维度上,就对那个维度起作用

torch.cat((a, b) , dim)

已有的轴上拼接 矩阵,默认轴为0,给定轴的维度可以不同,其余轴的维度必须相同

三个操作的组合使用例子如下:

python 复制代码
import torch

x = torch.randn(1, 18, 4, 4)

# print("x:",x)
print("x.shape:",x.shape)

indices_cls = torch.tensor([2, 5, 8, 11, 14, 17])
indices_point = torch.tensor([0,1, 3,4, 6,7, 9,10, 12,13, 15,16])

kpt_point = torch.index_select(x, 1, indices_point)
kpt_cls = torch.index_select(x, 1, indices_cls)

print("kpt_point.shape:",kpt_point.shape)
print("kpt_cls.shape:",kpt_cls.shape)


x_2 = torch.cat([kpt_point[:,0:2:1,],kpt_cls[:,0:1:1,],kpt_point[:,2:4:1,],kpt_cls[:,1:2:1,],kpt_point[:,4:6:1,],kpt_cls[:,2:3:1,],
            kpt_point[:,6:8:1,],kpt_cls[:,3:4:1,],kpt_point[:,8:10:1,],kpt_cls[:,4:5:1,],kpt_point[:,10:12:1,],kpt_cls[:,5:6:1,]],1)

# print("x_2:",x_2)
print("x_2.shape:",x_2.shape)

打印组合前后tensor的输出形状和内容发现,前后一致:

python 复制代码
x.shape: torch.Size([1, 18, 4, 4])
kpt_point.shape: torch.Size([1, 12, 4, 4])
kpt_cls.shape: torch.Size([1, 6, 4, 4])   
x_2.shape: torch.Size([1, 18, 4, 4])   
相关推荐
databook3 小时前
Manim实现闪光轨迹特效
后端·python·动效
Juchecar4 小时前
解惑:NumPy 中 ndarray.ndim 到底是什么?
python
用户8356290780514 小时前
Python 删除 Excel 工作表中的空白行列
后端·python
Json_4 小时前
使用python-fastApi框架开发一个学校宿舍管理系统-前后端分离项目
后端·python·fastapi
数据智能老司机11 小时前
精通 Python 设计模式——分布式系统模式
python·设计模式·架构
数据智能老司机12 小时前
精通 Python 设计模式——并发与异步模式
python·设计模式·编程语言
数据智能老司机12 小时前
精通 Python 设计模式——测试模式
python·设计模式·架构
数据智能老司机12 小时前
精通 Python 设计模式——性能模式
python·设计模式·架构
c8i12 小时前
drf初步梳理
python·django
每日AI新事件12 小时前
python的异步函数
python