pytorch小记(十):pytorch中torch.tril 和 torch.triu 详解
- [PyTorch `torch.tril` 和 `torch.triu` 详解](#PyTorch
torch.tril
和torch.triu
详解) -
- [1. `torch.tril`(计算下三角矩阵)](#1.
torch.tril
(计算下三角矩阵)) -
- [📌 作用](#📌 作用)
- [🔍 语法](#🔍 语法)
- [🔹 参数](#🔹 参数)
- [📌 示例](#📌 示例)
- [🔍 `diagonal` 参数](#🔍
diagonal
参数) - [🔍 `torch.tril` 的应用](#🔍
torch.tril
的应用)
- [2. `torch.triu`(计算上三角矩阵)](#2.
torch.triu
(计算上三角矩阵)) -
- [📌 作用](#📌 作用)
- [🔍 语法](#🔍 语法)
- [🔹 参数](#🔹 参数)
- [📌 示例](#📌 示例)
- [🔍 `diagonal` 参数](#🔍
diagonal
参数)
- [3. `torch.tril` vs `torch.triu` 对比](#3.
torch.tril
vstorch.triu
对比) - 总结
- [1. `torch.tril`(计算下三角矩阵)](#1.
PyTorch torch.tril
和 torch.triu
详解
在数值计算和深度学习中,下三角矩阵(Lower Triangular Matrix) 和 上三角矩阵(Upper Triangular Matrix) 是非常常见的矩阵操作。PyTorch 提供了 torch.tril()
和 torch.triu()
这两个函数,分别用于计算下三角矩阵和上三角矩阵。
1. torch.tril
(计算下三角矩阵)
📌 作用
torch.tril
返回输入张量的 下三角部分,即:
- 保留 主对角线及其以下的元素。
- 主对角线以上的元素全部变为 0。
🔍 语法
python
torch.tril(input, diagonal=0)
🔹 参数
参数 | 说明 |
---|---|
input |
输入张量 |
diagonal |
控制对角线位置 (默认 0 ) |
diagonal=0 |
保留主对角线 及其以下的元素 |
diagonal>0 |
向上偏移 ,保留主对角线以上 diagonal 行 |
diagonal<0 |
向下偏移 ,移除 -diagonal 行的主对角线元素 |
📌 示例
python
import torch
# 创建一个 4×4 的矩阵
A = torch.tensor([
[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12],
[13, 14, 15, 16]
])
print("原始矩阵 A:")
print(A)
# 计算 A 的下三角矩阵
L = torch.tril(A)
print("\nA 的下三角矩阵(diagonal=0):")
print(L)
输出:
原始矩阵 A:
tensor([[ 1, 2, 3, 4],
[ 5, 6, 7, 8],
[ 9, 10, 11, 12],
[13, 14, 15, 16]])
A 的下三角矩阵(diagonal=0):
tensor([[ 1, 0, 0, 0],
[ 5, 6, 0, 0],
[ 9, 10, 11, 0],
[13, 14, 15, 16]])
💡 说明 :主对角线上的元素保留,其上的元素变为
0
。
🔍 diagonal
参数
python
print(torch.tril(A, diagonal=1)) # 保留主对角线以上 1 行
print(torch.tril(A, diagonal=-1)) # 移除主对角线
输出:
A 的下三角矩阵(diagonal=1):
tensor([[ 1, 2, 0, 0],
[ 5, 6, 7, 0],
[ 9, 10, 11, 12],
[13, 14, 15, 16]])
A 的下三角矩阵(diagonal=-1):
tensor([[ 0, 0, 0, 0],
[ 5, 0, 0, 0],
[ 9, 10, 0, 0],
[13, 14, 15, 0]])
🔺 diagonal=1 :向上偏移 ,保留
1
行主对角线以上的元素。
🔻 diagonal=-1 :向下偏移,移除主对角线。
🔍 torch.tril
的应用
📌 用于 Masking(掩码)
python
seq_length = 5
mask = torch.tril(torch.ones(seq_length, seq_length)) # 创建一个下三角 Mask
print(mask)
输出:
tensor([[1., 0., 0., 0., 0.],
[1., 1., 0., 0., 0.],
[1., 1., 1., 0., 0.],
[1., 1., 1., 1., 0.],
[1., 1., 1., 1., 1.]])
💡 Transformer 中,这种 Mask 用于防止模型在训练时提前看到未来的信息。
2. torch.triu
(计算上三角矩阵)
📌 作用
torch.triu
返回输入张量的 上三角部分,即:
- 保留 主对角线及其以上的元素。
- 主对角线以下的元素全部变为 0。
🔍 语法
python
torch.triu(input, diagonal=0)
🔹 参数
参数 | 说明 |
---|---|
input |
输入张量 |
diagonal=0 |
保留主对角线及其以上的元素 |
diagonal>0 |
移除 diagonal 行的主对角线元素 |
diagonal<0 |
保留主对角线以下 -diagonal 行 |
📌 示例
python
U = torch.triu(A)
print("A 的上三角矩阵(diagonal=0):")
print(U)
输出:
A 的上三角矩阵(diagonal=0):
tensor([[ 1, 2, 3, 4],
[ 0, 6, 7, 8],
[ 0, 0, 11, 12],
[ 0, 0, 0, 16]])
💡 说明 :主对角线上的元素及其上的元素保留,下面的元素变为
0
。
🔍 diagonal
参数
python
print(torch.triu(A, diagonal=1)) # 移除主对角线元素
print(torch.triu(A, diagonal=-1)) # 保留主对角线以下 1 行
输出:
A 的上三角矩阵(diagonal=1):
tensor([[ 0, 2, 3, 4],
[ 0, 0, 7, 8],
[ 0, 0, 0, 12],
[ 0, 0, 0, 0]])
A 的上三角矩阵(diagonal=-1):
tensor([[ 1, 2, 3, 4],
[ 5, 6, 7, 8],
[ 0, 10, 11, 12],
[ 0, 0, 15, 16]])
🔺 diagonal=1 :移除主对角线的元素,仅保留主对角线以上的元素。
🔻 diagonal=-1 :允许保留主对角线以下1
行的元素。
3. torch.tril
vs torch.triu
对比
作用 | torch.tril(A) |
torch.triu(A) |
---|---|---|
计算结果 | 取下三角部分 | 取上三角部分 |
对角线控制 | diagonal=0 保留主对角线 |
diagonal=0 保留主对角线 |
diagonal>0 |
保留主对角线以上元素 | 移除主对角线部分元素 |
diagonal<0 |
移除主对角线部分元素 | 保留主对角线以下部分 |
总结
torch.tril()
取 下三角矩阵 ,可以用于 Cholesky 分解 、Transformer Masking。torch.triu()
取 上三角矩阵 ,常用于 线性代数计算 和 矩阵变换。
🚀 你可以根据不同的需求选择合适的函数,在 PyTorch 中高效处理矩阵运算!