pytorch(3d、4d张量转换)

维度转换

import torch

from einops import rearrange

print(torch.cuda.is_available())

to_3d

把四维的张量转换为三维的张量,输入形状(b,c,h,w),输出形状(b,hw,c)

def to_3d(x):

return rearrange(x, 'b c h w -> b (h w) c')

to_4d

把三维的张量转换为四维的张量,输入形状(b,hw,c),输出形状(b,c,h,w)

def to_4d(x,h,w):

return rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)

测试

if name == 'main':

创建一个四维张量

tensor_4d = torch.randn(2, 3, 4, 5) # 形状(批大小2, 通道数3, 高度4, 宽度5)

转换为三维张量

tensor_3d = to_3d(tensor_4d) # 形状(批大小2, 20, 3)

转换为四维张量

height, width = 4, 5

tensor_4d_back = to_4d(tensor_3d, height, width) # 形状(批大小2, 通道数3, 高度4, 宽度5)

print(tensor_4d.shape)

print(tensor_3d.shape) # 输出:torch.Size([2, 20, 3])

print(tensor_4d_back.shape) # 输出:torch.Size([2, 3, 4, 5])

相关推荐
肾透侧视攻城狮2 小时前
《从fit()到分布式训练:深度解锁TensorFlow模型训练全栈技能》
人工智能·深度学习·tensorflow 模型训练·模型训练中的fit方法·自定义训练循环·回调函数使用·混合精度/分布式训练
索木木2 小时前
大模型训练CP切分(与TP、SP结合)
人工智能·深度学习·机器学习·大模型·训练·cp·切分
喵手2 小时前
Python爬虫实战:增量爬虫实战 - 利用 HTTP 缓存机制实现“极致减负”(附CSV导出 + SQLite持久化存储)!
爬虫·python·爬虫实战·零基础python爬虫教学·增量爬虫·http缓存机制·极致减负
一个处女座的程序猿O(∩_∩)O3 小时前
Python异常处理完全指南:KeyError、TypeError、ValueError深度解析
开发语言·python
was1723 小时前
使用 Python 脚本一键上传图片到兰空图床并自动复制链接
python·api上传·自建图床·一键脚本
好学且牛逼的马3 小时前
从“Oak”到“虚拟线程”:JDK 1.0到25演进全记录与核心知识点详解a
java·开发语言·python
shangjian0073 小时前
Python基础-环境安装-Anaconda配置虚拟环境
开发语言·python
codeJinger3 小时前
【Python】函数
开发语言·python
量子-Alex3 小时前
【大模型思维链】COT、COT-SC、TOT和RAP四篇经典工作对比分析
人工智能·深度学习·机器学习
geovindu4 小时前
python: Command Pattern
开发语言·python·命令模式