pytorch(3d、4d张量转换)

维度转换

import torch

from einops import rearrange

print(torch.cuda.is_available())

to_3d

把四维的张量转换为三维的张量,输入形状(b,c,h,w),输出形状(b,hw,c)

def to_3d(x):

return rearrange(x, 'b c h w -> b (h w) c')

to_4d

把三维的张量转换为四维的张量,输入形状(b,hw,c),输出形状(b,c,h,w)

def to_4d(x,h,w):

return rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)

测试

if name == 'main':

创建一个四维张量

tensor_4d = torch.randn(2, 3, 4, 5) # 形状(批大小2, 通道数3, 高度4, 宽度5)

转换为三维张量

tensor_3d = to_3d(tensor_4d) # 形状(批大小2, 20, 3)

转换为四维张量

height, width = 4, 5

tensor_4d_back = to_4d(tensor_3d, height, width) # 形状(批大小2, 通道数3, 高度4, 宽度5)

print(tensor_4d.shape)

print(tensor_3d.shape) # 输出:torch.Size([2, 20, 3])

print(tensor_4d_back.shape) # 输出:torch.Size([2, 3, 4, 5])

相关推荐
m0_678485452 分钟前
CSS如何控制表格单元格边框合并_通过border-collapse实现
jvm·数据库·python
m0_748839495 分钟前
如何用组合继承模式实现父类方法复用与子类属性独立
jvm·数据库·python
qq_3345635514 分钟前
PHP源码是否依赖特定芯片组_Intel与AMD平台差异【操作】
jvm·数据库·python
qq_206901391 小时前
如何使用C#调用Oracle存储过程_OracleCommand配置CommandType.StoredProcedure
jvm·数据库·python
m0_748839491 小时前
CSS如何实现元素平滑滚动_使用scroll-behavior属性设置
jvm·数据库·python
橙露1 小时前
特征选择实战:方差、卡方、互信息法筛选有效特征
人工智能·深度学习·机器学习
Victoria.a1 小时前
python基础语法
开发语言·python
xiaotao1312 小时前
01-编程基础与数学基石: Python核心数据结构完全指南
数据结构·人工智能·windows·python
青苔猿猿2 小时前
【1】JupyterLab安装
python·jupyter
xiaoyaohou112 小时前
023、数据增强改进(二):自适应数据增强与AutoAugment策略
开发语言·python