pytorch(3d、4d张量转换)

维度转换

import torch

from einops import rearrange

print(torch.cuda.is_available())

to_3d

把四维的张量转换为三维的张量,输入形状(b,c,h,w),输出形状(b,hw,c)

def to_3d(x):

return rearrange(x, 'b c h w -> b (h w) c')

to_4d

把三维的张量转换为四维的张量,输入形状(b,hw,c),输出形状(b,c,h,w)

def to_4d(x,h,w):

return rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)

测试

if name == 'main':

创建一个四维张量

tensor_4d = torch.randn(2, 3, 4, 5) # 形状(批大小2, 通道数3, 高度4, 宽度5)

转换为三维张量

tensor_3d = to_3d(tensor_4d) # 形状(批大小2, 20, 3)

转换为四维张量

height, width = 4, 5

tensor_4d_back = to_4d(tensor_3d, height, width) # 形状(批大小2, 通道数3, 高度4, 宽度5)

print(tensor_4d.shape)

print(tensor_3d.shape) # 输出:torch.Size([2, 20, 3])

print(tensor_4d_back.shape) # 输出:torch.Size([2, 3, 4, 5])

相关推荐
计算机毕设匠心工作室3 分钟前
【python大数据毕设实战】全国健康老龄化数据分析系统、Hadoop、计算机毕业设计、包括数据爬取、数据分析、数据可视化、机器学习
后端·python
Dxy123931021622 分钟前
Python的PIL对象crop函数详解
开发语言·python
一条破秋裤24 分钟前
零样本学习指标
深度学习·学习·机器学习
翔云 OCR API30 分钟前
护照NFC识读鉴伪接口集成-让身份核验更加智能与高效
开发语言·人工智能·python·计算机视觉·ocr
三好kiii36 分钟前
海康威视热成像摄像头温度矩阵提取实战:ISAPI + Python 实现无 SDK 读取
图像处理·python
logocode_li39 分钟前
面试 LoRA 被问懵?B 矩阵初始化为 0 的原因,大多数人拿目标来回答
人工智能·python·面试·职场和发展·矩阵
零日失眠者1 小时前
【网络工具系列】002:网站可用性监控脚本
python·代码规范
茶色岛^1 小时前
解析CLIP:从“看标签”到“读描述”
人工智能·深度学习·机器学习
MrSYJ1 小时前
pyenv管理多个版本的python,你造吗?我才造
python·llm·ai编程