Class9简洁实现

Class9简洁实现

python 复制代码
%matplotlib inline
import torch
from torch import nn
from d2l import torch as d2l
python 复制代码
# 初始化训练样本、测试样本、样本特征维度和批量大小
n_train,n_test,num_inputs,batch_size = 20,100,200,5
# 设置真实权重和偏置
true_w,true_b = torch.ones((num_inputs,1)) * 0.01,0.05
# 生成训练数据
# d2l.synthetic_data():函数生成模拟的训练数据
# synthetic_data()L返回三元组(features,labels)
train_data = d2l.synthetic_data(true_w,true_b,n_train)
# 数据封装为训练数据迭代器
# d2l.load_array():把数据打包成一个笑屁刘昂迭代器,便于后续训练
# batch_size=5:每次迭代返回5个样本
train_iter = d2l.load_array(train_data,batch_size)
# 生成测试数据
test_data = d2l.synthetic_data(true_w,true_b,n_test)
# 数据封装为测试数据迭代器
test_iter = d2l.load_array(test_data,batch_size,is_train=False)
python 复制代码
# 实现带权重衰减(L2正则)线性回归模型训练
# wd:L2正则化系数lambd
def train_concise(wd):
    # 构建一个全连接层,输入为num_inputs,输出为1
    net = nn.Sequential(nn.Linear(num_inputs, 1))
    for param in net.parameters():
        # 将参数用正态分布随机初始化
        param.data.normal_()
    # 样本的均方误差不求平均
    loss = nn.MSELoss(reduction='none')
    # 定义训练轮数和学习率
    num_epochs, lr = 100, 0.003
    # 使用随机梯度下降优化器
    trainer = torch.optim.SGD([
        # 权重参数,应用L2正则
        {"params":net[0].weight,'weight_decay': wd},
        # 偏置参数,不加正则
        {"params":net[0].bias}], lr=lr)
    # 定义可视化工具
    animator = d2l.Animator(xlabel='epochs', ylabel='loss', yscale='log',
                            xlim=[5, num_epochs], legend=['train', 'test'])
    # 循环训练
    for epoch in range(num_epochs):
        for X, y in train_iter:
            # 清空梯度,防止梯度累加
            trainer.zero_grad()
            # 计算每个样本的MSELoss
            l = loss(net(X), y)
            # 进行反向传播
            l.mean().backward()
            # 更新模型参数
            trainer.step()
        # 每5轮评估训练集和测试集的loss损失函数
        if (epoch + 1) % 5 == 0:
            # 将当前loss加入到动态图中
            animator.add(epoch + 1,
                         (d2l.evaluate_loss(net, train_iter, loss),
                          d2l.evaluate_loss(net, test_iter, loss)))
    # 打印输出L2范数
    print('w的L2范数:', net[0].weight.norm().item())
python 复制代码
train_concise(0)
python 复制代码
train_concise(3)
相关推荐
AndrewHZ25 分钟前
【图像处理基石】什么是解析力?
图像处理·人工智能·深度学习·计算机视觉·过拟合·解析力·计算成像
你喜欢喝可乐吗?30 分钟前
深度学习模型开发部署全流程:以YOLOv11目标检测任务为例
深度学习·yolo·目标检测
Leo Chaw2 小时前
40 - ScConv卷积模块
pytorch·深度学习·神经网络·机器学习·cnn
枫落雁然2 小时前
深度学习零基础入门(3)-图像与神经网络
人工智能·深度学习·神经网络
苍何3 小时前
2025值得推荐的AI生产力工具一览表(上)
机器学习
马贡多在下雨3 小时前
一文读懂什么是逻辑回归
机器学习
Blossom.1183 小时前
基于深度学习的图像识别:从零构建卷积神经网络(CNN)
人工智能·深度学习·神经网络·机器学习·cnn·机器人·transformer
AI妈妈手把手3 小时前
【深度学习框架终极PK】TensorFlow/PyTorch/MindSpore深度解析!选对框架效率翻倍
人工智能·pytorch·python·深度学习·tensorflow·mindspore·ai选型指南
爱学习的茄子3 小时前
告别 useState 噩梦:useReducer 如何终结 React 状态管理混乱?
前端·深度学习·react.js
泡芙萝莉酱4 小时前
世界各国和地区ICRG政治经济金融综合风险指标数据(1984-2023年)-实证数据
大数据·人工智能·深度学习·数据挖掘·数据分析·数据统计·实证数据