Class9简洁实现

Class9简洁实现

python 复制代码
%matplotlib inline
import torch
from torch import nn
from d2l import torch as d2l
python 复制代码
# 初始化训练样本、测试样本、样本特征维度和批量大小
n_train,n_test,num_inputs,batch_size = 20,100,200,5
# 设置真实权重和偏置
true_w,true_b = torch.ones((num_inputs,1)) * 0.01,0.05
# 生成训练数据
# d2l.synthetic_data():函数生成模拟的训练数据
# synthetic_data()L返回三元组(features,labels)
train_data = d2l.synthetic_data(true_w,true_b,n_train)
# 数据封装为训练数据迭代器
# d2l.load_array():把数据打包成一个笑屁刘昂迭代器,便于后续训练
# batch_size=5:每次迭代返回5个样本
train_iter = d2l.load_array(train_data,batch_size)
# 生成测试数据
test_data = d2l.synthetic_data(true_w,true_b,n_test)
# 数据封装为测试数据迭代器
test_iter = d2l.load_array(test_data,batch_size,is_train=False)
python 复制代码
# 实现带权重衰减(L2正则)线性回归模型训练
# wd:L2正则化系数lambd
def train_concise(wd):
    # 构建一个全连接层,输入为num_inputs,输出为1
    net = nn.Sequential(nn.Linear(num_inputs, 1))
    for param in net.parameters():
        # 将参数用正态分布随机初始化
        param.data.normal_()
    # 样本的均方误差不求平均
    loss = nn.MSELoss(reduction='none')
    # 定义训练轮数和学习率
    num_epochs, lr = 100, 0.003
    # 使用随机梯度下降优化器
    trainer = torch.optim.SGD([
        # 权重参数,应用L2正则
        {"params":net[0].weight,'weight_decay': wd},
        # 偏置参数,不加正则
        {"params":net[0].bias}], lr=lr)
    # 定义可视化工具
    animator = d2l.Animator(xlabel='epochs', ylabel='loss', yscale='log',
                            xlim=[5, num_epochs], legend=['train', 'test'])
    # 循环训练
    for epoch in range(num_epochs):
        for X, y in train_iter:
            # 清空梯度,防止梯度累加
            trainer.zero_grad()
            # 计算每个样本的MSELoss
            l = loss(net(X), y)
            # 进行反向传播
            l.mean().backward()
            # 更新模型参数
            trainer.step()
        # 每5轮评估训练集和测试集的loss损失函数
        if (epoch + 1) % 5 == 0:
            # 将当前loss加入到动态图中
            animator.add(epoch + 1,
                         (d2l.evaluate_loss(net, train_iter, loss),
                          d2l.evaluate_loss(net, test_iter, loss)))
    # 打印输出L2范数
    print('w的L2范数:', net[0].weight.norm().item())
python 复制代码
train_concise(0)
python 复制代码
train_concise(3)
相关推荐
大千AI助手2 小时前
代价复杂度剪枝(CCP)详解:原理、实现与应用
人工智能·决策树·机器学习·剪枝·大千ai助手·代价复杂度剪枝·ccp
九年义务漏网鲨鱼9 小时前
【大模型面经】千问系列专题面经
人工智能·深度学习·算法·大模型·强化学习
源码之家9 小时前
机器学习:基于大数据二手房房价预测与分析系统 可视化 线性回归预测算法 Django框架 链家网站 二手房 计算机毕业设计✅
大数据·算法·机器学习·数据分析·spark·线性回归·推荐算法
WWZZ202510 小时前
快速上手大模型:深度学习7(实践:卷积层)
人工智能·深度学习·算法·机器人·大模型·卷积神经网络·具身智能
强盛小灵通专卖员12 小时前
煤矿传送带异物检测:深度学习如何提升煤矿安全?
人工智能·深度学习·sci·小论文·大论文·延毕·研究生辅导
编程小白_正在努力中13 小时前
第七章深度解析:从零构建智能体框架——模块化设计与全流程落地指南
人工智能·深度学习·大语言模型·agent·智能体
化作星辰13 小时前
深度学习_三层神经网络传播案例(L0->L1->L2)
人工智能·深度学习·神经网络
_codemonster14 小时前
深度学习实战(基于pytroch)系列(十五)模型构造
人工智能·深度学习
xuehaikj14 小时前
【深度学习】YOLOv10n-MAN-Faster实现包装盒flap状态识别与分类,提高生产效率
深度学习·yolo·分类
sponge'14 小时前
opencv学习笔记9:基于CNN的mnist分类任务
深度学习·神经网络·cnn