pytorch(3d、4d张量转换)

维度转换

import torch

from einops import rearrange

print(torch.cuda.is_available())

to_3d

把四维的张量转换为三维的张量,输入形状(b,c,h,w),输出形状(b,hw,c)

def to_3d(x):

return rearrange(x, 'b c h w -> b (h w) c')

to_4d

把三维的张量转换为四维的张量,输入形状(b,hw,c),输出形状(b,c,h,w)

def to_4d(x,h,w):

return rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)

测试

if name == 'main':

创建一个四维张量

tensor_4d = torch.randn(2, 3, 4, 5) # 形状(批大小2, 通道数3, 高度4, 宽度5)

转换为三维张量

tensor_3d = to_3d(tensor_4d) # 形状(批大小2, 20, 3)

转换为四维张量

height, width = 4, 5

tensor_4d_back = to_4d(tensor_3d, height, width) # 形状(批大小2, 通道数3, 高度4, 宽度5)

print(tensor_4d.shape)

print(tensor_3d.shape) # 输出:torch.Size([2, 20, 3])

print(tensor_4d_back.shape) # 输出:torch.Size([2, 3, 4, 5])

相关推荐
Jay_Franklin10 分钟前
SRIM通过python计算dap
开发语言·python
是一个Bug19 分钟前
Java基础50道经典面试题(四)
java·windows·python
吴佳浩33 分钟前
Python入门指南(七) - YOLO检测API进阶实战
人工智能·后端·python
liliangcsdn1 小时前
python下载并转存http文件链接的示例
开发语言·python
大、男人2 小时前
python之Starlette
python·uvicorn
Coding茶水间2 小时前
基于深度学习的水面垃圾检测系统演示与介绍(YOLOv12/v11/v8/v5模型+Pyqt5界面+训练代码+数据集)
图像处理·人工智能·深度学习·yolo·目标检测·机器学习·计算机视觉
小智RE0-走在路上4 小时前
Python学习笔记(11) --数据可视化
笔记·python·学习
历程里程碑4 小时前
hot 206
java·开发语言·数据结构·c++·python·算法·排序算法
Coder_Boy_4 小时前
Java+Proteus仿真Arduino控制LED问题排查全记录(含交互过程)
java·人工智能·python
qq_356196954 小时前
day47_预训练模型与迁移学习@浙大疏锦行
python