Day 48

Day 48

  • PyTorch 随机张量生成与广播机制详解
    • 一、随机张量的生成:"
      • 1.1 torch.randn 函数:"
      • 1.2 其他随机函数:"
      • 1.3 输出维度测试:"
    • 二、广播机制:"
      • 2.1 广播机制原理:"
      • 2.2 广播实例:"
      • 2.3 乘法的广播机制:矩阵运算中的 "巧妙搭配"

PyTorch 随机张量生成与广播机制详解

在深度学习的探索之旅中,PyTorch 作为强大的开源机器学习库,为我们提供了诸多便捷的功能。今天,就让我们一同深入探究 PyTorch 中随机张量生成以及广播机制的奥秘,解锁高效张量操作的新技能。

一、随机张量的生成:"

在深度学习的诸多场景里,随机生成张量都有着举足轻重的地位。无论是模型权重的初始化,还是计算输入维度经过模块后的输出维度,在开发和测试阶段,随机张量都能让我们摆脱对真实数据的依赖,快速推进实验进程。

1.1 torch.randn 函数:"

torch.randn() 凭借其简洁的语法与强大的功能,成为生成随机张量的常用之选。它能依据标准正态分布(均值 0,标准差 1)生成填充的张量,为模型参数初始化、测试数据生成以及模拟输入特征等场景提供了极大便利。

  • 函数签名与参数剖析torch.randn(*size, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False)

    • size :必选参数,精准定义输出张量的形状,例如 (3, 4) 即构建 3 行 4 列的矩阵。
    • dtype :可选参数,用于指定张量的数据类型,如 torch.float32torch.int64 等。
    • device :可选参数,明确张量存储的设备,是 cpu 还是 cuda
    • requires_grad :可选参数,决定张量是否需要计算梯度,在模型训练时尤为关键。
  • 示例演绎

python 复制代码
import torch

# 标量(0 维张量)生成
scalar = torch.randn(())
print(f"标量: {scalar}, 形状: {scalar.shape}")  

# 向量(1 维张量)打造
vector = torch.randn(5)
print(f"向量: {vector}, 形状: {vector.shape}")  

# 矩阵(2 维张量)构造
matrix = torch.randn(3, 4)
print(f"矩阵:{matrix},矩阵形状: {matrix.shape}")  

# 3 维张量(常用于图像数据)塑造
tensor_3d = torch.randn(3, 224, 224)
print(f"3 维张量形状: {tensor_3d.shape}")  

# 4 维张量(批量图像数据)雕琢
tensor_4d = torch.randn(2, 3, 224, 224)
print(f"4 维张量形状: {tensor_4d.shape}")  

1.2 其他随机函数:"

除了 torch.randn(),PyTorch 还提供了其他生成不同分布随机数的函数,满足多样化的应用场景:

  • torch.rand() :在 [0, 1) 范围内均匀分布的随机数生成。
python 复制代码
x = torch.rand(3, 2)
print(f"均匀分布随机数: {x}, 形状: {x.shape}")
  • torch.randint() :生成指定范围内的随机整数。
python 复制代码
x = torch.randint(low=0, high=10, size=(3,))
print(f"随机整数: {x}, 形状: {x.shape}")
  • torch.normal() :生成指定均值和标准差的正态分布随机数。
python 复制代码
mean = torch.tensor([0.0, 0.0])
std = torch.tensor([1.0, 2.0])
x = torch.normal(mean, std)
print(f"正态分布随机数: {x}, 形状: {x.shape}")

1.3 输出维度测试:"

在深度学习模型搭建过程中,精准把握每一层输出张量的形状至关重要。借助随机张量生成与打印输出,我们能轻松追踪尺寸变化:

python 复制代码
import torch
import torch.nn as nn

# 生成输入张量 (批量大小, 通道数, 高度, 宽度)
input_tensor = torch.randn(1, 3, 32, 32)
print(f"输入尺寸: {input_tensor.shape}")  

# 1. 卷积层操作
conv1 = nn.Conv2d(
    in_channels=3,        
    out_channels=16,      
    kernel_size=3,        
    stride=1,             
    padding=1             
)
conv_output = conv1(input_tensor)
print(f"卷积后尺寸: {conv_output.shape}")  

# 2. 池化层操作 (减小空间尺寸)
pool = nn.MaxPool2d(kernel_size=2, stride=2)
pool_output = pool(conv_output)
print(f"池化后尺寸: {pool_output.shape}")  

# 3. 将多维张量展平为向量
flattened = pool_output.view(pool_output.size(0), -1)
print(f"展平后尺寸: {flattened.shape}")  

# 4. 线性层操作
fc1 = nn.Linear(
    in_features=4096,     
    out_features=128      
)
fc_output = fc1(flattened)
print(f"线性层后尺寸: {fc_output.shape}")  

# 5. 再经过一个线性层(例如分类器)
fc2 = nn.Linear(128, 10)  
final_output = fc2(fc_output)
print(f"最终输出尺寸: {final_output.shape}")  
print(final_output)

# 使用Softmax替代Sigmoid
softmax = nn.Softmax(dim=1)  
class_probs = softmax(final_output)
print(f"Softmax输出: {class_probs}")  
print(f"Softmax输出总和: {class_probs.sum():.4f}")

这种维度测试方式,在非交互式环境如 PyCharm 中,配合断点与调试控制台,能有效避免维度不匹配的报错,保障模型搭建的顺利推进。

二、广播机制:"

PyTorch 的广播机制,赋予了不同形状张量间进行算术运算的 "超能力",无需显式扩展或复制数据,极大简化了代码,提升运算效率。

2.1 广播机制原理:"

当对形状不同的张量进行运算时,PyTorch 依照以下规则巧妙处理维度兼容性:

  • 从右向左比较维度 :从张量的最后一个维度(最右侧)开始向前逐维比较。

  • 维度扩展条件

    • 相等维度 :若两个张量在某一维度上大小相同,则继续比较下一维度。
    • 一维扩展 :若其中一个张量的某个维度大小为 1,则该维度会被扩展为另一个张量对应维度的大小。
    • 不兼容错误 :若某一维度大小既不相同也不为 1,则抛出 RuntimeError
  • 维度补全规则 :若一个张量的维度少于另一个,则在其左侧补 1 直至维度数匹配。

2.2 广播实例:"

  • 二维张量与一维向量相加
python 复制代码
import torch

# 创建原始张量
a = torch.tensor([[10], [20], [30]])  # 形状: (3, 1)
b = torch.tensor([1, 2, 3])          # 形状: (3,)

result = a + b

print("原始张量a:")
print(a)

print("\n原始张量b:")
print(b)

print("\n加法结果:")
print(result)
  • 三维张量与二维张量相加
python 复制代码
# 创建原始张量
a = torch.tensor([[[1], [2]], [[3], [4]]])  # 形状: (2, 2, 1)
b = torch.tensor([[10, 20]])               # 形状: (1, 2)

# 广播过程
result = a + b
print("原始张量a:")
print(a)

print("\n原始张量b:")
print(b)

print("\n加法结果:")
print(result)
  • 二维张量与标量相加
python 复制代码
# 创建原始张量
a = torch.tensor([[1, 2], [3, 4]])  # 形状: (2, 2)
b = 10                              # 标量,形状视为 ()

# 广播过程
result = a + b
print("原始张量a:")
print(a)

print("\n标量b:")
print(b)

print("\n加法结果:")
print(result)
  • 高维张量与低维张量相加
python 复制代码
# 创建原始张量
a = torch.tensor([[[1, 2], [3, 4]]])  # 形状: (1, 2, 2)
b = torch.tensor([[5, 6]])            # 形状: (1, 2)

# 广播过程
result = a + b
print("原始张量a:")
print(a)

print("\n原始张量b:")
print(b)

print("\n加法结果:")
print(result)

2.3 乘法的广播机制:矩阵运算中的 "巧妙搭配"

矩阵乘法除了遵循通用广播规则,在维度约束上还有特殊要求:

  • 最后两个维度的适配 :A.shape[-1] == B.shape[-2],即 A 的列数等于 B 的行数。

  • 其他维度(批量维度)的广播 :遵循通用广播规则。

  • 批量矩阵与单个矩阵相乘

python 复制代码
# A: 批量大小为2,每个是3×4的矩阵
A = torch.randn(2, 3, 4)  # 形状: (2, 3, 4)

# B: 单个4×5的矩阵
B = torch.randn(4, 5)     # 形状: (4, 5)

# 广播过程:
result = A @ B            # 结果形状: (2, 3, 5)

print("A形状:", A.shape)
print("B形状:", B.shape)
print("结果形状:", result.shape)
  • 批量矩阵与批量矩阵相乘(部分广播)
python 复制代码
# A: 批量大小为3,每个是2×4的矩阵
A = torch.randn(3, 2, 4)  # 形状: (3, 2, 4)

# B: 批量大小为1,每个是4×5的矩阵
B = torch.randn(1, 4, 5)  # 形状: (1, 4, 5)

# 广播过程:
result = A @ B            # 结果形状: (3, 2, 5)

print("A形状:", A.shape)
print("B形状:", B.shape)
print("结果形状:", result.shape)
  • 三维张量与二维张量相乘(高维广播)
python 复制代码
# A: 批量大小为2,通道数为3,每个是4×5的矩阵
A = torch.randn(2, 3, 4, 5)  # 形状: (2, 3, 4, 5)

# B: 单个5×6的矩阵
B = torch.randn(5, 6)        # 形状: (5, 6)

# 广播过程:
result = A @ B               # 结果形状: (2, 3, 4, 6)

print("A形状:", A.shape)
print("B形状:", B.shape)
print("结果形状:", result.shape)

@浙大疏锦行

相关推荐
鹏码纵横2 小时前
已解决:java.lang.ClassNotFoundException: com.mysql.jdbc.Driver 异常的正确解决方法,亲测有效!!!
java·python·mysql
仙人掌_lz2 小时前
Qwen-3 微调实战:用 Python 和 Unsloth 打造专属 AI 模型
人工智能·python·ai·lora·llm·微调·qwen3
猎人everest2 小时前
快速搭建运行Django第一个应用—投票
后端·python·django
猎人everest2 小时前
Django的HelloWorld程序
开发语言·python·django
chusheng18403 小时前
2025最新版!Windows Python3 超详细安装图文教程(支持 Python3 全版本)
windows·python·python3下载·python 安装教程·python3 安装教程
别勉.3 小时前
Python Day50
开发语言·python
xiaohanbao094 小时前
day54 python对抗生成网络
网络·python·深度学习·学习
爬虫程序猿4 小时前
利用 Python 爬虫按关键字搜索 1688 商品
开发语言·爬虫·python
英杰.王4 小时前
深入 Java 泛型:基础应用与实战技巧
java·windows·python
安替-AnTi4 小时前
基于Django的购物系统
python·sql·django·毕设·购物系统