pytorch(3d、4d张量转换)

维度转换

import torch

from einops import rearrange

print(torch.cuda.is_available())

to_3d

把四维的张量转换为三维的张量,输入形状(b,c,h,w),输出形状(b,hw,c)

def to_3d(x):

return rearrange(x, 'b c h w -> b (h w) c')

to_4d

把三维的张量转换为四维的张量,输入形状(b,hw,c),输出形状(b,c,h,w)

def to_4d(x,h,w):

return rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)

测试

if name == 'main':

创建一个四维张量

tensor_4d = torch.randn(2, 3, 4, 5) # 形状(批大小2, 通道数3, 高度4, 宽度5)

转换为三维张量

tensor_3d = to_3d(tensor_4d) # 形状(批大小2, 20, 3)

转换为四维张量

height, width = 4, 5

tensor_4d_back = to_4d(tensor_3d, height, width) # 形状(批大小2, 通道数3, 高度4, 宽度5)

print(tensor_4d.shape)

print(tensor_3d.shape) # 输出:torch.Size([2, 20, 3])

print(tensor_4d_back.shape) # 输出:torch.Size([2, 3, 4, 5])

相关推荐
夜思红尘3 小时前
算法--双指针
python·算法·剪枝
人工智能训练3 小时前
OpenEnler等Linux系统中安装git工具的方法
linux·运维·服务器·git·vscode·python·ubuntu
Tipriest_3 小时前
torch训练出的模型的组成以及模型训练后的使用和分析办法
人工智能·深度学习·torch·utils
QuiteCoder3 小时前
深度学习的范式演进、架构前沿与通用人工智能之路
人工智能·深度学习
智航GIS4 小时前
8.2 面向对象
开发语言·python
weixin_468466854 小时前
YOLOv13结合代码原理详细解析及模型安装与使用
人工智能·深度学习·yolo·计算机视觉·图像识别·目标识别·yolov13
蹦蹦跳跳真可爱5895 小时前
Python----大模型(GPT-2模型训练加速,训练策略)
人工智能·pytorch·python·gpt·embedding
xwill*5 小时前
π∗0.6: a VLA That Learns From Experience
人工智能·pytorch·python
还不秃顶的计科生5 小时前
LeetCode 热题 100第二题:字母易位词分组python版本
linux·python·leetcode
weixin_462446235 小时前
exo + tinygrad:Linux 节点设备能力自动探测(NVIDIA / AMD / CPU 安全兜底)
linux·运维·python·安全