PyTorch_点积运算

点积运算要求第一个矩阵 shape:(n, m),第二个矩阵 shape: (m, p), 两个矩阵点积运算shape为:(n,p)

  1. 运算符 @ 用于进行两个矩阵的点乘运算
  2. torch.mm 用于进行两个矩阵点乘运算,要求输入的矩阵为3维 (mm 代表 mat, mul)
  3. torch.bmm 用于批量进行矩阵点乘运算,要求输入的矩阵为3维 (b 代表 batch)
  4. torch.matmul 对进行点乘运算的两矩阵形状没有限定。
    a. 对于输入都是二维的张量相当于 mm 运算
    b. 对于输入都是三维的张量相当于 bmm 运算
    c. 对数输入的shape不同的张量,对应的最后几个维度必须符合矩阵运算规则

代码

python 复制代码
import torch 
import numpy as np 

# 使用@运算符
def test01():
    # 形状为:3行2列 
    data1 = torch.tensor([[1,2], [3,4], [5,6]])

    # 形状为:2行2列
    data2 = torch.tensor([[5,6], [7,8]])

    data = data1 @ data2
    print(data) 

# 使用 mm 函数
def test02():
    # 要求输入的张量形状都是二维的
     # 形状为:3行2列 
    data1 = torch.tensor([[1,2], [3,4], [5,6]])

    # 形状为:2行2列
    data2 = torch.tensor([[5,6], [7,8]])

    data = torch.mm(data1, data2)   
    print(data)
    print(data.shape)

# 使用 bmm 函数
def test03():
    # 第一个维度:表示批次
    # 第二个维度:多少行
    # 第三个维度:多少列
    data1 = torch.randn(3, 4, 5)
    data2 = torch.randn(3, 5, 8)

    data = torch.bmm(data1, data2) 
    print(data.shape)


# 使用 matmul 函数
def test04():
    # 对二维进行计算
    data1 = torch.randn(4,5)
    data2 = torch.randn(5,8)
    print(torch.matmul(data1, data2).shape)

    # 对三维进行计算
    data1 = torch.randn(3, 4, 5)
    data2 = torch.randn(3, 5, 8)
    print(torch.matmul(data1, data2).shape)

    data1 = torch.randn(3, 4, 5)
    data2 = torch.randn(5, 8)
    print(torch.matmul(data1, data2).shape) 

if __name__ == "__main__":
    test04()
相关推荐
计算机小手21 小时前
AI截图解答工具,可自定义设置多模态模型和提示词
人工智能·经验分享·开源软件
闲人编程1 天前
深入理解Python的`if __name__ == ‘__main__‘`:它到底做了什么?
服务器·数据库·python·main·name·魔法语句
资讯全球1 天前
2025年用户体验佳的大型企业报销系统
人工智能·百度·ux
毕设源码-郭学长1 天前
【开题答辩全过程】以 Python基于大数据的四川旅游景点数据分析与可视化为例,包含答辩的问题和答案
大数据·python·数据分析
海底的星星fly1 天前
【Prompt学习技能树地图】单一思维链优化-自我一致性提示工程原理、实践与代码实现
人工智能·语言模型·prompt
无妄无望1 天前
解码器系列(1)BERT
人工智能·深度学习·bert
葡萄与www1 天前
模块化神经网络
人工智能·深度学习·神经网络·机器学习
Lin_Aries_04211 天前
容器化 Flask 应用程序
linux·后端·python·docker·容器·flask
MediaTea1 天前
Jupyter Notebook:基于 Web 的交互式编程环境
前端·ide·人工智能·python·jupyter