【ShuQiHere】从零开始实现线性回归:深入理解反向传播与梯度下降

【ShuQiHere】

线性回归是一种简单但强大的回归分析方法,主要用于预测连续值。它在许多领域都有广泛的应用,尤其是当我们需要根据已有数据来预测未来的趋势时,线性回归显得尤为重要。虽然它是机器学习中最基础的算法之一,但理解其原理对掌握更复杂的算法至关重要。本文将带你一步步从零开始实现线性回归,并深入探讨反向传播与梯度下降这两个核心算法,帮助你打下扎实的基础。

线性回归的数学基础

线性回归的目标是找到一个线性函数,该函数能够尽可能准确地预测目标变量 ( Y ) 的值。这个线性函数的形式如下:

[
Y = W X + b Y = WX + b Y=WX+b

]

在这个公式中:

  • ( W ) 是权重向量,表示每个输入特征对输出的影响程度。
  • ( X ) 是输入特征向量,即我们用来进行预测的输入数据。
  • ( b ) 是偏置项,它帮助模型更好地拟合数据,特别是在输入特征的值很小时。

简单来说,线性回归试图找到一条直线(或者在多维情况下的一个超平面),使得这条线尽可能接近数据点。

损失函数

为了衡量模型的预测值与实际值之间的差异,我们使用均方误差(Mean Squared Error, MSE)作为损失函数。均方误差计算的是预测值与真实值之间的平均平方差异:

[
MSE = 1 2 m ∑ i = 1 m ( Y ( i ) − Y ^ ( i ) ) 2 \text{MSE} = \frac{1}{2m} \sum_{i=1}^{m} (Y^{(i)} - \hat{Y}^{(i)})^2 MSE=2m1i=1∑m(Y(i)−Y^(i))2

]

在这里:

  • ( m ) 是样本数量。
  • ( \hat{Y}^{(i)} ) 是第 ( i ) 个样本的预测值。
  • ( Y^{(i)} ) 是第 ( i ) 个样本的真实值。

损失函数的值越小,说明模型的预测值越接近真实值。我们希望通过调整模型参数来最小化这个损失。

反向传播(Backward Propagation)

反向传播是一种计算梯度的算法,用于指导模型如何更新参数以减小损失。在反向传播中,我们首先计算损失函数相对于模型参数(即 ( W ) 和 ( b ) )的导数(梯度)。这些梯度表示了损失函数相对于权重和偏置的变化速率,它们告诉我们如何调整这些参数才能更快地减小损失。

对于线性回归模型,损失函数对权重 ( W ) 和偏置 ( b ) 的梯度计算如下:

[
∂ Loss ∂ W = 1 m ∑ i = 1 m ( Y ( i ) − Y ^ ( i ) ) × X ( i ) \frac{\partial \text{Loss}}{\partial W} = \frac{1}{m} \sum_{i=1}^{m} (Y^{(i)} - \hat{Y}^{(i)}) \times X^{(i)} ∂W∂Loss=m1i=1∑m(Y(i)−Y^(i))×X(i)

]

[
∂ Loss ∂ b = 1 m ∑ i = 1 m ( Y ( i ) − Y ^ ( i ) ) \frac{\partial \text{Loss}}{\partial b} = \frac{1}{m} \sum_{i=1}^{m} (Y^{(i)} - \hat{Y}^{(i)}) ∂b∂Loss=m1i=1∑m(Y(i)−Y^(i))

]

这些公式的含义是:我们计算每个样本的预测误差,然后根据这些误差的方向和大小来更新模型的权重和偏置,以减少整体损失。

反向传播的实现

python 复制代码
def backward_propagation(X, A, Y):
    m = X.shape[1]
    dw = (1/m) * np.dot(X, (A - Y).T)
    db = (1/m) * np.sum(A - Y)
    return dw, db

在这个实现中,dw 是权重的梯度,db 是偏置的梯度。我们通过这些梯度来指导参数的更新。

梯度下降(Gradient Descent)

梯度下降是一种迭代优化算法,用于更新模型的参数,使损失函数达到最小值。其核心思想是沿着损失函数下降最快的方向(即梯度的反方向)更新参数。

梯度下降的更新规则如下:

[
W : = W − α × ∂ Loss ∂ W W := W - \alpha \times \frac{\partial \text{Loss}}{\partial W} W:=W−α×∂W∂Loss

]

[
b : = b − α × ∂ Loss ∂ b b := b - \alpha \times \frac{\partial \text{Loss}}{\partial b} b:=b−α×∂b∂Loss

]

在这里:

  • ( \alpha ) 是学习率,它决定了每次参数更新的步伐大小。如果学习率太大,模型可能会跳过最优解;如果学习率太小,收敛速度会很慢。

梯度下降的实现

python 复制代码
def update_parameters(w, b, dw, db, learning_rate):
    w = w - learning_rate * dw
    b = b - learning_rate * db
    return w, b

在这个实现中,我们使用计算得到的梯度 dwdb 来更新权重 w 和偏置 b ,以逐步减小损失。

线性回归的实现步骤

理解了反向传播和梯度下降之后,我们可以开始实现一个完整的线性回归模型。我们将按照以下步骤进行实现:

1. 初始化参数

首先,我们需要初始化模型的权重 ( W ) 和偏置 ( b )。在这个例子中,我们将权重初始化为零或小的随机值,偏置初始化为零。

python 复制代码
def initialize_parameters(dim):
    w = np.zeros((dim, 1))
    b = 0
    return w, b

2. 前向传播

前向传播用于计算模型的输出(预测值)。这是通过将输入特征与权重相乘并加上偏置来实现的。

python 复制代码
def forward_propagation(X, w, b):
    return np.dot(w.T, X) + b

3. 计算损失

我们使用均方误差来衡量预测值与真实值之间的差异。这一步骤非常关键,因为它告诉我们当前模型的性能如何。

python 复制代码
def compute_loss(A, Y):
    m = Y.shape[1]
    loss = (1/(2*m)) * np.sum((A - Y) ** 2)
    return loss

4. 模型训练

在模型训练过程中,我们通过多次迭代,不断进行前向传播、计算损失、反向传播以及参数更新,最终得到一个能够准确预测的模型。

python 复制代码
def train(X, Y, num_iterations, learning_rate):
    w, b = initialize_parameters(X.shape[0])
    for i in range(num_iterations):
        A = forward_propagation(X, w, b)
        loss = compute_loss(A, Y)
        dw, db = backward_propagation(X, A, Y)
        w, b = update_parameters(w, b, dw, db, learning_rate)
        if i % 100 == 0:
            print(f"Iteration {i}, Loss: {loss}")
    return w, b

在这个实现中,我们通过多次迭代来训练模型,并在每 100 次迭代时输出当前的损失值,以便跟踪模型的学习进度。

5. 模型评估

最后,我们可以使用均方误差来评估模型的性能,查看模型在测试集上的表现如何。

python 复制代码
def evaluate(X, Y, w, b):
    A = forward_propagation(X, w, b)
    return compute_loss(A, Y)

结论

线性回归虽然简单,但它是机器学习中至关重要的基础模型。通过深入理解其实现过程中的反向传播和梯度下降,我们可以更好地理解机器学习的核心思想。这些知识不仅有助于掌握线性回归的实现,还为学习更复杂的机器学习模型打下了坚实的基础。希望本文的讲解能帮助你更好地理解线性回归,并激发你对机器学习更深层次的探索欲望。

相关推荐
带多刺的玫瑰41 分钟前
Leecode刷题C语言之统计不是特殊数字的数字数量
java·c语言·算法
爱敲代码的憨仔42 分钟前
《线性代数的本质》
线性代数·算法·决策树
yigan_Eins1 小时前
【数论】莫比乌斯函数及其反演
c++·经验分享·算法
阿史大杯茶1 小时前
AtCoder Beginner Contest 381(ABCDEF 题)视频讲解
数据结构·c++·算法
დ旧言~2 小时前
【高阶数据结构】图论
算法·深度优先·广度优先·宽度优先·推荐算法
张彦峰ZYF2 小时前
投资策略规划最优决策分析
分布式·算法·金融
The_Ticker2 小时前
CFD平台如何接入实时行情源
java·大数据·数据库·人工智能·算法·区块链·软件工程
爪哇学长3 小时前
双指针算法详解:原理、应用场景及代码示例
java·数据结构·算法
Dola_Pan3 小时前
C语言:数组转换指针的时机
c语言·开发语言·算法
IT古董3 小时前
【人工智能】Python在机器学习与人工智能中的应用
开发语言·人工智能·python·机器学习