PyTorch深度学习总结
第三章 PyTorch中张量(Tensor)切片操作
文章目录
一、前言
上文介绍了PyTorch中改变张量(Tensor)形状的操作,本文主要介绍张量切片操作。
二、获取张量中的元素
1、切片(行、列数)方法
①
python# 引入库 import torch # 生成张量 A = torch.arange(9).reshape(3, 3) print(A)生成张量A:
tensor(
\[0, 1, 2,
3, 4, 5,
6, 7, 8])
现截取A0:
pythonprint(A[0]) # 截取最外围括号内第一个括号的内容,第一个维度第一行的内容结果为:
tensor(0, 1, 2)
②
python# 引入库 import torch # 生成张量 B = torch.arange(9).reshape(1, 3, 3) print(B)生成张量B:
tensor(
\[\[0, 1, 2,
3, 4, 5,
6, 7, 8]])
现截取B0:
pythonprint(B[0]) # 截取最外围括号内第括号的内容,第一个维度第一行的内容结果为:
tensor(
\[0, 1, 2,
3, 4, 5,
6, 7, 8])
③根据上文张量B进行截取
pythonprint(B[0, 1:2, 1:2])结果为:
tensor(\[4])
注意此时[1:2]指第2个元素开头到第三个元素为至,且不包含第三个元素。(属于包含左边不包含右边,先行后列)
pythonprint(B[0, 1:3, 1:2])结果为:
tensor(\[4, 7])
pythonprint(B[0, -1, -2])结果为:
tensor(7)
pythonprint(B[0, -3:-1, -2]) # 第一个维度,倒数第三行到倒数二行,倒数第二列的元素结果为:
tensor(1, 4)
④通过比较关系输出元素
pythonprint(B[B>=3])结果为:
tensor(3, 4, 5, 6, 7, 8)
注意此处为获取元素组成1维张量
2、torch.where()函数
pythonC = -B D = torch.where(B>4, B, C) print(D)输出结果为:
tensor(
\[\[ 0, -1, -2,
-3, -4, 5,
6, 7, 8]])
3、使元素置零的操作
| 函数 | 描述 |
|---|---|
| torch.tril(A, diagonal=0) | 将A以第一个元素为对角线的直线,将上三角置零 |
| torch.triu(A, diagonal=0) | 将A以第一个元素为对角线的直线,将下三角置零 |
| torch.diag(A) | 保留对角线,将其他元素全部置零,输入必须是二维张量 |
示例:
torch.tril():
pythonE1 = torch.tril(B, diagonal=0) print(E1)输出结果为:
tensor(
\[\[0, 0, 0,
3, 4, 0,
6, 7, 8]])
pythonE2 = torch.tril(B, diagonal=1) print(E2)输出结果为:
tensor(
\[\[0, 1, 0,
3, 4, 5,
6, 7, 8]])
torch.triu():
pythonF = torch.triu(B, diagonal=0) print(F)输出结果为:
tensor(
\[\[0, 1, 2,
0, 4, 5,
0, 0, 8]])
torch.diag():
pythonH = torch.diag(A) print(H)输出结果为:
tensor(0, 4, 8)