PyTorch中matmul函数使用详解和示例代码

torch.matmul 是 PyTorch 中用于执行矩阵乘法的函数,它根据输入张量的维度自动选择适当的矩阵乘法方式,包括:

  • 向量内积(1D @ 1D)
  • 矩阵乘向量(2D @ 1D)
  • 向量乘矩阵(1D @ 2D)
  • 矩阵乘矩阵(2D @ 2D)
  • 批量矩阵乘法(>2D)

函数原型

python 复制代码
torch.matmul(input, other, *, out=None) → Tensor
  • input:第一个张量
  • other:第二个张量
  • out(可选):指定输出张量

详细说明

torch.matmul(a, b) 根据 ab 的维度规则如下:

a 维度 b 维度 操作类型
1D 1D 向量点积
2D 1D 矩阵和向量相乘
1D 2D 向量和矩阵相乘
2D 2D 标准矩阵乘法
≥3D ≥3D 批量矩阵乘法(batch)

示例代码

1. 向量点积(1D @ 1D)

python 复制代码
import torch
a = torch.tensor([1.0, 2.0, 3.0])
b = torch.tensor([4.0, 5.0, 6.0])
result = torch.matmul(a, b)
print(result)  # 输出:32.0

2. 矩阵乘向量(2D @ 1D)

python 复制代码
a = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
b = torch.tensor([5.0, 6.0])
result = torch.matmul(a, b)
print(result)  # 输出:[17.0, 39.0]

3. 向量乘矩阵(1D @ 2D)

python 复制代码
a = torch.tensor([5.0, 6.0])
b = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
result = torch.matmul(a, b)
print(result)  # 输出:[23.0, 34.0]

4. 矩阵乘矩阵(2D @ 2D)

python 复制代码
a = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
b = torch.tensor([[5.0, 6.0], [7.0, 8.0]])
result = torch.matmul(a, b)
print(result)
# 输出:
# [[19.0, 22.0],
#  [43.0, 50.0]]

5. 批量矩阵乘法(3D @ 3D)

python 复制代码
a = torch.randn(10, 3, 4)
b = torch.randn(10, 4, 5)
result = torch.matmul(a, b)
print(result.shape)  # 输出:torch.Size([10, 3, 5])

综合示例:自定义线性层(类似 nn.Linear

下面是一个使用 torch.matmul 构建自定义线性层的完整示例,适合理解如何手动定义一个具有权重、偏置、支持自动求导的神经网络层,适合自定义网络结构或深入理解 PyTorch 的底层机制。

功能描述

  • 实现线性变换:y = x @ W^T + b
  • 使用 torch.matmul 执行矩阵乘法
  • 权重和偏置作为可训练参数
  • 支持 GPU 和自动求导

代码实现

python 复制代码
import torch
import torch.nn as nn

class MyLinear(nn.Module):
    def __init__(self, in_features, out_features):
        super(MyLinear, self).__init__()
        self.weight = nn.Parameter(torch.randn(out_features, in_features))  # shape: [out, in]
        self.bias = nn.Parameter(torch.zeros(out_features))                # shape: [out]

    def forward(self, x):
        # x: shape [batch_size, in_features]
        # weight: shape [out_features, in_features]
        # transpose weight -> shape [in_features, out_features], then matmul
        out = torch.matmul(x, self.weight.t()) + self.bias
        return out

使用示例

python 复制代码
batch_size = 4
in_dim = 6
out_dim = 3

x = torch.randn(batch_size, in_dim)
layer = MyLinear(in_dim, out_dim)

output = layer(x)
print(output.shape)  # torch.Size([4, 3])

与官方 nn.Linear 等效性验证(可选)

python 复制代码
# 官方线性层
torch.manual_seed(0)
official = nn.Linear(in_dim, out_dim)

# 自定义线性层,使用相同参数初始化
custom = MyLinear(in_dim, out_dim)
custom.weight.data.copy_(official.weight.data)
custom.bias.data.copy_(official.bias.data)

# 比较输出
x = torch.randn(2, in_dim)
out1 = official(x)
out2 = custom(x)
print(torch.allclose(out1, out2))  # True

说明

内容
torch.matmul 用于实现 x @ W.T 矩阵乘法
nn.Parameter 注册为可训练参数,自动加入 .parameters()
Module.forward() 用于定义前向传播逻辑

注意事项

  • 输入张量必须满足矩阵乘法的维度匹配规则。
  • 对于 >2D 的张量,PyTorch 会自动按 batch size 广播执行多组矩阵乘法。
  • torch.matmul 不支持标量乘法(标量乘张量可用 * 运算符)。

相关推荐
_waylau7 分钟前
【HarmonyOS NEXT+AI】问答08:仓颉编程语言是中文编程语言吗?
人工智能·华为·harmonyos·鸿蒙·仓颉编程语言·鸿蒙生态·鸿蒙6
攻城狮7号19 分钟前
Kimi 发布并开源 K2.5 模型:开始在逻辑和干活上卷你了
人工智能·ai编程·视觉理解·kimi code·kimi k2.5·agent 集群
szxinmai主板定制专家22 分钟前
基于 PC 的控制技术+ethercat+linux实时系统,助力追踪标签规模化生产,支持国产化
arm开发·人工智能·嵌入式硬件·yolo·fpga开发
测试开发Kevin26 分钟前
小tip:换行符CRLF 和 LF 的区别以及二者在实际项目中的影响
java·开发语言·python
爱学习的阿磊35 分钟前
使用PyTorch构建你的第一个神经网络
jvm·数据库·python
阿狸OKay36 分钟前
einops 库和 PyTorch 的 einsum 的语法
人工智能·pytorch·python
低调小一41 分钟前
Google AI Agent 白皮书拆解(1):从《Introduction to Agents》看清 Agent 的工程底座
人工智能
feasibility.44 分钟前
混元3D-dit-v2-mv-turbo生成3D模型初体验(ComfyUI)
人工智能·3d·aigc·三维建模·comfyui
极智-9961 小时前
GitHub 热榜项目-日榜精选(2026-02-02)| AI智能体、终端工具、视频生成等 | openclaw、99、Maestro等
人工智能·github·视频生成·终端工具·ai智能体·电子书管理·rust工具
悟纤1 小时前
AI 音乐创作中的音乐织体(Texture)完整指南 | Suno高级篇 | 第30篇
人工智能·suno·suno ai·suno api·ai music