深度学习与神经网络Pytorch版 3.2 线性回归从零开始实现 1.生成数据集

3.2 线性回归从零开始实现

目录

[3.2 线性回归从零开始实现](#3.2 线性回归从零开始实现)

[一 ,简介](#一 ,简介)

[1. 原理](#1. 原理)

[2. 步骤](#2. 步骤)

[3. 优缺点](#3. 优缺点)

[4. 应用场景](#4. 应用场景)

[二 ,代码展现](#二 ,代码展现)

[1. 生成数据集(完整代码)](#1. 生成数据集(完整代码))

[2. 各个函数解析](#2. 各个函数解析)

[2.1 torch.normal()函数](#2.1 torch.normal()函数)

[2.2 torch.matmul()函数](#2.2 torch.matmul()函数)

[2.3 d2l.plt.scatter()函数](#2.3 d2l.plt.scatter()函数)

[三 ,总结](#三 ,总结)


一 ,简介

1. 原理

深度学习线性回归的原理是基于神经网络和线性回归的结合。它使用神经网络来构建一个复杂的非线性模型,同时保持线性回归的简单性和可解释性。

在深度学习线性回归中,通常使用全连接神经网络(Fully Connected Neural Network)作为基础结构。输入数据经过一系列的线性变换和非线性激活函数,最终输出预测结果。与传统的线性回归不同,深度学习线性回归可以自动学习特征之间的复杂交互和组合,而不需要手动选择或设计特征。

深度学习线性回归的训练过程与传统神经网络的训练过程类似,使用梯度下降算法优化模型的参数,以最小化预测误差(如均方误差)。在训练过程中,通过反向传播算法计算梯度,并使用优化器(如Adam、SGD等)更新权重和偏置项。

深度学习线性回归的优点是可以处理高维、复杂的非线性数据,并且具有自动特征选择和组合的能力。然而,与传统的线性回归相比,深度学习线性回归需要更多的参数和计算资源,并且可能更容易过拟合。因此,在选择是否使用深度学习线性回归时,需要根据具体问题和数据集的特点进行权衡。

2. 步骤

线性回归从零开始实现步骤包括以下内容:

  1. 导入必要的库:在Python中,需要导入numpy库来处理数据和计算,以及matplotlib库来绘制数据和结果。
  2. 生成数据集:根据实际问题,可以使用随机数生成器生成一组训练数据集,包括输入特征和对应的标签。也可以使用真实数据集进行训练和测试。
  3. 初始化模型参数:为模型权重和偏置项设置初始值,这些初始值可以是随机数或基于先验知识的值。
  4. 定义模型:根据线性回归模型的公式,可以使用numpy的矩阵运算来构建模型。模型可以表示为y = w * x + b,其中x是输入特征,y是对应的标签,w是权重,b是偏置项。
  5. 计算损失函数:损失函数用于衡量模型的预测值与真实值之间的差距。对于线性回归问题,常用的损失函数是均方误差(MSE)。
  6. 执行梯度下降算法:梯度下降算法用于更新模型的参数以最小化损失函数。在每一步迭代中,根据梯度下降公式计算参数的更新方向和步长,并更新参数的值。
  7. 训练:重复执行步骤5和6,直到达到预设的迭代次数或损失函数达到可接受的值。
  8. 评估模型:使用测试数据集评估模型的性能,计算模型的预测值与真实值之间的误差或准确率等指标。
  9. 优化和调整:根据评估结果对模型进行调整,例如调整参数、增加特征或使用正则化等方法来提高模型的性能。
  10. 应用模型进行预测:将新数据输入到模型中进行预测,得到预测结果。

以上是线性回归从零开始实现的基本步骤,具体实现细节可能会根据问题和数据集的不同而有所差异。

3. 优缺点

线性回归的优点:

  1. 简单易行:线性回归模型简单易懂,实现起来也相对容易。
  2. 计算效率高:由于模型简单,计算复杂度较低,因此在线性回归中,无论是训练还是预测,计算速度都比较快。
  3. 可解释性强:线性回归模型可以给出每个特征的权重,这有助于理解特征对目标变量的影响程度。
  4. 适合处理线性关系:线性回归适合处理因变量和自变量之间存在线性关系的情况。
  5. 模型稳定性好:线性回归模型相对稳定,对异常值和噪声的鲁棒性较好。

然而,线性回归也存在一些缺点:

  1. 假设限制:线性回归基于一些假设,如误差项的独立性、同方差性、无序列相关性和常数方差等。在实际应用中,这些假设可能不成立,导致模型误判。
  2. 欠拟合与过拟合:如果线性模型过于简单(即过于欠拟合),它可能无法捕获数据的复杂模式;而如果模型过于复杂(即过拟合),它可能会捕获到数据中的噪声和无关紧要的信息。
  3. 无法处理非线性关系:对于非线性关系的数据,线性回归可能无法给出很好的预测。
  4. 对异常值敏感:如果数据集中存在异常值,线性回归模型的预测结果可能会受到影响。
  5. 特征选择困难:对于特征之间的交互和特征选择,线性回归模型可能会遇到困难。

4. 应用场景

线性回归的应用场景包括但不限于:

  1. 预测:当因变量是连续变量,并且与其影响因素有线性关系时,可以用线性回归进行建模。例如,预测信用卡用户的生命周期价值,可以基于用户所在小区的平均收入、年龄、学历、收入等因素进行线性回归建模。
  2. 模型解释:当需要理解自变量与因变量之间的关系时,可以通过建立线性回归模型,例如决策树、线性回归等模型,以自变量作为输入变量,以因变量作为目标变量进行建模,以此了解黑盒模型的运作机制,并对其作出解释。
  3. 全量实验效果评估:全量实验评估是指当在时间点时,对全量用户加入干预策略,然后评估策略所带来的影响。进行评估时,核心是要剥离其他因素,对实验效果进行评估,线性回归就能解决这个问题。
  4. AB实验:在AB实验中,假定有两组无差异的用户群体和,以作为实验组对其施加策略干预,作为对照组不采取施加任何策略,来评估实验对观测变量的影响。可以通过t或z检验来得到结果,当然也可以建立线性回归模型 ,为是否为实验组的哑变量(当策略变多时,也可为分类变量),通过检验参数的显著性即可得到策略的效果。
  5. 预测疾病发生概率:医院可以根据患者的病历数据(如体检指标、药物复用情况、平时的饮食习惯等)来预测某种疾病发生的概率。
  6. 预测用户支付转化率:网站可以根据访问的历史数据(包括新用户的注册量、老用户的活跃度、网站内容的更新频率等)来预测用户的支付转化率。

以上只是部分应用场景,线性回归模型的应用非常广泛,具体应用取决于数据的特征和业务需求。

二 ,代码展现

1. 生成数据集(完整代码)

python 复制代码
# 线性回归从零开始实现
# 生成数据集

# 导入必要的库
import matplotlib.pyplot as plt
import random
import torch
from d2l import torch as d2l


# 定义一个生成合成数据的函数
def synthetic_data(w, b, num_examples):    # 函数参数包括权重w、偏置b和数据点数量num_examples
    # 生成y=Xw+b+噪声满足线性关系y=Xw+b的数据,并添加噪声
    X = torch.normal(0, 1, (num_examples, len(w)))  # 创建一个形状为(num_examples, len(w))的张量X,元素值为从标准正态分布中抽取的随机数
    y = torch.matmul(X, w) + b  # 使用矩阵乘法计算y的值,y = X * w + b
    y += torch.normal(0, 0.01, y.shape)  # 在y的值上添加从标准正态分布中抽取的随机噪声,噪声的标准差为0.01
    return X, y.reshape((-1, 1))  # 返回X和y。y被重新整形为(-1, 1)的形状,这是因为matplotlib在绘图时需要这样的形状


# 定义真实的权重和偏置值
true_w = torch.tensor([2, -3.4])  # 真实的权重w为[2, -3.4]的张量
true_b = 4.2  # 真实的偏置b为4.2的标量

# 使用上面定义的函数生成数据集
features, labels = synthetic_data(true_w, true_b, 1000)  # 生成1000个数据点作为训练或测试样本,特征为X,标签为y(即labels)

print('features:', features[0],'\nlabel:', labels[0])
d2l.set_figsize()
d2l.plt.scatter(features[:, (1)].detach().numpy(), labels.detach().numpy(), 1) 
# 这行代码也是从d2l库中调用的。它使用散点图来可视化特征和标签。#
# features[:, (1)].detach().numpy()选取了所有数据点的第二个特征(索引为1,因为索引是从0开始的)并转换为NumPy数组。
# #.detach()是PyTorch中的方法,用于从计算图中分离张量,这样张量就不会追踪其历史计算,这在进行绘图等操作时是很有用的。
# labels.detach().numpy()将标签转换为NumPy数组。这里的1表示散点的大小。
plt.show()

2. 各个函数解析

2.1 torch.normal()函数
python 复制代码
normal(mean, std, *, generator=None, out=None)

参数说明

  • mean (Tensor): 每个输出元素的均值。它是一个张量,其中包含各个分布的均值。
  • std (Tensor): 每个输出元素的标准差。它也是一个张量,其中包含各个分布的标准差。
  • *: 表示后面的参数是关键字参数。
  • generator: 可选参数,一个伪随机数生成器。
  • out: 可选参数,输出张量。

注意事项

  1. meanstd的形状不必匹配,但它们的元素总数必须相同。如果形状不匹配,将使用mean的形状作为返回输出张量的形状。
  2. 如果std是一个CUDA tensor,该函数将同步其设备与CPU。
2.2 torch.matmul()函数
python 复制代码
matmul(input, other, *, out=None) -> Tensor

参数说明:

  • input (Tensor): 输入张量。
  • other (Tensor): 另一个张量。
  • *: 表示后面的参数是关键字参数。
  • out (Tensor, optional): 可选参数,输出张量。

行为取决于张量的维度如下:

  • 如果两个张量都是一维的,返回点积(标量)。
  • 如果两个参数都是二维的,返回矩阵-矩阵乘积。
  • 如果第一个参数是一维的,而第二个参数是二维的,为了矩阵乘法,向其维度添加一个1。矩阵乘法之后,添加的维度被移除。
  • 如果第一个参数是二维的,而第二个参数是一维的,返回矩阵-向量乘积。
  • 如果两个参数都至少是一维的,并且至少有一个参数是N维的(其中N>2),则返回批处理矩阵乘法。如果第一个参数是一维的,为了批处理矩阵乘法,向其维度添加一个1,然后在批处理矩阵乘法之后移除它。如果第二个参数是一维的,为了批处理矩阵乘法,向其维度添加一个1,然后在批处理矩阵乘法之后移除它。非矩阵(即批处理)维度是广播的(因此必须可广播)。例如,如果input是一个(j × 1 × n × n)张量,而other是一个(k × n × n)张量,则out将是一个(j × k × n × n)张量。
2.3 d2l.plt.scatter()函数
python 复制代码
scatter(x, y, s=None, c=None, marker=None, cmap=None, norm=None, vmin=None, vmax=None, alpha=None, linewidths=None, *, edgecolors=None, plotnonfinite=False, data=None, **kwargs)

参数说明:

  • x, y:这些是您要在散点图中表示的数据点的x和y坐标。

  • s:散点的面积,以像素为单位。这通常用于根据数据点的值进行大小调整。

  • c:用于颜色映射的单个值或数组,通常表示颜色或数据点的值。

  • marker :散点的形状。例如,'o'表示圆形,'.'表示点,','表示像素等。

  • cmap :颜色映射对象或名称。这决定了如何根据c参数的值映射颜色。

  • norm :用于映射到给定范围的归一化对象。这通常与cmap一起使用,以控制颜色映射的范围。

  • vmin, vmax :这些参数指定了归一化对象的下限和上限。它们与norm一起使用来控制颜色映射的范围。

  • alpha:散点的透明度。值范围从0(完全透明)到1(完全不透明)。

  • linewidths:用于绘制边框线的宽度。当不为None时,这会使散点变为带边框的圆圈。

  • edgecolors:用于边框线的颜色。这可以是单一的颜色或颜色数组,与数据点一一对应。

  • plotnonfinite:如果为True,则非有限数值的数据点将被绘制。默认为False。

  • data:提供给所有数据的原始数据的字典。这通常在传递给函数的数据不是直接参数时使用。

  • kwargs :其他关键字参数将传递给collections.PathCollection的构造函数,允许您自定义散点图的其他方面。例如,您可以指定label来在图例中标识这些点等。

三 ,总结

这段代码的主要目的是生成数据集,并使用散点图可视化其特征和标签。通过这种方式,可以直观地观察到数据分布和特征之间的关系。此外,代码还演示了如何使用PyTorch进行矩阵运算和NumPy数组转换,以及如何使用d2l库中的函数进行绘图操作。

之后我会更新,线性回归的读取数据集,初始化模型参数,定义模型,定义模型,定义损失函数,定义优化算法,训练等步骤。

相关推荐
WeeJot嵌入式几秒前
线性代数与数据挖掘:人工智能中的核心工具
人工智能·线性代数·数据挖掘
明明真系叻23 分钟前
第二十二周机器学习笔记:动手深度学习之——线性代数
笔记·深度学习·线性代数·机器学习·1024程序员节
星光樱梦39 分钟前
02. Python基础知识
python
亚图跨际42 分钟前
MATLAB和C++及Python流式细胞术
c++·python·matlab·流式细胞术
steamedobun1 小时前
【爬虫】Firecrawl对京东热卖网信息爬取(仅供学习)
爬虫·python
右恩1 小时前
Docker 实践与应用举例
python·docker
凤枭香1 小时前
Python Scikit-learn简介(二)
开发语言·python·机器学习·scikit-learn
AI小白龙*1 小时前
Windows环境下搭建Qwen开发环境
人工智能·windows·自然语言处理·llm·llama·ai大模型·ollama
cetcht88881 小时前
光伏电站项目-视频监控、微气象及安全警卫系统
运维·人工智能·物联网
惯师科技1 小时前
TDK推出第二代用于汽车安全应用的6轴IMU
人工智能·安全·机器人·汽车·imu