pytorch(3d、4d张量转换)

维度转换

import torch

from einops import rearrange

print(torch.cuda.is_available())

to_3d

把四维的张量转换为三维的张量,输入形状(b,c,h,w),输出形状(b,hw,c)

def to_3d(x):

return rearrange(x, 'b c h w -> b (h w) c')

to_4d

把三维的张量转换为四维的张量,输入形状(b,hw,c),输出形状(b,c,h,w)

def to_4d(x,h,w):

return rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)

测试

if name == 'main':

创建一个四维张量

tensor_4d = torch.randn(2, 3, 4, 5) # 形状(批大小2, 通道数3, 高度4, 宽度5)

转换为三维张量

tensor_3d = to_3d(tensor_4d) # 形状(批大小2, 20, 3)

转换为四维张量

height, width = 4, 5

tensor_4d_back = to_4d(tensor_3d, height, width) # 形状(批大小2, 通道数3, 高度4, 宽度5)

print(tensor_4d.shape)

print(tensor_3d.shape) # 输出:torch.Size([2, 20, 3])

print(tensor_4d_back.shape) # 输出:torch.Size([2, 3, 4, 5])

相关推荐
Adorable老犀牛4 小时前
可遇不可求的自动化运维工具 | 2 | 实施阶段一:基础准备
运维·git·vscode·python·node.js·自动化
盼小辉丶4 小时前
Transformer实战(17)——微调Transformer语言模型进行多标签文本分类
深度学习·分类·transformer
xchenhao5 小时前
SciKit-Learn 全面分析 digits 手写数据集
python·机器学习·分类·数据集·scikit-learn·svm·手写
胡耀超5 小时前
7、Matplotlib、Seaborn、Plotly数据可视化与探索性分析(探索性数据分析(EDA)方法论)
python·信息可视化·plotly·数据挖掘·数据分析·matplotlib·seaborn
tangweiguo030519875 小时前
Django REST Framework 构建安卓应用后端API:从开发到部署的完整实战指南
服务器·后端·python·django
Dfreedom.5 小时前
在Windows上搭建GPU版本PyTorch运行环境的详细步骤
c++·人工智能·pytorch·python·深度学习
confiself6 小时前
AndroidWorld+mobileRL
人工智能·深度学习
兴科Sinco6 小时前
[leetcode 1]给定一个整数数组 nums 和一个整数目标值 target,请你在该数组中找出和为目标值 target 的那两个整数[力扣]
python·算法·leetcode
程序员奈斯6 小时前
Python深度学习:NumPy数组库
python·深度学习·numpy
yongche_shi6 小时前
第二篇:Python“装包”与“拆包”的艺术:可迭代对象、迭代器、生成器
开发语言·python·面试·面试宝典·生成器·拆包·装包