pytorch(3d、4d张量转换)

维度转换

import torch

from einops import rearrange

print(torch.cuda.is_available())

to_3d

把四维的张量转换为三维的张量,输入形状(b,c,h,w),输出形状(b,hw,c)

def to_3d(x):

return rearrange(x, 'b c h w -> b (h w) c')

to_4d

把三维的张量转换为四维的张量,输入形状(b,hw,c),输出形状(b,c,h,w)

def to_4d(x,h,w):

return rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)

测试

if name == 'main':

创建一个四维张量

tensor_4d = torch.randn(2, 3, 4, 5) # 形状(批大小2, 通道数3, 高度4, 宽度5)

转换为三维张量

tensor_3d = to_3d(tensor_4d) # 形状(批大小2, 20, 3)

转换为四维张量

height, width = 4, 5

tensor_4d_back = to_4d(tensor_3d, height, width) # 形状(批大小2, 通道数3, 高度4, 宽度5)

print(tensor_4d.shape)

print(tensor_3d.shape) # 输出:torch.Size([2, 20, 3])

print(tensor_4d_back.shape) # 输出:torch.Size([2, 3, 4, 5])

相关推荐
Hui_AI7207 小时前
基于RAG的农产品GEO溯源智能问答系统实现
开发语言·网络·人工智能·python·算法·创业创新
不知名的老吴7 小时前
后端知识点:Python处理加权点赞
开发语言·python
忡黑梨7 小时前
eNSP_从直连到BGP全网互通
c语言·网络·数据结构·python·算法·网络安全
Cyber4K7 小时前
【Python专项】基础语法(2)
开发语言·python
用AI赚一点7 小时前
AI落地不是造大模型:从概念到落地的核心差异
人工智能·深度学习·机器学习
2601_956139428 小时前
文旅行业品牌全案公司哪家强
大数据·人工智能·python
小超同学你好8 小时前
Transformer 30. MoCo:用「动量编码器 + 队列字典」把对比学习做成可扩展的“字典查找”
深度学习·学习·transformer
hrhcode8 小时前
【LangGraph】二.State 和 Node 的设计细节
python·ai·langchain·langgraph·ai框架
dfdfadffa8 小时前
如何创建仅在首次订阅时执行一次计算的 RxJS 懒加载 Observable
jvm·数据库·python
m0_624578598 小时前
SQL分组后如何计算移动平均值_利用窗口函数AVG配合ROWS
jvm·数据库·python