pytorch(3d、4d张量转换)

维度转换

import torch

from einops import rearrange

print(torch.cuda.is_available())

to_3d

把四维的张量转换为三维的张量,输入形状(b,c,h,w),输出形状(b,hw,c)

def to_3d(x):

return rearrange(x, 'b c h w -> b (h w) c')

to_4d

把三维的张量转换为四维的张量,输入形状(b,hw,c),输出形状(b,c,h,w)

def to_4d(x,h,w):

return rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)

测试

if name == 'main':

创建一个四维张量

tensor_4d = torch.randn(2, 3, 4, 5) # 形状(批大小2, 通道数3, 高度4, 宽度5)

转换为三维张量

tensor_3d = to_3d(tensor_4d) # 形状(批大小2, 20, 3)

转换为四维张量

height, width = 4, 5

tensor_4d_back = to_4d(tensor_3d, height, width) # 形状(批大小2, 通道数3, 高度4, 宽度5)

print(tensor_4d.shape)

print(tensor_3d.shape) # 输出:torch.Size([2, 20, 3])

print(tensor_4d_back.shape) # 输出:torch.Size([2, 3, 4, 5])

相关推荐
ljuncong5 分钟前
python的装饰器怎么使用
开发语言·python
A7bert7776 分钟前
【YOLOv5seg部署RK3588】模型训练→转换RKNN→开发板部署
linux·c++·人工智能·深度学习·yolo·目标检测
该用户已不存在18 分钟前
没有这7款工具,难怪你的Python这么慢
后端·python
serve the people23 分钟前
tensorflow 零基础吃透:RaggedTensor 的不规则形状与广播机制 2
人工智能·python·tensorflow
donkey_199323 分钟前
ShiftwiseConv: Small Convolutional Kernel with Large Kernel Effect
人工智能·深度学习·目标检测·计算机视觉·语义分割·实例分割
Hello.Reader24 分钟前
Flink ML 基本概念Table API、Stage、Pipeline 与 Graph
大数据·python·flink
chen_note26 分钟前
Python面向对象、并发编程、网络编程
开发语言·python·网络编程·面向对象·并发编程
信看28 分钟前
树莓派CAN(FD) 测试&&RS232 RS485 CAN Board 测试
开发语言·python
brent42328 分钟前
DAY24推断聚类后簇的类型
python
测试199832 分钟前
一个只能通过压测发现Bug
自动化测试·软件测试·python·selenium·测试工具·bug·压力测试