pytorch(3d、4d张量转换)

维度转换

import torch

from einops import rearrange

print(torch.cuda.is_available())

to_3d

把四维的张量转换为三维的张量,输入形状(b,c,h,w),输出形状(b,hw,c)

def to_3d(x):

return rearrange(x, 'b c h w -> b (h w) c')

to_4d

把三维的张量转换为四维的张量,输入形状(b,hw,c),输出形状(b,c,h,w)

def to_4d(x,h,w):

return rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)

测试

if name == 'main':

创建一个四维张量

tensor_4d = torch.randn(2, 3, 4, 5) # 形状(批大小2, 通道数3, 高度4, 宽度5)

转换为三维张量

tensor_3d = to_3d(tensor_4d) # 形状(批大小2, 20, 3)

转换为四维张量

height, width = 4, 5

tensor_4d_back = to_4d(tensor_3d, height, width) # 形状(批大小2, 通道数3, 高度4, 宽度5)

print(tensor_4d.shape)

print(tensor_3d.shape) # 输出:torch.Size([2, 20, 3])

print(tensor_4d_back.shape) # 输出:torch.Size([2, 3, 4, 5])

相关推荐
肾透侧视攻城狮几秒前
《工业级实战:TensorFlow房价预测模型开发、优化与问题排查指南》
人工智能·深度学习·tensorfl波士顿房价预测·调整网络结构·使用k折交叉验证·添加正则化防止过拟合·tensorflow之回归问题
喵手7 分钟前
Python爬虫实战:自动化构建 arXiv 本地知识库 - 从 PDF 下载到元数据索引!
爬虫·python·自动化·arxiv·本地知识库·pdf下载·元数据索引
百锦再7 分钟前
Java InputStream和OutputStream实现类完全指南
java·开发语言·spring boot·python·struts·spring cloud·kafka
闲人编程9 分钟前
Celery分布式任务队列
redis·分布式·python·celery·任务队列·异步化
是小蟹呀^10 分钟前
【论文阅读15】告别死板!ElasticFace 如何用“弹性边缘”提升人脸识别性能
论文阅读·深度学习·分类·elasticface
deephub13 分钟前
深入RAG架构:分块策略、混合检索与重排序的工程实现
人工智能·python·大语言模型·rag
danyang_Q19 分钟前
vscode python-u问题
开发语言·vscode·python
忘忧记24 分钟前
python QT sqlsite版本 图书管理系统
开发语言·python·qt
长安牧笛26 分钟前
车载模型白天晚上自动切换,自动切昼夜模型,颠覆统一模型,输出稳定识别。
python·编程语言
人工智能研究所26 分钟前
从 0 开始学习人工智能——什么是推理模型?
人工智能·深度学习·学习·机器学习·语言模型·自然语言处理