pytorch(3d、4d张量转换)

维度转换

import torch

from einops import rearrange

print(torch.cuda.is_available())

to_3d

把四维的张量转换为三维的张量,输入形状(b,c,h,w),输出形状(b,hw,c)

def to_3d(x):

return rearrange(x, 'b c h w -> b (h w) c')

to_4d

把三维的张量转换为四维的张量,输入形状(b,hw,c),输出形状(b,c,h,w)

def to_4d(x,h,w):

return rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)

测试

if name == 'main':

创建一个四维张量

tensor_4d = torch.randn(2, 3, 4, 5) # 形状(批大小2, 通道数3, 高度4, 宽度5)

转换为三维张量

tensor_3d = to_3d(tensor_4d) # 形状(批大小2, 20, 3)

转换为四维张量

height, width = 4, 5

tensor_4d_back = to_4d(tensor_3d, height, width) # 形状(批大小2, 通道数3, 高度4, 宽度5)

print(tensor_4d.shape)

print(tensor_3d.shape) # 输出:torch.Size(2, 20, 3)

print(tensor_4d_back.shape) # 输出:torch.Size(2, 3, 4, 5)

相关推荐
月疯几秒前
PyTorch 中定义了一个 LeakyReLU 激活函数层
人工智能·pytorch·python
小白学大数据6 分钟前
AI 智能爬虫实战:Selenium+Python 自动绕反爬、一键提取数据
爬虫·python·selenium·数据分析
DreamLife☼7 分钟前
OpenBCI-实战二:脑波控制小游戏开发
python·pygame·openbci·cyton·ganglion
smj2302_796826528 分钟前
解决leetcode第3948题字典序最大的MEX数组
python·算法·leetcode
程序大视界23 分钟前
【Python系列课程】Pandas(六):数据读写——CSV与Excel文件操作
python·excel·pandas
lqqjuly40 分钟前
推荐系统技术解析(Recommendation Systems)
深度学习·推荐算法
weixin_407443871 小时前
OCR材料信息提取工具(附件中含代码和数据)
人工智能·python·计算机视觉·ocr
码农阿强1 小时前
PixVerse 全系列视频生成模型技术架构详解 + Python 基于 StartAPI.top 接口实战调用
python·ai·架构·音视频·ai编程
Smilecoc1 小时前
风控评分卡模型原理与应用(四):WOE编码的单调性
python
老鱼说AI1 小时前
统计学习方法第八章:Boosting
人工智能·深度学习·神经网络·机器学习·学习方法·集成学习·boosting