PyTorch_点积运算

点积运算要求第一个矩阵 shape:(n, m),第二个矩阵 shape: (m, p), 两个矩阵点积运算shape为:(n,p)

  1. 运算符 @ 用于进行两个矩阵的点乘运算
  2. torch.mm 用于进行两个矩阵点乘运算,要求输入的矩阵为3维 (mm 代表 mat, mul)
  3. torch.bmm 用于批量进行矩阵点乘运算,要求输入的矩阵为3维 (b 代表 batch)
  4. torch.matmul 对进行点乘运算的两矩阵形状没有限定。
    a. 对于输入都是二维的张量相当于 mm 运算
    b. 对于输入都是三维的张量相当于 bmm 运算
    c. 对数输入的shape不同的张量,对应的最后几个维度必须符合矩阵运算规则

代码

python 复制代码
import torch 
import numpy as np 

# 使用@运算符
def test01():
    # 形状为:3行2列 
    data1 = torch.tensor([[1,2], [3,4], [5,6]])

    # 形状为:2行2列
    data2 = torch.tensor([[5,6], [7,8]])

    data = data1 @ data2
    print(data) 

# 使用 mm 函数
def test02():
    # 要求输入的张量形状都是二维的
     # 形状为:3行2列 
    data1 = torch.tensor([[1,2], [3,4], [5,6]])

    # 形状为:2行2列
    data2 = torch.tensor([[5,6], [7,8]])

    data = torch.mm(data1, data2)   
    print(data)
    print(data.shape)

# 使用 bmm 函数
def test03():
    # 第一个维度:表示批次
    # 第二个维度:多少行
    # 第三个维度:多少列
    data1 = torch.randn(3, 4, 5)
    data2 = torch.randn(3, 5, 8)

    data = torch.bmm(data1, data2) 
    print(data.shape)


# 使用 matmul 函数
def test04():
    # 对二维进行计算
    data1 = torch.randn(4,5)
    data2 = torch.randn(5,8)
    print(torch.matmul(data1, data2).shape)

    # 对三维进行计算
    data1 = torch.randn(3, 4, 5)
    data2 = torch.randn(3, 5, 8)
    print(torch.matmul(data1, data2).shape)

    data1 = torch.randn(3, 4, 5)
    data2 = torch.randn(5, 8)
    print(torch.matmul(data1, data2).shape) 

if __name__ == "__main__":
    test04()
相关推荐
im_AMBER3 小时前
学习日志19 python
python·学习
白-胖-子4 小时前
深入剖析大模型在文本生成式 AI 产品架构中的核心地位
人工智能·架构
想要成为计算机高手5 小时前
11. isaacsim4.2教程-Transform 树与Odometry
人工智能·机器人·自动驾驶·ros·rviz·isaac sim·仿真环境
mortimer5 小时前
安装NVIDIA Parakeet时,我遇到的两个Pip“小插曲”
python·github
@昵称不存在6 小时前
Flask input 和datalist结合
后端·python·flask
静心问道6 小时前
InstructBLIP:通过指令微调迈向通用视觉-语言模型
人工智能·多模态·ai技术应用
宇称不守恒4.06 小时前
2025暑期—06神经网络-常见网络2
网络·人工智能·神经网络
赵英英俊6 小时前
Python day25
python
东林牧之6 小时前
Django+celery异步:拿来即用,可移植性高
后端·python·django
何双新7 小时前
基于Tornado的WebSocket实时聊天系统:从零到一构建与解析
python·websocket·tornado