pytorch(3d、4d张量转换)

维度转换

import torch

from einops import rearrange

print(torch.cuda.is_available())

to_3d

把四维的张量转换为三维的张量,输入形状(b,c,h,w),输出形状(b,hw,c)

def to_3d(x):

return rearrange(x, 'b c h w -> b (h w) c')

to_4d

把三维的张量转换为四维的张量,输入形状(b,hw,c),输出形状(b,c,h,w)

def to_4d(x,h,w):

return rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)

测试

if name == 'main':

创建一个四维张量

tensor_4d = torch.randn(2, 3, 4, 5) # 形状(批大小2, 通道数3, 高度4, 宽度5)

转换为三维张量

tensor_3d = to_3d(tensor_4d) # 形状(批大小2, 20, 3)

转换为四维张量

height, width = 4, 5

tensor_4d_back = to_4d(tensor_3d, height, width) # 形状(批大小2, 通道数3, 高度4, 宽度5)

print(tensor_4d.shape)

print(tensor_3d.shape) # 输出:torch.Size([2, 20, 3])

print(tensor_4d_back.shape) # 输出:torch.Size([2, 3, 4, 5])

相关推荐
用户12039112947261 分钟前
AIGC 时代,数据库终于可以“听懂人话”了:从零打造自然语言操作 SQLite 的完整实战
python·sqlite·aigc
Q_Q5110082852 分钟前
python+django/flask+vue农业电商服务系统
spring boot·python·pycharm·django·flask
帕巴啦3 分钟前
Python计算累积频率——Origin绘制累积频率图
python·绘图·origin·累积频率·python计算累积频率·origin绘制累积频率图
Q_Q51100828512 分钟前
python+django/flask+vue的基于疫情防控管理系统的数据可视化分析系统
spring boot·python·django·flask·node.js
生信大表哥34 分钟前
Claude Code / Gemini CLI / Codex CLI 安装大全(Linux 服务器版)
linux·python·ai·r语言·数信院生信服务器
databook38 分钟前
用样本猜总体的秘密武器,4大抽样分布总结
后端·python·数据分析
Jacob程序员1 小时前
欧几里得距离算法-相似度
开发语言·python·算法
a man of sadness1 小时前
GPS轨迹抽稀:降频、滑动窗口、RDP
python·gps·轨迹·抽稀·rdp算法
网安老伯1 小时前
什么是网络安全?网络安全包括哪几个方面?学完能做一名黑客吗?
linux·数据库·python·web安全·网络安全·php·xss
天才测试猿1 小时前
Postman接口测试:如何导入swagger接口文档?
自动化测试·软件测试·python·测试工具·职场和发展·接口测试·postman