pytorch(3d、4d张量转换)

维度转换

import torch

from einops import rearrange

print(torch.cuda.is_available())

to_3d

把四维的张量转换为三维的张量,输入形状(b,c,h,w),输出形状(b,hw,c)

def to_3d(x):

return rearrange(x, 'b c h w -> b (h w) c')

to_4d

把三维的张量转换为四维的张量,输入形状(b,hw,c),输出形状(b,c,h,w)

def to_4d(x,h,w):

return rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)

测试

if name == 'main':

创建一个四维张量

tensor_4d = torch.randn(2, 3, 4, 5) # 形状(批大小2, 通道数3, 高度4, 宽度5)

转换为三维张量

tensor_3d = to_3d(tensor_4d) # 形状(批大小2, 20, 3)

转换为四维张量

height, width = 4, 5

tensor_4d_back = to_4d(tensor_3d, height, width) # 形状(批大小2, 通道数3, 高度4, 宽度5)

print(tensor_4d.shape)

print(tensor_3d.shape) # 输出:torch.Size([2, 20, 3])

print(tensor_4d_back.shape) # 输出:torch.Size([2, 3, 4, 5])

相关推荐
一点晖光29 分钟前
Docker 作图咒语生成器搭建指南
python·docker
smj2302_796826521 小时前
解决leetcode第3768题.固定长度子数组中的最小逆序对数目
python·算法·leetcode
木头左1 小时前
位置编码增强法在量化交易策略中的应用基于短期记忆敏感度提升
python
Acc1oFl4g1 小时前
详解Java反射
java·开发语言·python
Mr.Lee jack1 小时前
【torch.compile】TorchDynamo动态图编译
pytorch
F_D_Z1 小时前
简明 | Yolo-v3结构理解摘要
深度学习·神经网络·yolo·计算机视觉·resnet
java1234_小锋2 小时前
Transformer 大语言模型(LLM)基石 - Transformer架构详解 - 自注意力机制(Self-Attention)原理介绍
深度学习·语言模型·transformer
ney187819024742 小时前
分类网络LeNet + FashionMNIST 准确率92.9%
python·深度学习·分类
Coding茶水间2 小时前
基于深度学习的无人机视角检测系统演示与介绍(YOLOv12/v11/v8/v5模型+Pyqt5界面+训练代码+数据集)
图像处理·人工智能·深度学习·yolo·目标检测·计算机视觉
Data_agent3 小时前
1688获得1688店铺列表API,python请求示例
开发语言·python·算法