Tensorflow2.0笔记 - 修改形状和维度

本次笔记主要使用reshape,transpose,expand_dim,和squeeze对tensor的形状和维度进行操作。

import tensorflow as tf
import numpy as np

tf.__version__

#tensor的shape和维数获取
#假设下面这个tensor表示4张28*28*3的图片
tensor = tf.random.uniform([4,28,28,3], minval=0, maxval=10, dtype=tf.int32)
print("tensor.shape:", tensor.shape)
print("tensor.ndim:", tensor.ndim)

#reshape成一个三维的tensor,将行和列的信息去掉,只保留pixel概念
print("=======reshape([4,28*28,3].shape=========\n", tf.reshape(tensor, [4,28*28,3]).shape)
#reshape里的参数中可以出现一个-1,表示自动计算省略掉的维度的大小
#还是上面的例子,将行和列的信息去掉,只保留pixel的概念
print("=======reshape([4,-1,3].shape=========\n", tf.reshape(tensor, [4,-1,3]).shape)
#将图片的行和列信息和RGB通道信息去掉,图片数据作为一个整体,等价于tf.reshape(tensor, [4, 28*28*3])
print("=======reshape([4,-1].shape=========\n", tf.reshape(tensor, [4,-1]).shape)

#transpose进行转置操作,会修改tensor的数据布局
tensor = tf.random.uniform([4,3,2,1], minval=0, maxval=9, dtype=tf.int32)
print(tensor.shape,tensor.ndim)
print(tensor)

#不带参数,表示整体转置,对所有维度进行转置
transpose = tf.transpose(tensor)
print("========Transpose without arg:", transpose.shape)
print(transpose)
#带参数,给出perm参数,表示原来的维度放到哪个位置
#第0个和第1个维度保留,交换最后两个维度
transpose = tf.transpose(tensor, perm=[0,1,3,2])
print("========Transpose by arg:", transpose.shape)
print(transpose)

#transpose的一个应用案例
#pytorch中,图片信息一般以[b,c,h,w]来表示,b表示batch数量,c表示像素通道数量,h,w表示图片的高度和宽度
#tensorflow中,图片信息一般以[b,h,w,c]来表示
#可以使用transpose进行pytorch和tensorflow格式的互转
#下面的tensor按照pytorch格式理解,两张5*5*3的图片
tensor = tf.random.uniform([2,3,5,5], minval=0, maxval=9, dtype=tf.int32)
print("=====PYTORCH data=====\n", tensor)
#通过transpose转换为tensorflow格式
transpose = tf.transpose(tensor, [0,2,3,1])
print("=====TENSORFLW data====\n", transpose)

#增加(expand)或减少(squeeze)维度
#假设下面的tensor表示4个班级,10个学生,5门科目的成绩
tensor = tf.random.normal([4,10,5])

#现在我们要增加一个学校的维度,使用expand_dims,会在指定axis的前面增加一个维度
#axis表示要在那个维度前面增加
expanded = tf.expand_dims(tensor, axis=0)
print("Expanded at dim0:", expanded.shape)

#在5门科目成绩维度前增加一个维度
expanded = tf.expand_dims(tensor, axis=2)
print("Expanded at dim2:", expanded.shape)

#在5门科目成绩维度后面增加一个维度
expanded = tf.expand_dims(tensor, axis=3)
print("Expanded at dim3:", expanded.shape)

#axis为负数的时候,和numpy索引给-1的情况是类似的,需要注意的是此时会在指定axis的后面增加一个维度
#在5门科目成绩维度前增加一个维度
expanded = tf.expand_dims(tensor, axis=-2)
print("Expanded at dim2:", expanded.shape)
#在最前面增加一个维度
expanded = tf.expand_dims(tensor, axis=-4)
print("Expanded at dim0:", expanded.shape)

#减少维度,仅用于去掉shape=1的维度,如果指定要去掉的维度shape大于1会报错
tensor = tf.zeros([1,2,1,1,3])
print("tensor.shape:", tensor.shape)
#上面的tensor,只有1个维度的位置可以去掉
squeezed = tf.squeeze(tensor)
print("Squeezed:", squeezed.shape)
#指定某个axis进行squeeze
squeezed = tf.squeeze(tensor, axis=0)
print("Squeezed:", squeezed.shape)
#axis为负数的情况
squeezed = tf.squeeze(tensor, axis=-2)
print("Squeezed:", squeezed.shape)

运行结果:

相关推荐
Hunter_pcx23 分钟前
[C++技能提升]类注册
c++·人工智能
东临碣石821 小时前
【重磅AI论文】DeepSeek-R1:通过强化学习激励大语言模型(LLMs)的推理能力
人工智能·深度学习·语言模型
SmallBambooCode1 小时前
【Flask】在Flask应用中使用Flask-Limiter进行简单CC攻击防御
后端·python·flask
Mr.L705171 小时前
Maui学习笔记- SQLite简单使用案例02添加详情页
笔记·学习·ios·sqlite·c#
抱抱宝1 小时前
Pyecharts之图表样式深度定制
python·信息可视化·数据分析
码界筑梦坊1 小时前
基于Flask的哔哩哔哩评论数据可视化分析系统的设计与实现
python·信息可视化·flask·毕业设计
大懒猫软件1 小时前
如何有效使用Python爬虫将网页数据存储到Word文档
爬虫·python·自动化·word
大数据魔法师1 小时前
1905电影网中国地区电影数据分析(二) - 数据分析与可视化
python·数据分析
&白帝&1 小时前
JAVA JDK7时间相关类
java·开发语言·python
星迹日2 小时前
数据结构:二叉树—面试题(二)
java·数据结构·笔记·二叉树·面试题