Day43 随机张量与广播机制

随机张量是深度学习中生成初始权重、噪声、随机采样的核心工具,PyTorch 提供了丰富的随机张量生成函数,先看核心用法:

核心生成函数
函数 作用 常用场景
torch.rand() 生成 [0,1) 均匀分布的随机张量 权重初始化、随机掩码
torch.randn() 生成标准正态分布(μ=0, σ=1)随机张量 模型噪声、梯度扰动
torch.randint(low, high, size) 生成整数随机张量 随机索引、类别采样
torch.ones_like()/zeros_like() 生成和输入形状相同的全 1 / 全 0 张量 掩码初始化、基准值计算
torch.randperm(n) 生成 0~n-1 的随机排列 数据集随机打乱
python 复制代码
import torch

# ===================== 1. 基础随机张量 =====================
# 1.1 均匀分布 [0,1)
rand_tensor = torch.rand(2, 3)  # shape=(2,3)
print("均匀分布随机张量:\n", rand_tensor)
# 输出示例:
# tensor([[0.4567, 0.1234, 0.7890],
#         [0.3456, 0.8765, 0.2345]])

# 1.2 标准正态分布
randn_tensor = torch.randn(2, 3)
print("\n标准正态分布随机张量:\n", randn_tensor)
# 输出示例:
# tensor([[ 0.1234, -0.5678,  0.9012],
#         [-0.3456,  0.7890, -0.2345]])

# 1.3 整数随机张量(low=0, high=10)
randint_tensor = torch.randint(0, 10, (2, 3))
print("\n整数随机张量:\n", randint_tensor)
# 输出示例:
# tensor([[5, 8, 2],
#         [7, 1, 9]])

# 1.4 随机排列(0~4的随机顺序)
randperm = torch.randperm(5)
print("\n随机排列:\n", randperm)  # 输出示例:tensor([3, 0, 4, 1, 2])

# ===================== 2. 实用技巧 =====================
# 2.1 固定随机种子(复现结果)
torch.manual_seed(42)  # 固定全局种子
print("\n固定种子后的随机张量:\n", torch.rand(2, 3))
# 每次运行结果都一样:tensor([[0.8823, 0.9150, 0.3829], [0.9593, 0.3904, 0.6009]])

# 2.2 生成指定形状的全1/全0张量(基于已有张量)
x = torch.rand(2, 3)
ones_like_x = torch.ones_like(x)
zeros_like_x = torch.zeros_like(x)
print("\n和x形状相同的全1张量:\n", ones_like_x)

# ===================== 3. 深度学习常用场景 =====================
# 3.1 生成模型权重(比如SE注意力的全连接层权重)
in_channels = 32
reduction = 16
# 均匀分布初始化权重
fc_weight = torch.rand(in_channels//reduction, in_channels) * 0.01  # 缩小权重范围
print("\nSE注意力权重张量shape:", fc_weight.shape)  # (2, 32)

# 3.2 生成批量噪声(数据增强)
batch_size = 4
img_shape = (3, 224, 224)
noise = torch.randn(batch_size, *img_shape) * 0.1  # 小幅度噪声
print("\n图像噪声张量shape:", noise.shape)  # (4, 3, 224, 224)

广播机制是 PyTorch 中不同形状张量自动适配计算 的核心规则,简单说:小形状张量 "扩展" 成大形状,和大张量逐元素计算,且不占用额外内存

判断两个张量能否广播,需从最后一个维度开始对比:

  • 规则 1:维度大小相等 ,或其中一个为1
  • 规则 2:维度少的张量,自动在前面补 1,再按规则 1 判断。
张量 A 形状 张量 B 形状 能否广播 广播后形状 核心逻辑
(2, 3) (3,) (2, 3) B 补 1 成 (1,3) → 扩展为 (2,3)
(2, 1, 3) (4, 3) (2, 4, 3) B 补 1 成 (1,4,3) → 扩展为 (2,4,3)
(2, 3) (2, 4) - 最后一维 3≠4,不满足规则 1
python 复制代码
# ===================== 1. 基础广播 =====================
# 示例1:(2,3) + (3,) → 广播后都是(2,3)
a = torch.tensor([[1, 2, 3], [4, 5, 6]])  # (2,3)
b = torch.tensor([10, 20, 30])            # (3,)
c = a + b
print("广播加法结果:\n", c)
# 输出:
# tensor([[11, 22, 33],
#         [14, 25, 36]])

# 示例2:(2,1,3) * (4,3) → 广播后都是(2,4,3)
d = torch.rand(2, 1, 3)  # (2,1,3)
e = torch.rand(4, 3)     # (4,3)
f = d * e
print("\n广播乘法结果shape:", f.shape)  # (2,4,3)

# ===================== 2. 注意力机制中的广播(重点) =====================
# 模拟SE通道注意力的权重缩放:(B,C,1,1) × (B,C,H,W)
batch_size = 2
C = 32  # 通道数
H, W = 14, 14

# 1. 生成通道权重(SE模块输出:(B,C,1,1))
channel_weight = torch.rand(batch_size, C, 1, 1)  # 每个通道一个权重
print("\n通道权重shape:", channel_weight.shape)  # (2,32,1,1)

# 2. 生成特征图(卷积输出:(B,C,H,W))
feature_map = torch.rand(batch_size, C, H, W)
print("特征图shape:", feature_map.shape)  # (2,32,14,14)

# 3. 广播乘法:权重 × 特征图(核心!无需手动扩展维度)
weighted_feature = feature_map * channel_weight
print("加权后特征图shape:", weighted_feature.shape)  # (2,32,14,14)

# ===================== 3. 常见错误(反例) =====================
# 错误1:最后一维不匹配((2,3) 和 (2,4))
try:
    torch.rand(2,3) + torch.rand(2,4)
except RuntimeError as e:
    print("\n广播错误示例1:", e)  # 报错:The size of tensor a (3) must match the size of tensor b (4)

# 错误2:维度顺序不匹配((3,2) 和 (3,) → 最后一维2≠3)
try:
    torch.rand(3,2) + torch.rand(3)
except RuntimeError as e:
    print("广播错误示例2:", e)  # 报错:The size of tensor a (2) must match the size of tensor b (3)

@浙大疏锦行

相关推荐
hixiong1232 小时前
C# OpenvinoSharp使用RAD进行缺陷检测
开发语言·人工智能·c#·openvino
小和尚同志2 小时前
还有人在问 Skills 是啥?感觉和 prompt 一样
人工智能·aigc
星和月2 小时前
人工智能与神经网络
人工智能
田里的水稻2 小时前
ubuntu22.04_构建openclaw开发框架
运维·人工智能·python
Trisyp2 小时前
Word2vec核心模型精讲:CBOW与Skip-gram
人工智能·自然语言处理·word2vec
liuccn2 小时前
技能管理工具npx skills 跟openskills的关系以及区别
人工智能
新缸中之脑2 小时前
AI Harness 工程的崛起
人工智能
大写-凌祁2 小时前
[2026年03月15日] AI 深度早报
人工智能·深度学习·机器学习·计算机视觉·agi
Lw中2 小时前
RAG如何科学调节切片长度与滑动窗口?
人工智能·大模型应用基础·rag检索