PyTorch_点积运算

点积运算要求第一个矩阵 shape:(n, m),第二个矩阵 shape: (m, p), 两个矩阵点积运算shape为:(n,p)

  1. 运算符 @ 用于进行两个矩阵的点乘运算
  2. torch.mm 用于进行两个矩阵点乘运算,要求输入的矩阵为3维 (mm 代表 mat, mul)
  3. torch.bmm 用于批量进行矩阵点乘运算,要求输入的矩阵为3维 (b 代表 batch)
  4. torch.matmul 对进行点乘运算的两矩阵形状没有限定。
    a. 对于输入都是二维的张量相当于 mm 运算
    b. 对于输入都是三维的张量相当于 bmm 运算
    c. 对数输入的shape不同的张量,对应的最后几个维度必须符合矩阵运算规则

代码

python 复制代码
import torch 
import numpy as np 

# 使用@运算符
def test01():
    # 形状为:3行2列 
    data1 = torch.tensor([[1,2], [3,4], [5,6]])

    # 形状为:2行2列
    data2 = torch.tensor([[5,6], [7,8]])

    data = data1 @ data2
    print(data) 

# 使用 mm 函数
def test02():
    # 要求输入的张量形状都是二维的
     # 形状为:3行2列 
    data1 = torch.tensor([[1,2], [3,4], [5,6]])

    # 形状为:2行2列
    data2 = torch.tensor([[5,6], [7,8]])

    data = torch.mm(data1, data2)   
    print(data)
    print(data.shape)

# 使用 bmm 函数
def test03():
    # 第一个维度:表示批次
    # 第二个维度:多少行
    # 第三个维度:多少列
    data1 = torch.randn(3, 4, 5)
    data2 = torch.randn(3, 5, 8)

    data = torch.bmm(data1, data2) 
    print(data.shape)


# 使用 matmul 函数
def test04():
    # 对二维进行计算
    data1 = torch.randn(4,5)
    data2 = torch.randn(5,8)
    print(torch.matmul(data1, data2).shape)

    # 对三维进行计算
    data1 = torch.randn(3, 4, 5)
    data2 = torch.randn(3, 5, 8)
    print(torch.matmul(data1, data2).shape)

    data1 = torch.randn(3, 4, 5)
    data2 = torch.randn(5, 8)
    print(torch.matmul(data1, data2).shape) 

if __name__ == "__main__":
    test04()
相关推荐
xiaohanbao0928 分钟前
day16 numpy和shap深入理解
python·学习·机器学习·信息可视化·numpy·pandas
eqwaak033 分钟前
基于DrissionPage的高效爬虫开发:以小说网站数据抓取为例
爬虫·python·语言模型·性能优化·交互·drissionpage
VI8664956I2634 分钟前
海外社交软件技术深潜:实时互动系统与边缘计算的极限优化
人工智能·实时互动·边缘计算
(・Д・)ノ37 分钟前
python打卡day16
开发语言·python
每天都要写算法(努力版)42 分钟前
【神经网络与深度学习】生成模型-单位高斯分布 Generating Models-unit Gaussian distribution
人工智能·深度学习·神经网络·生成模型
何似在人间5751 小时前
LangChain4j +DeepSeek大模型应用开发——7 项目实战 创建硅谷小鹿
java·人工智能·ai·大模型开发
小小爬虾1 小时前
在pycharm profession 2020.3上离线安装.whl类型的包(以PySimpleGUI为例)
ide·python·pycharm
Timmer丿1 小时前
Spring AI开发跃迁指南(第二章:急速上手3——Advisor核心原理、源码讲解及使用实例)
java·人工智能·spring
xrgs_shz2 小时前
基于MATLAB图像中的圆形目标识别和标记
图像处理·人工智能·计算机视觉·matlab
pen-ai2 小时前
【NLP】32. Transformers (HuggingFace Pipelines 实战)
人工智能·自然语言处理