Tensorflow2.0笔记 - 修改形状和维度

本次笔记主要使用reshape,transpose,expand_dim,和squeeze对tensor的形状和维度进行操作。

复制代码
import tensorflow as tf
import numpy as np

tf.__version__

#tensor的shape和维数获取
#假设下面这个tensor表示4张28*28*3的图片
tensor = tf.random.uniform([4,28,28,3], minval=0, maxval=10, dtype=tf.int32)
print("tensor.shape:", tensor.shape)
print("tensor.ndim:", tensor.ndim)

#reshape成一个三维的tensor,将行和列的信息去掉,只保留pixel概念
print("=======reshape([4,28*28,3].shape=========\n", tf.reshape(tensor, [4,28*28,3]).shape)
#reshape里的参数中可以出现一个-1,表示自动计算省略掉的维度的大小
#还是上面的例子,将行和列的信息去掉,只保留pixel的概念
print("=======reshape([4,-1,3].shape=========\n", tf.reshape(tensor, [4,-1,3]).shape)
#将图片的行和列信息和RGB通道信息去掉,图片数据作为一个整体,等价于tf.reshape(tensor, [4, 28*28*3])
print("=======reshape([4,-1].shape=========\n", tf.reshape(tensor, [4,-1]).shape)

#transpose进行转置操作,会修改tensor的数据布局
tensor = tf.random.uniform([4,3,2,1], minval=0, maxval=9, dtype=tf.int32)
print(tensor.shape,tensor.ndim)
print(tensor)

#不带参数,表示整体转置,对所有维度进行转置
transpose = tf.transpose(tensor)
print("========Transpose without arg:", transpose.shape)
print(transpose)
#带参数,给出perm参数,表示原来的维度放到哪个位置
#第0个和第1个维度保留,交换最后两个维度
transpose = tf.transpose(tensor, perm=[0,1,3,2])
print("========Transpose by arg:", transpose.shape)
print(transpose)

#transpose的一个应用案例
#pytorch中,图片信息一般以[b,c,h,w]来表示,b表示batch数量,c表示像素通道数量,h,w表示图片的高度和宽度
#tensorflow中,图片信息一般以[b,h,w,c]来表示
#可以使用transpose进行pytorch和tensorflow格式的互转
#下面的tensor按照pytorch格式理解,两张5*5*3的图片
tensor = tf.random.uniform([2,3,5,5], minval=0, maxval=9, dtype=tf.int32)
print("=====PYTORCH data=====\n", tensor)
#通过transpose转换为tensorflow格式
transpose = tf.transpose(tensor, [0,2,3,1])
print("=====TENSORFLW data====\n", transpose)

#增加(expand)或减少(squeeze)维度
#假设下面的tensor表示4个班级,10个学生,5门科目的成绩
tensor = tf.random.normal([4,10,5])

#现在我们要增加一个学校的维度,使用expand_dims,会在指定axis的前面增加一个维度
#axis表示要在那个维度前面增加
expanded = tf.expand_dims(tensor, axis=0)
print("Expanded at dim0:", expanded.shape)

#在5门科目成绩维度前增加一个维度
expanded = tf.expand_dims(tensor, axis=2)
print("Expanded at dim2:", expanded.shape)

#在5门科目成绩维度后面增加一个维度
expanded = tf.expand_dims(tensor, axis=3)
print("Expanded at dim3:", expanded.shape)

#axis为负数的时候,和numpy索引给-1的情况是类似的,需要注意的是此时会在指定axis的后面增加一个维度
#在5门科目成绩维度前增加一个维度
expanded = tf.expand_dims(tensor, axis=-2)
print("Expanded at dim2:", expanded.shape)
#在最前面增加一个维度
expanded = tf.expand_dims(tensor, axis=-4)
print("Expanded at dim0:", expanded.shape)

#减少维度,仅用于去掉shape=1的维度,如果指定要去掉的维度shape大于1会报错
tensor = tf.zeros([1,2,1,1,3])
print("tensor.shape:", tensor.shape)
#上面的tensor,只有1个维度的位置可以去掉
squeezed = tf.squeeze(tensor)
print("Squeezed:", squeezed.shape)
#指定某个axis进行squeeze
squeezed = tf.squeeze(tensor, axis=0)
print("Squeezed:", squeezed.shape)
#axis为负数的情况
squeezed = tf.squeeze(tensor, axis=-2)
print("Squeezed:", squeezed.shape)

运行结果:

相关推荐
m0_74675230几秒前
bootstrap怎么给表格添加固定表头效果
jvm·数据库·python
源码之家几秒前
计算机毕业设计:Python基金股票数据分析与可视化平台 Django框架 数据分析 可视化 爬虫 大数据 大模型(建议收藏)✅
爬虫·python·信息可视化·数据分析·django·flask·课程设计
justjinji几秒前
JavaScript 数组引用陷阱与“破纪录”问题的正确解法
jvm·数据库·python
小小王app小程序开发1 分钟前
AI 智能体小程序玩法分析:2026 千亿 AI 风口,冠品科技赋能低门槛落地
人工智能·科技·小程序
Dotrust东信创智1 分钟前
告别脚本依赖:AI 具身智能重构智能座舱 HMI 测试新范式
人工智能·重构
m0_674294641 分钟前
mysql如何通过yum源快速安装_mysql官方yum安装教程
jvm·数据库·python
生信研究猿2 分钟前
#P3492.第1题-基于决策树预判资源调配优先级
python·算法·决策树
justjinji3 分钟前
如何在Node.js中封装通用的MongoDB CRUD操作层_基于原生驱动的DAO层设计模式
jvm·数据库·python
是上好佳佳佳呀3 分钟前
【前端(九)】CSS Transform 2D/3D 变换笔记:分清两个原点,搞懂多重变换
前端·css·笔记
yyk的萌3 分钟前
Spring AI + 智谱大模型实战:打造有记忆功能的智能天气助手
java·人工智能·spring·agent·spring ai