PyTorch_点积运算

点积运算要求第一个矩阵 shape:(n, m),第二个矩阵 shape: (m, p), 两个矩阵点积运算shape为:(n,p)

  1. 运算符 @ 用于进行两个矩阵的点乘运算
  2. torch.mm 用于进行两个矩阵点乘运算,要求输入的矩阵为3维 (mm 代表 mat, mul)
  3. torch.bmm 用于批量进行矩阵点乘运算,要求输入的矩阵为3维 (b 代表 batch)
  4. torch.matmul 对进行点乘运算的两矩阵形状没有限定。
    a. 对于输入都是二维的张量相当于 mm 运算
    b. 对于输入都是三维的张量相当于 bmm 运算
    c. 对数输入的shape不同的张量,对应的最后几个维度必须符合矩阵运算规则

代码

python 复制代码
import torch 
import numpy as np 

# 使用@运算符
def test01():
    # 形状为:3行2列 
    data1 = torch.tensor([[1,2], [3,4], [5,6]])

    # 形状为:2行2列
    data2 = torch.tensor([[5,6], [7,8]])

    data = data1 @ data2
    print(data) 

# 使用 mm 函数
def test02():
    # 要求输入的张量形状都是二维的
     # 形状为:3行2列 
    data1 = torch.tensor([[1,2], [3,4], [5,6]])

    # 形状为:2行2列
    data2 = torch.tensor([[5,6], [7,8]])

    data = torch.mm(data1, data2)   
    print(data)
    print(data.shape)

# 使用 bmm 函数
def test03():
    # 第一个维度:表示批次
    # 第二个维度:多少行
    # 第三个维度:多少列
    data1 = torch.randn(3, 4, 5)
    data2 = torch.randn(3, 5, 8)

    data = torch.bmm(data1, data2) 
    print(data.shape)


# 使用 matmul 函数
def test04():
    # 对二维进行计算
    data1 = torch.randn(4,5)
    data2 = torch.randn(5,8)
    print(torch.matmul(data1, data2).shape)

    # 对三维进行计算
    data1 = torch.randn(3, 4, 5)
    data2 = torch.randn(3, 5, 8)
    print(torch.matmul(data1, data2).shape)

    data1 = torch.randn(3, 4, 5)
    data2 = torch.randn(5, 8)
    print(torch.matmul(data1, data2).shape) 

if __name__ == "__main__":
    test04()
相关推荐
宁大小白4 分钟前
pythonstudy Day40
python·机器学习
湘-枫叶情缘6 分钟前
“智律提效”AI数字化运营落地项目可行性方案
大数据·人工智能·产品运营
却道天凉_好个秋7 分钟前
OpenCV(四十二):图像分割原理
人工智能·opencv·计算机视觉·图像分割
Coding茶水间10 分钟前
基于深度学习的水下海洋生物检测系统演示与介绍(YOLOv12/v11/v8/v5模型+Pyqt5界面+训练代码+数据集)
图像处理·人工智能·深度学习·yolo·目标检测·机器学习·计算机视觉
KOYUELEC光与电子请努力拼搏~15 分钟前
AMAZINGIC晶焱科技:AI 驱动的车载革命:高速通信下的保护设计你准备好了吗?
人工智能·科技
禾从道15 分钟前
「杂想」未来的AI电子设备和胡思乱想。
人工智能·智能手机·创业创新·小米·豆包手机
HuggingFace19 分钟前
Codex 正在推动开源 AI 模型的训练与发布
人工智能
深蓝海拓22 分钟前
PySide6从0开始学习的笔记(十三) IDE的选择
笔记·python·qt·学习·pyqt
HuggingFace31 分钟前
经同意的语音克隆
人工智能
智算菩萨37 分钟前
实战:用 Python + 传统NLP 自动总结长文章
开发语言·人工智能·python