在 PyTorch 中,归一化是一种重要的操作,用于调整数据分布或模型参数,以提高模型的训练效率和性能。以下是常见的归一化方式及其应用场景:
1. 数据归一化
(1)torch.nn.functional.normalize
对输入张量沿指定维度进行 L2 范数归一化,使得张量的范数为 1。
代码示例:
python
import torch
import torch.nn.functional as F
x = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
normalized_x = F.normalize(x, p=2, dim=1) # 每行进行归一化
print(normalized_x)
(2)自定义归一化
将输入数据缩放到特定范围(如 [0, 1]
或 [-1, 1]
)。
代码示例:
python
x = torch.tensor([1.0, 2.0, 3.0])
x_min, x_max = x.min(), x.max()
normalized_x = (x - x_min) / (x_max - x_min) # 归一化到 [0, 1]
2. 批归一化 (Batch Normalization)
(1)torch.nn.BatchNorm1d/2d/3d
对多维输入(如图像、序列数据)进行批归一化,主要用于神经网络的隐藏层。
BatchNorm1d
:用于 1D 输入(如序列或全连接层的输出)。BatchNorm2d
:用于 2D 输入(如卷积层的输出,(N, C, H, W)
)。BatchNorm3d
:用于 3D 输入(如 3D 卷积的输出,(N, C, D, H, W)
)。
代码示例:
python
import torch
import torch.nn as nn
batch_norm = nn.BatchNorm2d(num_features=3) # 通道数为 3
x = torch.randn(4, 3, 8, 8) # (N, C, H, W)
normalized_x = batch_norm(x)
3. 层归一化 (Layer Normalization)
(1)torch.nn.LayerNorm
对每个样本的特定维度进行归一化,常用于 RNN 或 Transformer。
代码示例:
python
import torch
import torch.nn as nn
layer_norm = nn.LayerNorm(normalized_shape=10) # 归一化的维度大小
x = torch.randn(5, 10) # (batch_size, features)
normalized_x = layer_norm(x)
4. 实例归一化 (Instance Normalization)
(1)torch.nn.InstanceNorm1d/2d/3d
对每个样本的特征图进行归一化,适用于风格迁移或生成模型。
代码示例:
python
import torch
import torch.nn as nn
instance_norm = nn.InstanceNorm2d(num_features=3)
x = torch.randn(4, 3, 8, 8) # (N, C, H, W)
normalized_x = instance_norm(x)
5. 局部响应归一化 (Local Response Normalization, LRN)
(1)torch.nn.LocalResponseNorm
模仿生物神经元的抑制机制,主要在早期 CNN(如 AlexNet)中使用。
代码示例:
python
import torch
import torch.nn as nn
lrn = nn.LocalResponseNorm(size=5)
x = torch.randn(1, 10, 8, 8) # (N, C, H, W)
normalized_x = lrn(x)
6. 权值归一化 (Weight Normalization)
(1)torch.nn.utils.weight_norm
对权值进行归一化,常用于加速收敛。
代码示例:
python
import torch
import torch.nn as nn
from torch.nn.utils import weight_norm
linear = nn.Linear(10, 5)
linear = weight_norm(linear) # 对权值进行归一化
7. 谱归一化 (Spectral Normalization)
(1)torch.nn.utils.spectral_norm
通过对权值矩阵进行奇异值分解,约束最大奇异值,常用于生成对抗网络(GAN)。
代码示例:
python
import torch
import torch.nn as nn
from torch.nn.utils import spectral_norm
conv = nn.Conv2d(3, 16, 3)
conv = spectral_norm(conv) # 对卷积核进行谱归一化
8. 正则化归一化
(1)梯度裁剪(Grad Clipping)
通过裁剪梯度的范数来实现归一化,主要用于防止梯度爆炸。