PyTorch学习(12):PyTorch的张量相乘(torch.matmul)

PyTorch学习(1):torch.meshgrid的使用-CSDN博客

PyTorch学习(2):torch.device-CSDN博客

PyTorch学习(9):torch.topk-CSDN博客

PyTorch学习(10):torch.where-CSDN博客

PyTorch学习(11):PyTorch的形状变换(view, reshape)与维度变换(transpose, permute)-CSDN博客


目录

[1. 写在前面](#1. 写在前面)

[2. 基本用法](#2. 基本用法)

[3. 高级用法](#3. 高级用法)

[4. 注意事项](#4. 注意事项)

[5. 例程](#5. 例程)


1. 写在前面

torch.matmul()是PyTorch库中用于执行矩阵乘法的函数。它可以处理不同尺寸的矩阵,包括批量矩阵和张量。该函数的特点在于能够利用Python的广播机制,处理维度不同的张量结构进行相乘操作。

torch.matmul也可以使用"@"符号来替代。

2. 基本用法

当两个张量都是一维的,torch.matmul()返回两个向量的点积。

当两个张量都是二维的,torch.matmul()返回矩阵乘积。

如果第一个参数是一维张量,第二个参数是二维张量,torch.matmul()在一维张量的前面增加一个维度,然后进行矩阵乘法,矩阵乘法结束后移除添加的维度。

如果第一个参数是二维张量,第二个参数是一维张量,torch.matmul()返回矩阵×向量的积。

如果两个参数至少为一维,且其中一个参数的维度大于等于2,torch.matmul()会进行批量矩阵乘法。

3. 高级用法

对于高维张量,torch.matmul()可以进行批量矩阵乘法。具体来说,如果输入是一个形状为(j × 1 × n × n)的张量,另一个是形状为(k × n × n)的张量,输出将是形状为(j × k × n × n)的张量。

torch.matmul()函数还支持在特定维度上进行广播,即在不匹配的维度上复制数据以使其尺寸一致,从而进行矩阵乘法。

4. 注意事项

在使用torch.matmul()时,需要注意矩阵乘法的基本规则,即第一个矩阵的列数必须等于第二个矩阵的行数。

如果遇到维度不匹配的情况,可以使用torch.Tensor.view()或torch.Tensor.reshape()函数来调整张量的形状。

在神经网络的训练和推理中,torch.matmul()函数是实现全连接层、卷积层等操作的关键组件。

5. 例程

python 复制代码
import torch


# 创建两个一维张量(向量)

vector1 = torch.tensor([1, 2, 3])

vector2 = torch.tensor([4, 5, 6])

# 使用torch.matmul()计算点积

dot_product = torch.matmul(vector1, vector2)

print("Dot product of two vectors:", dot_product)

# 创建两个二维张量(矩阵)

matrix1 = torch.tensor([[1, 2], [3, 4]])

matrix2 = torch.tensor([[5, 6], [7, 8]])

# 使用torch.matmul()进行矩阵乘法

matrix_product = torch.matmul(matrix1, matrix2)

print("Matrix multiplication result:\n", matrix_product)

# 创建一个一维张量和一个二维张量

vector = torch.tensor([1, 2, 3])

matrix = torch.tensor([[4, 5], [6, 7], [8, 9]])

# 使用torch.matmul()进行矩阵乘法,其中一维张量会被视为列向量

result = torch.matmul(vector, matrix)

print("Matrix multiplication with a vector and a matrix:\n", result)

# 创建两个三维张量(批量矩阵)

batch1 = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])

batch2 = torch.tensor([[[9, 10], [11, 12]], [[13, 14], [15, 16]]])

# 使用torch.matmul()进行批量矩阵乘法

batch_product = torch.matmul(batch1, batch2)

print("Batch matrix multiplication result:\n", batch_product)
相关推荐
on_pluto_3 小时前
【debug】解决 conda 和 镜像下载pytorch太慢的问题
人工智能·pytorch·conda
nix.gnehc3 小时前
PyTorch基础概念
人工智能·pytorch·python
●VON9 小时前
开源 vs 商业:主流AI生态概览——从PyTorch到OpenAI的技术格局之争
人工智能·pytorch·开源
shayudiandian12 小时前
用PyTorch训练一个猫狗分类器
人工智能·pytorch·深度学习
xwill*16 小时前
RDT-1B: A DIFFUSION FOUNDATION MODEL FOR BIMANUAL MANIPULATION
人工智能·pytorch·python·深度学习
程序猿追16 小时前
PyTorch算子模板库技术解读:无缝衔接PyTorch模型与Ascend硬件的桥梁
人工智能·pytorch·python·深度学习·机器学习
操练起来1 天前
【昇腾CANN训练营·第八期】Ascend C生态兼容:基于PyTorch Adapter的自定义算子注册与自动微分实现
人工智能·pytorch·acl·昇腾·cann
AI即插即用4 天前
即插即用系列 | 2025 MambaNeXt-YOLO 炸裂登场!YOLO 激吻 Mamba,打造实时检测新霸主
人工智能·pytorch·深度学习·yolo·目标检测·计算机视觉·视觉检测
忘却的旋律dw5 天前
使用LLM模型的tokenizer报错AttributeError: ‘dict‘ object has no attribute ‘model_type‘
人工智能·pytorch·python
studytosky5 天前
深度学习理论与实战:MNIST 手写数字分类实战
人工智能·pytorch·python·深度学习·机器学习·分类·matplotlib