pytorch(3d、4d张量转换)

维度转换

import torch

from einops import rearrange

print(torch.cuda.is_available())

to_3d

把四维的张量转换为三维的张量,输入形状(b,c,h,w),输出形状(b,hw,c)

def to_3d(x):

return rearrange(x, 'b c h w -> b (h w) c')

to_4d

把三维的张量转换为四维的张量,输入形状(b,hw,c),输出形状(b,c,h,w)

def to_4d(x,h,w):

return rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)

测试

if name == 'main':

创建一个四维张量

tensor_4d = torch.randn(2, 3, 4, 5) # 形状(批大小2, 通道数3, 高度4, 宽度5)

转换为三维张量

tensor_3d = to_3d(tensor_4d) # 形状(批大小2, 20, 3)

转换为四维张量

height, width = 4, 5

tensor_4d_back = to_4d(tensor_3d, height, width) # 形状(批大小2, 通道数3, 高度4, 宽度5)

print(tensor_4d.shape)

print(tensor_3d.shape) # 输出:torch.Size([2, 20, 3])

print(tensor_4d_back.shape) # 输出:torch.Size([2, 3, 4, 5])

相关推荐
superman超哥12 分钟前
仓颉语言中基本数据类型的深度剖析与工程实践
c语言·开发语言·python·算法·仓颉
Learner__Q42 分钟前
每天五分钟:滑动窗口-LeetCode高频题解析_day3
python·算法·leetcode
————A1 小时前
强化学习----->轨迹、回报、折扣因子和回合
人工智能·python
徐先生 @_@|||1 小时前
(Wheel 格式) Python 的标准分发格式的生成规则规范
开发语言·python
weixin_409383122 小时前
在kaggle训练Qwen/Qwen2.5-1.5B-Instruct 通过中二时期qq空间记录作为训练数据 训练出中二的模型为目标 第一次训练 好像太二了
人工智能·深度学习·机器学习·qwen
Mqh1807622 小时前
day45 简单CNN
python
学习者0072 小时前
python 下载离线库方法
python
声声codeGrandMaster2 小时前
AI之模型提升
人工智能·pytorch·python·算法·ai
魔镜前的帅比3 小时前
多 Agent 架构:Coordinator + Worker 模式
python·ai
路长冬3 小时前
深度学习评估指标:
深度学习