Class9简洁实现

Class9简洁实现

python 复制代码
%matplotlib inline
import torch
from torch import nn
from d2l import torch as d2l
python 复制代码
# 初始化训练样本、测试样本、样本特征维度和批量大小
n_train,n_test,num_inputs,batch_size = 20,100,200,5
# 设置真实权重和偏置
true_w,true_b = torch.ones((num_inputs,1)) * 0.01,0.05
# 生成训练数据
# d2l.synthetic_data():函数生成模拟的训练数据
# synthetic_data()L返回三元组(features,labels)
train_data = d2l.synthetic_data(true_w,true_b,n_train)
# 数据封装为训练数据迭代器
# d2l.load_array():把数据打包成一个笑屁刘昂迭代器,便于后续训练
# batch_size=5:每次迭代返回5个样本
train_iter = d2l.load_array(train_data,batch_size)
# 生成测试数据
test_data = d2l.synthetic_data(true_w,true_b,n_test)
# 数据封装为测试数据迭代器
test_iter = d2l.load_array(test_data,batch_size,is_train=False)
python 复制代码
# 实现带权重衰减(L2正则)线性回归模型训练
# wd:L2正则化系数lambd
def train_concise(wd):
    # 构建一个全连接层,输入为num_inputs,输出为1
    net = nn.Sequential(nn.Linear(num_inputs, 1))
    for param in net.parameters():
        # 将参数用正态分布随机初始化
        param.data.normal_()
    # 样本的均方误差不求平均
    loss = nn.MSELoss(reduction='none')
    # 定义训练轮数和学习率
    num_epochs, lr = 100, 0.003
    # 使用随机梯度下降优化器
    trainer = torch.optim.SGD([
        # 权重参数,应用L2正则
        {"params":net[0].weight,'weight_decay': wd},
        # 偏置参数,不加正则
        {"params":net[0].bias}], lr=lr)
    # 定义可视化工具
    animator = d2l.Animator(xlabel='epochs', ylabel='loss', yscale='log',
                            xlim=[5, num_epochs], legend=['train', 'test'])
    # 循环训练
    for epoch in range(num_epochs):
        for X, y in train_iter:
            # 清空梯度,防止梯度累加
            trainer.zero_grad()
            # 计算每个样本的MSELoss
            l = loss(net(X), y)
            # 进行反向传播
            l.mean().backward()
            # 更新模型参数
            trainer.step()
        # 每5轮评估训练集和测试集的loss损失函数
        if (epoch + 1) % 5 == 0:
            # 将当前loss加入到动态图中
            animator.add(epoch + 1,
                         (d2l.evaluate_loss(net, train_iter, loss),
                          d2l.evaluate_loss(net, test_iter, loss)))
    # 打印输出L2范数
    print('w的L2范数:', net[0].weight.norm().item())
python 复制代码
train_concise(0)
python 复制代码
train_concise(3)
相关推荐
大写-凌祁26 分钟前
零基础入门深度学习:从理论到实战,GitHub+开源资源全指南(2025最新版)
人工智能·深度学习·开源·github
焦耳加热1 小时前
阿德莱德大学Nat. Commun.:盐模板策略实现废弃塑料到单原子催化剂的高值转化,推动环境与能源催化应用
人工智能·算法·机器学习·能源·材料工程
wan5555cn1 小时前
多张图片生成视频模型技术深度解析
人工智能·笔记·深度学习·算法·音视频
格林威2 小时前
机器视觉检测的光源基础知识及光源选型
人工智能·深度学习·数码相机·yolo·计算机视觉·视觉检测
THMAIL3 小时前
量化股票从贫穷到财务自由之路 - 零基础搭建Python量化环境:Anaconda、Jupyter实战指南
linux·人工智能·python·深度学习·机器学习·金融
~-~%%4 小时前
从PyTorch到ONNX:模型部署性能提升
人工智能·pytorch·python
xcnn_4 小时前
深度学习基础概念回顾(Pytorch架构)
人工智能·pytorch·深度学习
attitude.x4 小时前
PyTorch 动态图的灵活性与实用技巧
前端·人工智能·深度学习
Ven%4 小时前
第一章 神经网络的复习
人工智能·深度学习·神经网络
研梦非凡5 小时前
CVPR 2025|基于视觉语言模型的零样本3D视觉定位
人工智能·深度学习·计算机视觉·3d·ai·语言模型·自然语言处理