torch.matmul() VS torch.einsum()

torch.matmul():标准的矩阵乘法

  • 向量-向量(点积)

    python 复制代码
    a = torch.randn(3)  # [3]
    b = torch.randn(3)  # [3]
    c = torch.matmul(a, b)  # 点积,标量输出
  • 矩阵-向量

    python 复制代码
    A = torch.randn(3, 4)  # [3, 4]
    x = torch.randn(4)     # [4]
    y = torch.matmul(A, x) # [3]
  • 矩阵-矩阵

    python 复制代码
    A = torch.randn(3, 4)  # [3, 4]
    B = torch.randn(4, 5)  # [4, 5]
    C = torch.matmul(A, B) # [3, 5]
  • 批量矩阵乘法(更高维张量)

    python 复制代码
    A = torch.randn(2, 3, 4)  # [B, M, K]
    B = torch.randn(2, 4, 5)  # [B, K, N]
    C = torch.matmul(A, B)     # [B, M, N]

    torch.einsum:爱因斯坦求和约定(更通用的张量运算工具)

  • 矩阵乘法

    python 复制代码
    A = torch.randn(3, 4)
    B = torch.randn(4, 5)
    C = torch.einsum("ik,kj->ij", A, B)  # 等价于 A @ B
    
    A = torch.randn(2, 3, 4)  # [B, M, K]
    B = torch.randn(2, 4, 5)  # [B, K, N]
    C = torch.einsum("bik,bkj->bij", A, B)  # [B, M, N]
    
    a = torch.randn(3)
    b = torch.randn(3)
    c = torch.einsum("i,i->", a, b)  # 点积,标量输出
  • 转置

    python 复制代码
    A = torch.randn(3, 4)
    B = torch.einsum("ij->ji", A)  # 等价于 A.T
  • 对角线提取

  • 张量收缩(Tensor Contraction)(高阶张量乘法)

    python 复制代码
    A = torch.randn(2, 3, 4, 5)
    B = torch.randn(2, 4, 5, 6)
    C = torch.einsum("abcd,abde->abce", A, B)  # 对 d 维度收缩
  • 广播运算

torch.matmul torch.einsum
灵活性 仅支持矩阵乘法类操作 支持任意张量运算(转置、收缩等)
可读性 直观(A @ B 需要熟悉爱因斯坦求和约定
性能 高度优化(推荐用于标准矩阵乘法) 灵活但可能稍慢
广播支持
批量处理 自动支持 需显式指定批量维度
相关推荐
王锋(oxwangfeng)1 分钟前
基于 DINO 与 Chinese-CLIP 的自动驾驶语义检索系统架构
人工智能·机器学习·自动驾驶
巫婆理发2222 分钟前
自然语言处理与词嵌入
人工智能·自然语言处理
共享家95276 分钟前
基于 Coze 工作流搭建历史主题图片生成器
前端·人工智能·js
IT研究所10 分钟前
信创浪潮下 ITSM 的价值重构与实践赋能
大数据·运维·人工智能·安全·低代码·重构·自动化
AI职业加油站11 分钟前
Python技术应用工程师:互联网行业技能赋能者
大数据·开发语言·人工智能·python·数据分析
I'mChloe11 分钟前
机器学习核心分支:深入解析非监督学习
人工智能·学习·机器学习
J_Xiong011716 分钟前
【Agents篇】06:Agent 的感知模块——多模态输入处理
人工智能·ai agent·视觉感知
深蓝海域知识库19 分钟前
深蓝海域中标大型机电企业大模型知识工程平台项目
大数据·人工智能
爱吃泡芙的小白白19 分钟前
机器学习中的“隐形之手”:偏置项深入探讨与资源全导航
人工智能·机器学习
爱打代码的小林25 分钟前
用 PyTorch 实现 CBOW 模型
人工智能·pytorch·python