李沐动手学深度学习Pytorch-v2笔记【08线性回归+基础优化算法】2

文章目录

线性回归的简介实现

通过使用深度学习框架来简洁实现 线性回归模型 生成数据集

python 复制代码
import torch
import numpy as np
from torch.utils import data
from d2l import torch as d2l

true_w = torch.tensor([2, -3.4])
true_b = 4.2
features, labels = d2l.synthetic_data(true_w, true_b, 1000)
python 复制代码
def load_array(data_arrays, batch_size, is_train = True):
	"构造一个Pytorch数据迭代器"
	dataset = data.TensorDataset(*data_arrays)
	return data.DataLoader(dataset, batch_size, shuffle = is_train)
batch_size = 10
data_iter = load_array((features, labels),batch_size)

next(iter(data_iter))

data.TensorDataset :是 PyTorch 提供的一个类,用于将多个张量封装为一个数据集。
data_arrays :是解包操作,假设 data_arrays(features, labels),则等价于 data.TensorDataset(features, labels)
DataLoader:提供批次加载、数据打乱和多线程支持。
next(iter(data_iter))
iter(data_iter)DataLoader 转换为迭代器。
next() 获取下一个批次(第一次调用时是第一批数据)。

使用框架的预定好的层

python 复制代码
#"nn"是神经网络的缩写
from torch import nn

net = nn.Sequential(nn.Linear(2, 1))

初始化模型参数

python 复制代码
net[0].weight.data.normal_(0, 0.01)
net[0].bias.data.fill_(0)


_:表示写入
normal : 表示正态分布

计算均方误差使用的是MESLoss类也称平方范数

python 复制代码
loss = nn.MSELoss()

实例化SGD示实例

python 复制代码
trainer = torch.optim.SGD(net.parameters(),lr = 0.03)

torch.optim.SGD(params, lr=, momentum=0, dampening=0, weight_decay=0, nesterov=False)

params(必须参数): 这是一个包含了需要优化的参数(张量)的迭代器,例如模型的参数 model.parameters()

lr(必须参数): 学习率(learning rate)。它是一个正数,控制每次参数更新的步长。较小的学习率会导致收敛较慢,较大的学习率可能导致震荡或无法收敛。

momentum(默认值为 0): 动量(momentum)是一个用于加速 SGD 收敛的参数。它引入了上一步梯度的指数加权平均。通常设置在 01 之间。当 momentum 大于 0 时,算法在更新时会考虑之前的梯度,有助于加速收敛。

dampening(默认值为 0): 阻尼项,用于减缓动量的速度。在某些情况下,为了防止动量项引起的震荡,可以设置一个小的 dampening 值。

weight_decay(默认值为 0): 权重衰减,也称为 L2 正则化项。它用于控制参数的幅度,以防止过拟合。通常设置为一个小的正数。

nesterov(默认值为 False): Nesterov 动量。当设置为 True 时,采用 Nesterov 动量更新规则。Nesterov 动量在梯度更新之前先进行一次预测,然后在计算梯度更新时使用这个预测。

训练过程

python 复制代码
num_epochs = 3
for epoch in rangepochs:
	for X, y in data_iter:
		l = loss(net(X), y)
		trainer.zero_grad()
		l.backward()
		train_step()
	l = loss(net(features), labels)
	print(f'epoch {epoch + 1}, loss {1:f}')
相关推荐
晨曦夜月25 分钟前
map与unordered_map区别
算法·哈希算法
图码1 小时前
如何用多种方法判断字符串是否为回文?
开发语言·数据结构·c++·算法·阿里云·线性回归·数字雕刻
handler011 小时前
Linux 内核剖析:进程优先级、上下文切换与 O(1) 调度算法
linux·运维·c语言·开发语言·c++·笔记·算法
minglie11 小时前
实数列的常用递推模式
算法
我是大聪明.1 小时前
CUDA矩阵乘法优化:共享内存分块与Warp级执行机制深度解析
人工智能·深度学习·线性代数·机器学习·矩阵
码云数智-大飞1 小时前
大模型幻觉:成因解析与有效避免策略
人工智能·深度学习
代码小书生1 小时前
math,一个基础的 Python 库!
人工智能·python·算法
AI科技星1 小时前
全域数学·数术本源·高维代数卷(72分册)【乖乖数学】
人工智能·算法·数学建模·数据挖掘·量子计算
生成论实验室2 小时前
《事件关系阴阳博弈动力学:识势应势之道》第一篇:生成正在发生——从《即事经》到事件-关系网络
人工智能·科技·算法·架构·创业创新
漂流瓶jz2 小时前
UVA-1152 和为0的4个值 题解答案代码 算法竞赛入门经典第二版
数据结构·算法·二分查找·题解·aoapc·算法竞赛入门经典·uva