pytorch(3d、4d张量转换)

维度转换

import torch

from einops import rearrange

print(torch.cuda.is_available())

to_3d

把四维的张量转换为三维的张量,输入形状(b,c,h,w),输出形状(b,hw,c)

def to_3d(x):

return rearrange(x, 'b c h w -> b (h w) c')

to_4d

把三维的张量转换为四维的张量,输入形状(b,hw,c),输出形状(b,c,h,w)

def to_4d(x,h,w):

return rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)

测试

if name == 'main':

创建一个四维张量

tensor_4d = torch.randn(2, 3, 4, 5) # 形状(批大小2, 通道数3, 高度4, 宽度5)

转换为三维张量

tensor_3d = to_3d(tensor_4d) # 形状(批大小2, 20, 3)

转换为四维张量

height, width = 4, 5

tensor_4d_back = to_4d(tensor_3d, height, width) # 形状(批大小2, 通道数3, 高度4, 宽度5)

print(tensor_4d.shape)

print(tensor_3d.shape) # 输出:torch.Size([2, 20, 3])

print(tensor_4d_back.shape) # 输出:torch.Size([2, 3, 4, 5])

相关推荐
大、男人3 分钟前
python之contextmanager
android·python·adb
毕设源码-钟学长26 分钟前
【开题答辩全过程】以 基于Python的车辆管理系统为例,包含答辩的问题和答案
开发语言·python
koo36427 分钟前
pytorch深度学习笔记12
pytorch·笔记·深度学习
CCPC不拿奖不改名1 小时前
数据处理与分析:数据可视化的面试习题
开发语言·python·信息可视化·面试·职场和发展
液态不合群1 小时前
线程池和高并发
开发语言·python
旦莫1 小时前
Pytest教程:Pytest与主流测试框架对比
人工智能·python·pytest
数据大魔方1 小时前
【期货量化实战】螺纹钢量化交易指南:品种特性与策略实战(TqSdk完整方案)
python·算法·github·程序员创富·期货程序化·期货量化·交易策略实战
旻璿gg2 小时前
paddleocr、paddleocrvl、ppocrv5
python
清水白石0082 小时前
手写超速 CSV 解析器:利用 multiprocessing 与 mmap 实现 10 倍 Pandas 加速
python·pandas
Corleo2 小时前
记录一次复杂的 ONNX 到 TensorRT 动态 Shape 转换排错过程
python·ai