梯度下降算法的计算过程

1 小批量梯度下降(Mini-Batch Gradient Descent, MBGD)

  • 1.1划分数据集为多个小批量。
  • 1.2前向传播:对于每个小批量中的所有样本进行一次前向传播,得到预测输出。
  • 1.3计算损失:然后计算这些预测输出相对于真实标签的总损失。通常是累加每个样本的损失来完成。
  • 1.4反向传播:执行反向传播以计算当前小批量上损失函数关于模型参数的梯度,这是通过自动微分工具自动完成,它会为每一个参数计算出一个梯度值。
  • 1.5计算平均梯度
    • 前向传播:对于一个给定的小批量(mini-batch),假设包含m个样本。对于每个样本 x i {x}{i} xi,通过前向传播计算出预测值 y i ^ = f ( x i ; θ ) \hat{{y}{i}}=f({x}{i};\theta) yi^=f(xi;θ)。 y i ^ \hat{{y}{i}} yi^是关于样本值和模型参数的函数。
    • 计算损失:基于预定义的损失函数计算预测值和标签值的差异,即损失。损失函数形式为: J ( x i , y i ; θ ) = L ( y i ^ , y i ) J({x}{i},{y}{i};\theta)=L(\hat{{y}{i}}, {y}{i}) J(xi,yi;θ)=L(yi^,yi)。 J J J是关于 ( y i ^ , y i ) (\hat{{y}{i}}, {y}{i}) (yi^,yi)的函数。
    • 反向传播:基于链式法则,从输出层开始,逐层向后计算梯度。具体来说,对于每一层的参数 θ j \theta_{j} θj,计算该参数的梯度 ∇ θ j J ( x i , y i ; θ j ) \nabla_{\theta_{j}}J({x}{i},{y}{i};\theta_{j}) ∇θjJ(xi,yi;θj)
      ∂ L ∂ θ j = ∂ L ∂ y ^ ⋅ ∂ y ^ ∂ θ j \frac{\partial L}{\partial \theta_{j}}=\frac{\partial L}{\partial \hat{y}} \cdot \frac{\partial \hat{y}}{\partial \theta_{j}} ∂θj∂L=∂y^∂L⋅∂θj∂y^
      由于每个小批量有多个样本,反向传播会得到一组梯度值,最终结果取梯度的平均值。
      ∇ θ j J ˉ = 1 m ∑ i = 1 m ∇ θ j J ( x i , y i ; θ j ) \nabla_{\theta_{j}}\bar{J}=\frac{1}{m}\sum_{i=1}^{m}\nabla_{\theta_{j}}J({x}{i},{y}{i};\theta_{j}) ∇θjJˉ=m1∑i=1m∇θjJ(xi,yi;θj)
    • 参数更新:基于上述计算出的平均梯度更新模型参数。对于每个参数 θ j \theta_{j} θj,按照以下公式进行更新:
      θ j : = θ j − ϵ ∇ θ j J ˉ \theta_{j} :=\theta_{j} - \epsilon\nabla_{\theta_{j}}\bar{J} θj:=θj−ϵ∇θjJˉ,其中 ϵ \epsilon ϵ是模型学习率。

2 带动量的梯度下降

  • 2.1设置学习率 ϵ \epsilon ϵ和动量参数 α \alpha α。
  • 2.2 计算当前小批量的平均梯度
    g = 1 m ∑ i = 1 m ∇ θ j J ( x i , y i ; θ j ) g=\frac{1}{m}\sum_{i=1}^{m}\nabla_{\theta_{j}}J({x}{i},{y}{i};\theta_{j}) g=m1∑i=1m∇θjJ(xi,yi;θj)
  • 2.3 计算速度更新
    ν ← α ν − ϵ g \nu \gets \alpha\nu - \epsilon g ν←αν−ϵg
  • 2.4更新参数
    θ ← θ + ν \theta \gets \theta + \nu θ←θ+ν
相关推荐
旺仔.29112 分钟前
STL排序算法详解
数据结构·算法·排序算法
美狐美颜sdk19 分钟前
美颜SDK是什么?直播/短视频美颜SDK技术详解
人工智能·算法·美颜sdk·直播美颜sdk·美颜api
章鱼丸-28 分钟前
DAY 42 Grad-CAM 与 Hook 函数
pytorch·深度学习·计算机视觉
剑穗挂着新流苏31228 分钟前
207_深度学习调优:透彻理解权重衰退(L2 正则化)
人工智能·机器学习
机器学习之心30 分钟前
多工况车速数据集训练BiGRU双向门控循环单元用于车速预测,输出未来多个时间步车速,MATLAB代码
深度学习·matlab·双向门控循环单元·gru·bigru·车速预测
这张生成的图像能检测吗34 分钟前
(论文速读)FDGLM:面向多场景工业故障诊断的深度数字双动力大视觉语言模型
人工智能·深度学习·计算机视觉·故障诊断·视觉语言大模型·问答模型
华农DrLai39 分钟前
什么是远程监督?怎么自动生成训练数据?
人工智能·算法·llm·prompt·知识图谱
计算机安禾41 分钟前
【数据结构与算法】第16篇:串(String)的定长顺序存储与朴素模式匹配
c语言·数据结构·c++·学习·算法·visual studio code·visual studio
Roselind_Yi41 分钟前
【吴恩达2026 Agentic AI】面试向+项目实战(含面试题+项目案例)-2
人工智能·python·机器学习·面试·职场和发展·langchain·agent
AI科技星1 小时前
基于v≡c公设的理论优化方案
c语言·开发语言·算法·机器学习·数据挖掘