pytorch(3d、4d张量转换)

维度转换

import torch

from einops import rearrange

print(torch.cuda.is_available())

to_3d

把四维的张量转换为三维的张量,输入形状(b,c,h,w),输出形状(b,hw,c)

def to_3d(x):

return rearrange(x, 'b c h w -> b (h w) c')

to_4d

把三维的张量转换为四维的张量,输入形状(b,hw,c),输出形状(b,c,h,w)

def to_4d(x,h,w):

return rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)

测试

if name == 'main':

创建一个四维张量

tensor_4d = torch.randn(2, 3, 4, 5) # 形状(批大小2, 通道数3, 高度4, 宽度5)

转换为三维张量

tensor_3d = to_3d(tensor_4d) # 形状(批大小2, 20, 3)

转换为四维张量

height, width = 4, 5

tensor_4d_back = to_4d(tensor_3d, height, width) # 形状(批大小2, 通道数3, 高度4, 宽度5)

print(tensor_4d.shape)

print(tensor_3d.shape) # 输出:torch.Size([2, 20, 3])

print(tensor_4d_back.shape) # 输出:torch.Size([2, 3, 4, 5])

相关推荐
加强洁西卡7 小时前
【框架】Pytorch和vLLMnull
深度学习
ting94520007 小时前
动手学深度学习(PyTorch版)深度详解(1)(含实操+避坑)
pytorch·深度学习·学习
牛大兵7 小时前
播放网络摄像头视频支持ONVIF/RTSP
网络·python·音视频
nervermore9907 小时前
3. 人工智能学习-PyTorch框架学习
人工智能·pytorch·学习
m0_495496417 小时前
SQL中如何获取前N个最大值并排除自己_利用窗口函数限制
jvm·数据库·python
m0_740653227 小时前
mysql如何提取日期中的年份_使用year函数从日期中截取
jvm·数据库·python
运气好好的7 小时前
mysql数据库日志文件过大如何清理_定期备份与重置日志文件
jvm·数据库·python
ATMQuant7 小时前
量化策略开发01:我让AI全权做交易决策 - 从提示词设计到决策执行
python·量化交易·vnpy·ai策略
站大爷IP7 小时前
如何在 Python 中使用 colorama 库来给输出添加颜色
python