【PyTorch基础】PyTorch还支持线性代数运算?PyTorch的内置线性代数运算示例

目录

引言

[1. trace - 对角线元素之和(矩阵的迹)](#1. trace - 对角线元素之和(矩阵的迹))

[2. diag - 提取对角线元素](#2. diag - 提取对角线元素)

[3. triu / tril - 上三角/下三角矩阵](#3. triu / tril - 上三角/下三角矩阵)

[4. mm / bmm - 矩阵乘法 / 批量矩阵乘法](#4. mm / bmm - 矩阵乘法 / 批量矩阵乘法)

[5. addmm / addbmm / addmv / addr / baddbmm - 矩阵运算](#5. addmm / addbmm / addmv / addr / baddbmm - 矩阵运算)

[6. t - 转置矩阵](#6. t - 转置矩阵)

[7. dot / cross - 内积 / 外积](#7. dot / cross - 内积 / 外积)

[8. inverse - 求逆矩阵](#8. inverse - 求逆矩阵)

[9. svd - 奇异值分解](#9. svd - 奇异值分解)

结语


引言

PyTorch 作为一个强大的深度学习框架,提供了丰富的线性代数功能,使得开发者能够高效地进行各种矩阵运算。通过使用 PyTorch 中的相关库,我们可以避免重复"造轮子",从而专注于实现更高层次的算法和模型。本文将介绍几个常用的 PyTorch 中的线性代数功能,帮助您更轻松地进行矩阵操作,从而提升您的开发效率。

另外,若有进一步研究需求,可以参考官方文档

1. trace - 对角线元素之和(矩阵的迹)

python 复制代码
# 1. trace - 对角线元素之和(矩阵的迹)
import torch

# 创建一个 3x3 矩阵
tensor = torch.tensor([[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]])
# 计算矩阵的迹
trace_value = torch.trace(tensor)
print(f"矩阵的迹: {trace_value}")

2. diag - 提取对角线元素

python 复制代码
# 2. diag - 提取对角线元素
# 提取对角线元素
diagonal_elements = torch.diag(tensor)
print(f"对角线元素: {diagonal_elements}")

3. triu / tril - 上三角/下三角矩阵

python 复制代码
# 3. triu / tril - 上三角/下三角矩阵
# 创建一个 3x3 矩阵
tensor = torch.tensor([[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]])
# 获取上三角矩阵
upper_triangle = torch.triu(tensor)
print(f"上三角矩阵:\n{upper_triangle}")

# 获取下三角矩阵
lower_triangle = torch.tril(tensor)
print(f"下三角矩阵:\n{lower_triangle}")

4. mm / bmm - 矩阵乘法 / 批量矩阵乘法

python 复制代码
# 4. mm / bmm - 矩阵乘法 / 批量矩阵乘法
# 创建两个 3x3 矩阵
A = torch.tensor([[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]])
B = torch.tensor([[9., 8., 7.], [6., 5., 4.], [3., 2., 1.]])
# 矩阵乘法
result_mm = torch.mm(A, B)
print(f"矩阵乘法结果:\n{result_mm}")

# 批量矩阵乘法 (batch matrix multiplication)
batch_A = torch.rand(10, 3, 3)  # 10 个 3x3 矩阵
batch_B = torch.rand(10, 3, 3)
result_bmm = torch.bmm(batch_A, batch_B)
print(f"批量矩阵乘法结果:\n{result_bmm[0]}")  # 显示第一个批次的结果

5. addmm / addbmm / addmv / addr / baddbmm - 矩阵运算

python 复制代码
# 5. addmm / addbmm / addmv / addr / baddbmm - 矩阵运算
# 使用 addmm 函数,矩阵加法加乘法
C = torch.tensor([[1., 1., 1.], [1., 1., 1.], [1., 1., 1.]])
result_addmm = torch.addmm(C, A, B)
print(f"addmm 矩阵加法加乘法结果:\n{result_addmm}")

# 使用 addbmm 批量矩阵运算
# 创建批量矩阵
batch_A = torch.rand(10, 3, 4)  # 10 个 3x4 矩阵
batch_B = torch.rand(10, 4, 3)  # 10 个 4x3 矩阵
batch_C = torch.rand(3, 3)

# 使用 addbmm 进行批量矩阵运算 C = A @ B + C
result_addbmm = torch.addbmm(batch_C, batch_A, batch_B)
print(f"addbmm 批量矩阵运算结果:\n{result_addbmm[0]}")  # 显示第一个批次的结果

6. t - 转置矩阵

python 复制代码
# 6. t - 转置矩阵
# 创建一个 2x3 矩阵
tensor = torch.tensor([[1., 2., 3.], [4., 5., 6.]])
# 转置矩阵
transposed_tensor = tensor.t()
print(f"转置后的矩阵:\n{transposed_tensor}")

7. dot / cross - 内积 / 外积

python 复制代码
# 7. dot / cross - 内积 / 外积
# 创建两个向量
vector1 = torch.tensor([1., 2., 3.])
vector2 = torch.tensor([4., 5., 6.])

# 计算内积
dot_product = torch.dot(vector1, vector2)
print(f"内积: {dot_product}")

# 计算外积
cross_product = torch.cross(vector1, vector2)
print(f"外积: {cross_product}")

8. inverse - 求逆矩阵

python 复制代码
# 8. inverse - 求逆矩阵
# 创建一个 2x2 矩阵
matrix = torch.tensor([[4., 7.], [2., 6.]])
# 求矩阵的逆
inverse_matrix = torch.inverse(matrix)
print(f"矩阵的逆:\n{inverse_matrix}")

9. svd - 奇异值分解

python 复制代码
# 9. svd - 奇异值分解
# 创建一个 3x3 矩阵
matrix = torch.tensor([[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]])
# 进行奇异值分解
U, S, V = torch.svd(matrix)
print(f"U 矩阵:\n{U}")
print(f"S 奇异值:\n{S}")
print(f"V 矩阵:\n{V}")

结语

在这篇博客中,我们探讨了如何使用 PyTorch 中的线性代数功能来简化矩阵操作。通过利用这些内置的高效函数,不仅可以节省时间和精力,还能减少潜在的错误,提升代码的可读性和可维护性。希望大家在使用 PyTorch 进行矩阵运算时,能够充分利用这些强大的工具,而不再需要自己重复实现基础功能。

相关推荐
audyxiao0019 小时前
人工智能顶级期刊PR论文解读|HCRT:基于相关性感知区域的混合网络,用于DCE-MRI图像中的乳腺肿瘤分割
网络·人工智能·智慧医疗·肿瘤分割
零售ERP菜鸟9 小时前
IT价值证明:从“成本中心”到“增长引擎”的确定性度量
大数据·人工智能·职场和发展·创业创新·学习方法·业界资讯
叫我:松哥9 小时前
基于大数据和深度学习的智能空气质量监测与预测平台,采用Spark数据预处理,利用TensorFlow构建LSTM深度学习模型
大数据·python·深度学习·机器学习·spark·flask·lstm
童话名剑10 小时前
目标检测(吴恩达深度学习笔记)
人工智能·目标检测·滑动窗口·目标定位·yolo算法·特征点检测
木卫四科技10 小时前
【木卫四 CES 2026】观察:融合智能体与联邦数据湖的安全数据运营成为趋势
人工智能·安全·汽车
吃茄子的猫15 小时前
quecpython中&的具体含义和使用场景
开发语言·python
珠海西格电力15 小时前
零碳园区有哪些政策支持?
大数据·数据库·人工智能·物联网·能源
じ☆冷颜〃15 小时前
黎曼几何驱动的算法与系统设计:理论、实践与跨领域应用
笔记·python·深度学习·网络协议·算法·机器学习
数据大魔方15 小时前
【期货量化实战】日内动量策略:顺势而为的短线交易法(Python源码)
开发语言·数据库·python·mysql·算法·github·程序员创富
启途AI15 小时前
2026免费好用的AIPPT工具榜:智能演示文稿制作新纪元
人工智能·powerpoint·ppt