目录
根据这阶段的学习情况以及为了加深印象,所以将一些关键知识点以表格的形式提取出来,这样既可以突出重点,更方便自己复习回顾。
重点函数:
|----|--------------|------------------------------------|
| 序号 | 函数名称 | 功能说明 |
| 1 | shape() | 可以在保证张量数据不变的前提下改变数据的维度,将其转换成指定的形状。 |
| 2 | squeze() | 删除形状为 1 的维度(降维)。 |
| 3 | unsqueze() | 添加形状为1的维度(升维)。 |
| 4 | transpose() | 可以实现交换张量形状的指定维度。 |
| 5 | permute() | 可以用于修改张量的形状,只能用于修改连续的张量。 |
| 6 | view() | 修改连续张量的形状, 操作等同于reshape()。 |
| 7 | contiguous() | 将张量转为连续张量。 |
一、张量索引操作
|----|-----------|---------------------------------------------|
| 序号 | 说明 | 简单示例 |
| 1 | 掌握行列索引的使用 | 第一行:data[0] 第一列:data[:,0] |
| 2 | 掌握列表索引的使用 | 第二行第三列的值和第四行第五列值 data[[1, 3], [2, 4]] |
| 3 | 掌握索引范围的使用 | 第一行第三行以及第二列第四列张量 data[::2, 1::2] |
| 4 | 知道布尔值的使用 | data[data[:, 1] > 6] |
| 5 | 知道多维索引的使用 | data2[0, :, :] |
具体练习代码如下:
python
import torch
# 下标从左到右从0开始
# data[行下标,列下标]
# data[0轴下标,1轴下标,2轴下标]
def demo1():
# 创建张量
torch.manual_seed(0)
data = torch.randint(low=0, high=10, size=(4, 5))
print(f'data: {data}')
# 根据下标获取对应位置的元素
# 行数据 第一行
print(f'第一行为: {data[0]}')
print(f'第一列为:{data[:,0]}')
# 获取第二行第三列的值和第四行第五列的值
# 特殊点:前面行,后面列匹配
print(f'data[[1,3],[2,4]为:{data[[1,3],[2,4]]}')
# 第二列大于6的所有行的数据
print(f'data[:,1]>6的行有:{data[:,1]>6}')
# 第三行大于6的所有列数据
print(f'data[:,data[2]>6]的列为:{data[:,data[2]>6]}')
# 根据范围取值:[::]从起始位置开始,每隔一个元素取一个从起始位置开始,每隔一个元素取一个
# 第一行第三行以及第二列第四列张量
print(f'data[::2,1::2]为:{data[::2,1::2]}')
# 创建三维张量
data2 = torch.randint(low=0, high=10, size=(3,4,5))
print(f'data2: {data2}')
print(f'0轴的第一个值为:{data2[0,:,:]}')
print(f'1轴的第一个值为:{data2[:,0,:]}')
print(f'2轴的第一个值为:{data2[:,:,0]}')
if __name__ == '__main__':
demo1()
二、张量形状操作
代码如下:
python
import torch
# reshape(shape=(行,列)): 修改连续或非连续张量的形状, 不改数据
# -1: 表示自动计算行或列 例如: (5, 6) -> (-1, 3) -1*3=5*6 -1=10 (10, 3)
def dm01():
torch.manual_seed(0)
t1 = torch.randint(0, 10, (5, 6))
print('t1->', t1)
print('t1的形状->', t1.shape)
# 形状修改为 (2, 15)
t2 = t1.reshape(shape=(2, 15))
t3 = t1.reshape(shape=(2, -1))
print('t2->', t2)
print('t2的形状->', t2.shape)
print('t3->', t3)
print('t3的形状->', t3.shape)
if __name__ == '__main__':
dm01()
三、张量升维与降维
代码如下:
python
# squeeze(dim=): 删除值为1的维度, dim->指定维度, 维度值不为1不生效 不设置dim,删除所有值为1的维度
# 例如: (3,1,2,1) -> squeeze()->(3,2) squeeze(dim=1)->(3,2,1)
# unqueeze(dim=): 在指定维度上增加值为1的维度 dim=-1:最后维度
def dm02():
torch.manual_seed(0)
# 四维
t1 = torch.randint(0, 10, (3, 1, 2, 1))
print('t1->', t1)
print('t1的形状->', t1.shape)
# squeeze: 降维
t2 = torch.squeeze(t1)
print('t2->', t2)
print('t2的形状->', t2.shape)
# dim: 指定维度
t3 = torch.squeeze(t1, dim=1)
print('t3->', t3)
print('t3的形状->', t3.shape)
# unsqueeze: 升维
# (3, 2)->(1, 3, 2)
# t4 = t2.unsqueeze(dim=0)
# 最后维度 (3, 2)->(3, 2, 1)
t4 = t2.unsqueeze(dim=-1)
print('t4->', t4)
print('t4的形状->', t4.shape)
if __name__ == '__main__':
dm02()
四、张量维度顺序变换
代码如下:
python
# 调换维度
# torch.permute(input=,dims=): 改变张量任意维度顺序
# input: 张量对象
# dims: 改变后的维度顺序, 传入轴下标值 (1,2,3)->(3,1,2)
# torch.transpose(input=,dim0=,dim1=): 改变张量两个维度顺序
# dim0: 轴下标值, 第一个维度
# dim1: 轴下标值, 第二个维度
# (1,2,3)->(2,1,3) 一次只能交换两个维度
def dm03():
torch.manual_seed(0)
t1 = torch.randint(low=0, high=10, size=(3, 4, 5))
print('t1->', t1)
print('t1形状->', t1.shape)
# 交换0维和1维数据
# t2 = t1.transpose(dim0=1, dim1=0)
t2 = t1.permute(dims=(1, 0, 2))
print('t2->', t2)
print('t2形状->', t2.shape)
# t1形状修改为 (5, 3, 4)
t3 = t1.permute(dims=(2, 0, 1))
print('t3->', t3)
print('t3形状->', t3.shape)
if __name__ == '__main__':
dm03()
五、张量的修改
代码如下:
python
# tensor.view(shape=): 修改连续张量的形状, 操作等同于reshape()
# tensor.is_contiugous(): 判断张量是否连续, 返回True/False 张量经过transpose/permute处理变成不连续
# tensor.contiugous(): 将张量转为连续张量
def dm04():
torch.manual_seed(0)
t1 = torch.randint(low=0, high=10, size=(3, 4))
print('t1->', t1)
print('t1形状->', t1.shape)
print('t1是否连续->', t1.is_contiguous())
# 修改张量形状
t2 = t1.view((4, 3))
print('t2->', t2)
print('t2形状->', t2.shape)
print('t2是否连续->', t2.is_contiguous())
# 张量经过transpose操作
t3 = t1.transpose(dim0=1, dim1=0)
print('t3->', t3)
print('t3形状->', t3.shape)
print('t3是否连续->', t3.is_contiguous())
# 修改张量形状
# view
# contiugous(): 转换成连续张量
t4 = t3.contiguous().view((3, 4))
print('t4->', t4)
t5 = t3.reshape(shape=(3, 4))
print('t5->', t5)
print('t5是否连续->', t5.is_contiguous())
if __name__ == '__main__':
dm04()
基本上上述函数使用较多,所以加深记忆即可。