pytorch小记(十):pytorch中torch.tril 和 torch.triu 详解

pytorch小记(十):pytorch中torch.tril 和 torch.triu 详解

  • [PyTorch `torch.tril` 和 `torch.triu` 详解](#PyTorch torch.triltorch.triu 详解)
    • [1. `torch.tril`(计算下三角矩阵)](#1. torch.tril(计算下三角矩阵))
      • [📌 作用](#📌 作用)
      • [🔍 语法](#🔍 语法)
      • [🔹 参数](#🔹 参数)
      • [📌 示例](#📌 示例)
      • [🔍 `diagonal` 参数](#🔍 diagonal 参数)
      • [🔍 `torch.tril` 的应用](#🔍 torch.tril 的应用)
    • [2. `torch.triu`(计算上三角矩阵)](#2. torch.triu(计算上三角矩阵))
      • [📌 作用](#📌 作用)
      • [🔍 语法](#🔍 语法)
      • [🔹 参数](#🔹 参数)
      • [📌 示例](#📌 示例)
      • [🔍 `diagonal` 参数](#🔍 diagonal 参数)
    • [3. `torch.tril` vs `torch.triu` 对比](#3. torch.tril vs torch.triu 对比)
    • 总结

PyTorch torch.triltorch.triu 详解

在数值计算和深度学习中,下三角矩阵(Lower Triangular Matrix)上三角矩阵(Upper Triangular Matrix) 是非常常见的矩阵操作。PyTorch 提供了 torch.tril()torch.triu() 这两个函数,分别用于计算下三角矩阵和上三角矩阵。


1. torch.tril(计算下三角矩阵)

📌 作用

torch.tril 返回输入张量的 下三角部分,即:

  • 保留 主对角线及其以下的元素
  • 主对角线以上的元素全部变为 0

🔍 语法

python 复制代码
torch.tril(input, diagonal=0)

🔹 参数

参数 说明
input 输入张量
diagonal 控制对角线位置 (默认 0
diagonal=0 保留主对角线 及其以下的元素
diagonal>0 向上偏移 ,保留主对角线以上 diagonal
diagonal<0 向下偏移 ,移除 -diagonal 行的主对角线元素

📌 示例

python 复制代码
import torch

# 创建一个 4×4 的矩阵
A = torch.tensor([
    [1, 2, 3, 4],
    [5, 6, 7, 8],
    [9, 10, 11, 12],
    [13, 14, 15, 16]
])

print("原始矩阵 A:")
print(A)

# 计算 A 的下三角矩阵
L = torch.tril(A)
print("\nA 的下三角矩阵(diagonal=0):")
print(L)

输出:

复制代码
原始矩阵 A:
tensor([[ 1,  2,  3,  4],
        [ 5,  6,  7,  8],
        [ 9, 10, 11, 12],
        [13, 14, 15, 16]])

A 的下三角矩阵(diagonal=0):
tensor([[ 1,  0,  0,  0],
        [ 5,  6,  0,  0],
        [ 9, 10, 11,  0],
        [13, 14, 15, 16]])

💡 说明 :主对角线上的元素保留,其上的元素变为 0


🔍 diagonal 参数

python 复制代码
print(torch.tril(A, diagonal=1))  # 保留主对角线以上 1 行
print(torch.tril(A, diagonal=-1)) # 移除主对角线

输出:

复制代码
A 的下三角矩阵(diagonal=1):
tensor([[ 1,  2,  0,  0],
        [ 5,  6,  7,  0],
        [ 9, 10, 11, 12],
        [13, 14, 15, 16]])

A 的下三角矩阵(diagonal=-1):
tensor([[ 0,  0,  0,  0],
        [ 5,  0,  0,  0],
        [ 9, 10,  0,  0],
        [13, 14, 15,  0]])

🔺 diagonal=1向上偏移 ,保留 1 行主对角线以上的元素。
🔻 diagonal=-1向下偏移,移除主对角线。


🔍 torch.tril 的应用

📌 用于 Masking(掩码)

python 复制代码
seq_length = 5
mask = torch.tril(torch.ones(seq_length, seq_length))  # 创建一个下三角 Mask
print(mask)

输出:

复制代码
tensor([[1., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0.],
        [1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1.]])

💡 Transformer 中,这种 Mask 用于防止模型在训练时提前看到未来的信息。


2. torch.triu(计算上三角矩阵)

📌 作用

torch.triu 返回输入张量的 上三角部分,即:

  • 保留 主对角线及其以上的元素
  • 主对角线以下的元素全部变为 0

🔍 语法

python 复制代码
torch.triu(input, diagonal=0)

🔹 参数

参数 说明
input 输入张量
diagonal=0 保留主对角线及其以上的元素
diagonal>0 移除 diagonal 行的主对角线元素
diagonal<0 保留主对角线以下 -diagonal

📌 示例

python 复制代码
U = torch.triu(A)
print("A 的上三角矩阵(diagonal=0):")
print(U)

输出:

复制代码
A 的上三角矩阵(diagonal=0):
tensor([[ 1,  2,  3,  4],
        [ 0,  6,  7,  8],
        [ 0,  0, 11, 12],
        [ 0,  0,  0, 16]])

💡 说明 :主对角线上的元素及其上的元素保留,下面的元素变为 0


🔍 diagonal 参数

python 复制代码
print(torch.triu(A, diagonal=1))  # 移除主对角线元素
print(torch.triu(A, diagonal=-1)) # 保留主对角线以下 1 行

输出:

复制代码
A 的上三角矩阵(diagonal=1):
tensor([[ 0,  2,  3,  4],
        [ 0,  0,  7,  8],
        [ 0,  0,  0, 12],
        [ 0,  0,  0,  0]])

A 的上三角矩阵(diagonal=-1):
tensor([[ 1,  2,  3,  4],
        [ 5,  6,  7,  8],
        [ 0, 10, 11, 12],
        [ 0,  0, 15, 16]])

🔺 diagonal=1 :移除主对角线的元素,仅保留主对角线以上的元素。
🔻 diagonal=-1 :允许保留主对角线以下 1 行的元素。


3. torch.tril vs torch.triu 对比

作用 torch.tril(A) torch.triu(A)
计算结果 取下三角部分 取上三角部分
对角线控制 diagonal=0 保留主对角线 diagonal=0 保留主对角线
diagonal>0 保留主对角线以上元素 移除主对角线部分元素
diagonal<0 移除主对角线部分元素 保留主对角线以下部分

总结

  • torch.tril()下三角矩阵 ,可以用于 Cholesky 分解Transformer Masking
  • torch.triu()上三角矩阵 ,常用于 线性代数计算矩阵变换

🚀 你可以根据不同的需求选择合适的函数,在 PyTorch 中高效处理矩阵运算!

相关推荐
LeeZhao@1 小时前
【数据挖掘】时间序列预测-常用序列预测模型
人工智能·自然语言处理·数据挖掘·agi
没有梦想的咸鱼185-1037-16631 小时前
解锁空间数据新质生产力暨:AI(DeepSeek、ChatGPT)、Python、ArcGIS Pro多技术融合下的空间数据分析、建模与科研绘图及论文写作
人工智能·python·深度学习·机器学习·arcgis·chatgpt·数据分析
z_mazin2 小时前
用户行为检测技术解析:从请求头到流量模式的对抗与防御
python·网络爬虫
狐凄3 小时前
Python实例题:使用Pvthon3编写系列实用脚本
java·网络·python
乌旭4 小时前
量子计算与GPU的异构加速:基于CUDA Quantum的混合编程实践
人工智能·pytorch·分布式·深度学习·ai·gpu算力·量子计算
deephub5 小时前
CLIMB自举框架:基于语义聚类的迭代数据混合优化及其在LLM预训练中的应用
人工智能·深度学习·大语言模型·聚类
思通数科AI全行业智能NLP系统6 小时前
AI视频技术赋能幼儿园安全——教师离岗报警系统的智慧守护
大数据·人工智能·安全·目标检测·目标跟踪·自然语言处理·ocr
struggle20257 小时前
deepseek-cli开源的强大命令行界面,用于与 DeepSeek 的 AI 模型进行交互
人工智能·开源·自动化·交互·deepseek
ocr_sinosecu18 小时前
OCR定制识别:解锁文字识别的无限可能
人工智能·机器学习·ocr
fish_study_csdn8 小时前
pytest 技术总结
开发语言·python·pytest