PyTorch_点积运算

点积运算要求第一个矩阵 shape:(n, m),第二个矩阵 shape: (m, p), 两个矩阵点积运算shape为:(n,p)

  1. 运算符 @ 用于进行两个矩阵的点乘运算
  2. torch.mm 用于进行两个矩阵点乘运算,要求输入的矩阵为3维 (mm 代表 mat, mul)
  3. torch.bmm 用于批量进行矩阵点乘运算,要求输入的矩阵为3维 (b 代表 batch)
  4. torch.matmul 对进行点乘运算的两矩阵形状没有限定。
    a. 对于输入都是二维的张量相当于 mm 运算
    b. 对于输入都是三维的张量相当于 bmm 运算
    c. 对数输入的shape不同的张量,对应的最后几个维度必须符合矩阵运算规则

代码

python 复制代码
import torch 
import numpy as np 

# 使用@运算符
def test01():
    # 形状为:3行2列 
    data1 = torch.tensor([[1,2], [3,4], [5,6]])

    # 形状为:2行2列
    data2 = torch.tensor([[5,6], [7,8]])

    data = data1 @ data2
    print(data) 

# 使用 mm 函数
def test02():
    # 要求输入的张量形状都是二维的
     # 形状为:3行2列 
    data1 = torch.tensor([[1,2], [3,4], [5,6]])

    # 形状为:2行2列
    data2 = torch.tensor([[5,6], [7,8]])

    data = torch.mm(data1, data2)   
    print(data)
    print(data.shape)

# 使用 bmm 函数
def test03():
    # 第一个维度:表示批次
    # 第二个维度:多少行
    # 第三个维度:多少列
    data1 = torch.randn(3, 4, 5)
    data2 = torch.randn(3, 5, 8)

    data = torch.bmm(data1, data2) 
    print(data.shape)


# 使用 matmul 函数
def test04():
    # 对二维进行计算
    data1 = torch.randn(4,5)
    data2 = torch.randn(5,8)
    print(torch.matmul(data1, data2).shape)

    # 对三维进行计算
    data1 = torch.randn(3, 4, 5)
    data2 = torch.randn(3, 5, 8)
    print(torch.matmul(data1, data2).shape)

    data1 = torch.randn(3, 4, 5)
    data2 = torch.randn(5, 8)
    print(torch.matmul(data1, data2).shape) 

if __name__ == "__main__":
    test04()
相关推荐
苏苏susuus6 分钟前
机器学习:欠拟合、过拟合、正则化
人工智能·机器学习
爱写代码的小朋友7 分钟前
数字化浪潮下:信息化教学模式与人工智能的协同创新发展研究
人工智能·数字化
geneculture7 分钟前
融智学内涵、数学定义和跨学科应用的四个核心公式
人工智能·数学建模·课程设计·融智学的重要应用·融智学应用场景
whaosoft-14313 分钟前
51c自动驾驶~合集56
人工智能
量子-Alex13 分钟前
【目标检测】【AAAI-2022】Anchor DETR
人工智能·目标检测·计算机视觉
开开心心就好42 分钟前
高效合并 Excel 表格实用工具
开发语言·javascript·python·qt·r语言·ocr·excel
fish_study_csdn1 小时前
PyCharm接入DeepSeek,实现高效AI编程
python·pycharm·ai编程
MindTechBuilder1 小时前
安全架构的深度技术剖析
人工智能
云卓SKYDROID1 小时前
无人机报警器探测模块技术解析!
人工智能·无人机·科普·高科技·报警器
西猫雷婶1 小时前
深度学习|pytorch基本运算
人工智能·pytorch·深度学习