动手学深度学习(pytorch版):第四章节—多层感知机(5)权重衰减

前一节描述了过拟合的问题,本节将介绍一些正则化模型的技术。 可以通过去收集更多的训练数据来缓解过拟合。 但这可能成本很高,耗时颇多,或者完全超出控制,因而在短期内不可能做到。 假设已经拥有尽可能多的高质量数据,便可以将重点放在正则化技术上。

回想一下,在多项式回归的例子中, 可以通过调整拟合多项式的阶数来限制模型的容量。 实际上,限制特征的数量是缓解过拟合的一种常用技术。 然而,简单地丢弃特征对这项工作来说可能过于生硬。 继续思考多项式回归的例子,考虑高维输入可能发生的情况。 多项式对多变量数据的自然扩展称为单项式(monomials), 也可以说是变量幂的乘积。 单项式的阶数是幂的和。

注意,随着阶数的增长,带有阶数的项数迅速增加。 因此即使是阶数上的微小变化,比如从到,也会显著增加模型的复杂性。 仅仅通过简单的限制特征数量(在多项式回归中体现为限制阶数),可能仍然使模型在过简单和过复杂中徘徊, 需要一个更细粒度的工具来调整函数的复杂性,使其达到一个合适的平衡位置。

已经描述了范数和范数, 它们是更为一般的范数的特殊情况。

在训练参数化机器学习模型时, 权重衰减 (weight decay)是最广泛使用的正则化的技术之一, 它通常也被称为正则化。 这项技术通过函数与零的距离来衡量函数的复杂度, 因为在所有函数中,函数(所有输入都得到值) 在某种意义上是最简单的。 但是应该如何精确地测量一个函数和零之间的距离呢? 没有一个正确的答案。 事实上,函数分析和巴拿赫空间理论的研究,都在致力于回答这个问题。

一种简单的方法是通过线性函数中的权重向量的某个范数来度量其复杂性,

回想一下,是样本的特征, 是样本的标签, 是权重和偏置参数。 为了惩罚权重向量的大小, 必须以某种方式在损失函数中添加, 但是模型应该如何平衡这个新的额外惩罚的损失?

1. 高维线性回归

通过一个简单的例子来演示权重衰减。

python 复制代码
%matplotlib inline
import torch
from torch import nn
from d2l import torch as d2l

首先,像以前一样生成一些数据,生成公式如下:

选择标签是关于输入的线性函数。 标签同时被均值为0,标准差为0.01高斯噪声破坏。 为了使过拟合的效果更加明显,可以将问题的维数增加到, 并使用一个只包含20个样本的小训练集。

python 复制代码
n_train, n_test, num_inputs, batch_size = 20, 100, 200, 5
true_w, true_b = torch.ones((num_inputs, 1)) * 0.01, 0.05
train_data = d2l.synthetic_data(true_w, true_b, n_train)
train_iter = d2l.load_array(train_data, batch_size)
test_data = d2l.synthetic_data(true_w, true_b, n_test)
test_iter = d2l.load_array(test_data, batch_size, is_train=False)

2. 从零开始实现

下面将从头开始实现权重衰减,只需将的平方惩罚添加到原始目标函数中。

2.1. 初始化模型参数

首先,定义一个函数来随机初始化模型参数。

python 复制代码
def init_params():
    w = torch.normal(0, 1, size=(num_inputs, 1), requires_grad=True)
    b = torch.zeros(1, requires_grad=True)
    return [w, b]

2.2. 定义范数惩罚

实现这一惩罚最方便的方法是对所有项求平方后并将它们求和。

python 复制代码
def l2_penalty(w):
    return torch.sum(w.pow(2)) / 2

2.3. 定义训练代码实现

下面的代码将模型拟合训练数据集,并在测试数据集上进行评估。

python 复制代码
def train(lambd):
    w, b = init_params()
    net, loss = lambda X: d2l.linreg(X, w, b), d2l.squared_loss
    num_epochs, lr = 100, 0.003
    animator = d2l.Animator(xlabel='epochs', ylabel='loss', yscale='log',
                            xlim=[5, num_epochs], legend=['train', 'test'])
    for epoch in range(num_epochs):
        for X, y in train_iter:
            # 增加了L2范数惩罚项,
            # 广播机制使l2_penalty(w)成为一个长度为batch_size的向量
            l = loss(net(X), y) + lambd * l2_penalty(w)
            l.sum().backward()
            d2l.sgd([w, b], lr, batch_size)
        if (epoch + 1) % 5 == 0:
            animator.add(epoch + 1, (d2l.evaluate_loss(net, train_iter, loss),
                                     d2l.evaluate_loss(net, test_iter, loss)))
    print('w的L2范数是:', torch.norm(w).item())

2.4. 忽略正则化直接训练

lambd=0禁用权重衰减后运行这个代码。 注意,这里训练误差有了减少,但测试误差没有减少, 这意味着出现了严重的过拟合。

python 复制代码
train(lambd=0)

2.5. 使用权重衰减

下面使用权重衰减来运行代码。 注意,在这里训练误差增大,但测试误差减小。 这正是我们期望从正则化中得到的效果。

python 复制代码
train(lambd=3)

3. 简洁实现

在下面的代码中,我们在实例化优化器时直接通过weight_decay指定weight decay超参数。 默认情况下,PyTorch同时衰减权重和偏移。 这里我们只为权重设置了weight_decay,所以偏置参数不会衰减。

python 复制代码
def train_concise(wd):
    net = nn.Sequential(nn.Linear(num_inputs, 1))
    for param in net.parameters():
        param.data.normal_()
    loss = nn.MSELoss(reduction='none')
    num_epochs, lr = 100, 0.003
    # 偏置参数没有衰减
    trainer = torch.optim.SGD([
        {"params":net[0].weight,'weight_decay': wd},
        {"params":net[0].bias}], lr=lr)
    animator = d2l.Animator(xlabel='epochs', ylabel='loss', yscale='log',
                            xlim=[5, num_epochs], legend=['train', 'test'])
    for epoch in range(num_epochs):
        for X, y in train_iter:
            trainer.zero_grad()
            l = loss(net(X), y)
            l.mean().backward()
            trainer.step()
        if (epoch + 1) % 5 == 0:
            animator.add(epoch + 1,
                         (d2l.evaluate_loss(net, train_iter, loss),
                          d2l.evaluate_loss(net, test_iter, loss)))
    print('w的L2范数:', net[0].weight.norm().item())

然而,它们运行得更快,更容易实现。 对于更复杂的问题,这一好处将变得更加明显。

python 复制代码
train_concise(0)
python 复制代码
train_concise(3)
复制代码
相关推荐
格林威6 小时前
Baumer相机视野内微小缺陷增强检测:提升亚像素级瑕疵可见性的 7 个核心方法,附 OpenCV+Halcon 实战代码!
人工智能·数码相机·opencv·算法·计算机视觉·视觉检测·工业相机
老百姓懂点AI7 小时前
[WASM实战] 插件系统的安全性:智能体来了(西南总部)AI调度官的WebAssembly沙箱与AI agent指挥官的动态加载
人工智能·wasm
多米Domi0117 小时前
0x3f 第49天 面向实习的八股背诵第六天 过了一遍JVM的知识点,看了相关视频讲解JVM内存,垃圾清理,买了plus,稍微看了点确定一下方向
jvm·数据结构·python·算法·leetcode
人工智能训练13 小时前
【极速部署】Ubuntu24.04+CUDA13.0 玩转 VLLM 0.15.0:预编译 Wheel 包 GPU 版安装全攻略
运维·前端·人工智能·python·ai编程·cuda·vllm
yaoming16813 小时前
python性能优化方案研究
python·性能优化
源于花海13 小时前
迁移学习相关的期刊和会议
人工智能·机器学习·迁移学习·期刊会议
码云数智-大飞14 小时前
使用 Python 高效提取 PDF 中的表格数据并导出为 TXT 或 Excel
python
DisonTangor15 小时前
DeepSeek-OCR 2: 视觉因果流
人工智能·开源·aigc·ocr·deepseek
薛定谔的猫198215 小时前
二十一、基于 Hugging Face Transformers 实现中文情感分析情感分析
人工智能·自然语言处理·大模型 训练 调优
发哥来了15 小时前
《AI视频生成技术原理剖析及金管道·图生视频的应用实践》
人工智能