PyTorch_点积运算

点积运算要求第一个矩阵 shape:(n, m),第二个矩阵 shape: (m, p), 两个矩阵点积运算shape为:(n,p)

  1. 运算符 @ 用于进行两个矩阵的点乘运算
  2. torch.mm 用于进行两个矩阵点乘运算,要求输入的矩阵为3维 (mm 代表 mat, mul)
  3. torch.bmm 用于批量进行矩阵点乘运算,要求输入的矩阵为3维 (b 代表 batch)
  4. torch.matmul 对进行点乘运算的两矩阵形状没有限定。
    a. 对于输入都是二维的张量相当于 mm 运算
    b. 对于输入都是三维的张量相当于 bmm 运算
    c. 对数输入的shape不同的张量,对应的最后几个维度必须符合矩阵运算规则

代码

python 复制代码
import torch 
import numpy as np 

# 使用@运算符
def test01():
    # 形状为:3行2列 
    data1 = torch.tensor([[1,2], [3,4], [5,6]])

    # 形状为:2行2列
    data2 = torch.tensor([[5,6], [7,8]])

    data = data1 @ data2
    print(data) 

# 使用 mm 函数
def test02():
    # 要求输入的张量形状都是二维的
     # 形状为:3行2列 
    data1 = torch.tensor([[1,2], [3,4], [5,6]])

    # 形状为:2行2列
    data2 = torch.tensor([[5,6], [7,8]])

    data = torch.mm(data1, data2)   
    print(data)
    print(data.shape)

# 使用 bmm 函数
def test03():
    # 第一个维度:表示批次
    # 第二个维度:多少行
    # 第三个维度:多少列
    data1 = torch.randn(3, 4, 5)
    data2 = torch.randn(3, 5, 8)

    data = torch.bmm(data1, data2) 
    print(data.shape)


# 使用 matmul 函数
def test04():
    # 对二维进行计算
    data1 = torch.randn(4,5)
    data2 = torch.randn(5,8)
    print(torch.matmul(data1, data2).shape)

    # 对三维进行计算
    data1 = torch.randn(3, 4, 5)
    data2 = torch.randn(3, 5, 8)
    print(torch.matmul(data1, data2).shape)

    data1 = torch.randn(3, 4, 5)
    data2 = torch.randn(5, 8)
    print(torch.matmul(data1, data2).shape) 

if __name__ == "__main__":
    test04()
相关推荐
胖头鱼的鱼缸(尹海文)19 分钟前
数据库管理-第376期 Oracle AI DB 23.26新特性一览(20251016)
数据库·人工智能·oracle
瑞禧生物ruixibio23 分钟前
4-ARM-PEG-Pyrene(2)/Biotin(2),多功能化聚乙二醇修饰荧光标记生物分子的设计与应用探索
arm开发·人工智能
大千AI助手27 分钟前
Huber损失函数:稳健回归的智慧之选
人工智能·数据挖掘·回归·损失函数·mse·mae·huber损失函数
墨利昂38 分钟前
10.17RNN情感分析实验:加载预训练词向量模块整理
人工智能·rnn·深度学习
【建模先锋】42 分钟前
一区直接写!CEEMDAN分解 + Informer-LSTM +XGBoost组合预测模型
人工智能·lstm·ceemdan·预测模型·风速预测·时间序列预测模型
fsnine1 小时前
YOLOv2原理介绍
人工智能·计算机视觉·目标跟踪
倔强的石头1061 小时前
AI修图革命:IOPaint+cpolar让废片拯救触手可及
人工智能·cpolar·iopaint
文火冰糖的硅基工坊1 小时前
[人工智能-大模型-15]:大模型典型产品对比 - 数字人
人工智能·大模型·大语言模型
这里有鱼汤1 小时前
📊量化实战篇:如何计算RSI指标的“拥挤度指标”?
后端·python
JJJJ_iii1 小时前
【机器学习05】神经网络、模型表示、前向传播、TensorFlow实现
人工智能·pytorch·python·深度学习·神经网络·机器学习·tensorflow