深度学习 (线性回归 简洁实现)

介绍:

在线性神经网络中,线性回归是一种常见的任务,用于预测一个连续的数值输出。其目标是根据输入特征来拟合一个线性函数,使得预测值与真实值之间的误差最小化。

线性回归的数学表达式为:

y = w1x1 + w2x2 + ... + wnxn + b

其中,y表示预测的输出值,x1, x2, ..., xn表示输入特征,w1, w2, ..., wn表示特征的权重,b表示偏置项。

训练线性回归模型的目标是找到最优的权重和偏置项,使得模型预测的输出与真实值之间的平方差(即损失函数)最小化。这一最优化问题可以通过梯度下降等优化算法来解决。

线性回归在深度学习中也被广泛应用,特别是在浅层神经网络中。在深度学习中,通过将多个线性回归模型组合在一起,可以构建更复杂的神经网络结构,以解决更复杂的问题。

深度学习 线性神经网络(线性回归 从零开始实现)-CSDN博客

生成数据集:

python 复制代码
import numpy as np
import torch
from torch.utils import data
from d2l import torch as d2l

true_w = d2l.tensor([2, -3.4])
true_b = 4.2
features, labels = d2l.synthetic_data(true_w, true_b, 1000)

读取小批量数据集:

python 复制代码
#选取小批量样本
def load_array(data_arrays,batch_size,is_train=True):
    dataset = data.TensorDataset(*data_arrays)
    return data.DataLoader(dataset,batch_size,shuffle=is_train)

定义模型:

python 复制代码
from torch import nn#"nn"是神经网络的缩写

net = nn.Sequential(nn.Linear(2,1))#输入维度2,输出维度1

定义损失函数:

python 复制代码
loss = nn.MSELoss()#均分误差函数

定义优化函数(实例化SGD):

python 复制代码
#实例化SGD
trainer = torch.optim.SGD(net.parameters(),lr=0.03)#参数、学习率

模型训练:

python 复制代码
num_epochs=8
for epoch in range(num_epochs):
    for X, y in data_iter:#拿出一批量x,y
        l = loss(net(X), y)  # X和y的小批量损失,实际的和预测的
        trainer.zero_grad()
        l.backward()
        trainer.step()  # 使用参数的梯度更新参数
    l = loss(net(features),labels)
    print(f'epoch {epoch + 1}, loss {l:f}')

'''
epoch 1, loss 0.000175
epoch 2, loss 0.000096
epoch 3, loss 0.000095
epoch 4, loss 0.000095
epoch 5, loss 0.000095
epoch 6, loss 0.000095
epoch 7, loss 0.000095
epoch 8, loss 0.000096
'''

print(net[0].weight)
'''
Parameter containing:
tensor([[ 2.0004, -3.3990]], requires_grad=True)
'''

print(net[0].bias)
'''
Parameter containing:
tensor([4.2007], requires_grad=True)

'''
相关推荐
ting945200034 分钟前
Tornado 全栈技术深度指南:从原理到实战
人工智能·python·架构·tornado
果汁华1 小时前
Browserbase Skills:让 Claude Agent 真正“看见“网页世界
人工智能·python
ZhengEnCi1 小时前
04-缩放点积注意力代码实现 💻
人工智能·python
HackTwoHub2 小时前
AI大模型网关存在SQL注入、附 POC 复现、影响版本LiteLLM 1.81.16~1.83.7(CVE-2026-42208)
数据库·人工智能·sql·网络安全·系统安全·网络攻击模型·安全架构
段一凡-华北理工大学2 小时前
【高炉炼铁领域炉温监测、预警、调控智能体设计与应用】~系列文章08:多模态数据融合:让数据更聪明
人工智能·python·高炉炼铁·ai赋能·工业智能体·高炉炉温
网络工程小王2 小时前
【LangChain 大模型6大调用指南】调用大模型篇
linux·运维·服务器·人工智能·学习
HIT_Weston3 小时前
63、【Agent】【OpenCode】用户对话提示词(示例)
人工智能·agent·opencode
CV-杨帆3 小时前
Phi-4-mini-flash-reasoning 部署安装与推理测试完整记录
人工智能
MediaTea3 小时前
AI 术语通俗词典:C4.5 算法
人工智能·算法
海兰3 小时前
【第27篇】Micrometer + Zipkin
人工智能·spring boot·alibaba·spring ai