基于python开发用于深度学习模型训练过程loss值曲线的平滑处理模块

深度学习网络模型的loss曲线是训练过程中非常重要的一个监控指标,它能够直观地反映模型的学习状态以及可能存在的问题。以下是对深度学习网络模型loss曲线的详细介绍:

一、loss曲线的基本概念

在深度学习的训练过程中,loss函数用于衡量模型预测结果与实际标签之间的差异。loss曲线则是通过记录每个epoch(或者迭代步数)的loss值,并将其以图形化的方式展现出来,以便我们更好地理解和分析模型的训练过程。

二、loss曲线的解读

  1. loss值的变化趋势:
    • 如果loss值随着训练的进行而逐渐降低,说明模型正在学习并优化,这是一个正常的训练过程。
    • 如果loss值在训练初期迅速下降,但随后趋于稳定或波动较小,可能意味着模型已经收敛,或者陷入了局部最优解。
    • 如果loss值在训练过程中出现剧烈波动,可能是学习率设置不当、模型结构复杂度过高等原因导致的。
  2. 训练和验证loss的对比:
    • 训练loss和验证loss的差距可以反映模型的过拟合程度。如果训练loss持续下降而验证loss却开始上升,说明模型可能出现了过拟合现象。
    • 理想的训练过程应该是训练loss和验证loss都逐渐下降,且两者之间的差距较小。
  3. 不同阶段的loss变化:
    • 在训练初期,由于模型参数是随机初始化的,因此loss值通常会比较大。随着训练的进行,loss值会逐渐降低并趋于稳定。
    • 在训练后期,如果模型没有出现过拟合现象,loss值应该能够稳定在一个较低的水平上。

三、loss曲线的绘制与监控

在深度学习框架(如TensorFlow、PyTorch等)中,通常都提供了绘制loss曲线的功能。通过调用相应的API或库(如matplotlib、Visdom等),我们可以方便地绘制出训练过程中的loss曲线,并对其进行实时监控和分析。

四、loss曲线的优化策略

针对loss曲线反映出的问题,我们可以采取以下优化策略:

  1. 调整学习率:学习率是影响loss曲线变化的重要因素之一。如果学习率设置得过大,可能会导致loss值在训练过程中出现剧烈波动;如果学习率设置得过小,则可能会导致训练过程过于缓慢。因此,我们需要根据loss曲线的变化情况来适时调整学习率的大小。
  2. 添加正则化项:正则化项可以有效地防止模型过拟合。通过向损失函数中添加正则化项(如L1正则化、L2正则化等),我们可以限制模型参数的复杂度,从而降低过拟合的风险。
  3. 使用更复杂的模型结构:如果模型的复杂度不够高,可能无法充分拟合训练数据中的复杂模式。在这种情况下,我们可以尝试使用更复杂的模型结构(如增加网络层数、使用更复杂的激活函数等)来提高模型的拟合能力。
  4. 增加训练数据:增加训练数据可以提供更多的信息供模型学习,从而降低过拟合的风险。如果条件允许的话,我们可以尝试增加训练数据的数量或多样性来提高模型的性能。

实际工作中,经常会需要训练构建深度学习模型,相信做这块工作的同学对于loss曲线一定不会陌生的,大家肯定也都经常在模型开发过程中实际去绘制模型的loss曲线,在一些特殊的场景下需要对原始产生的loss曲线进行平滑处理,这里主要是记录实践这块的内容。

这里我以经常使用的keras框架,介绍下我常用的讲模型训练过程日志进行记录存储的方式,核心代码实现如下:

python 复制代码
#记录日志
history = model.fit(
    X_train,
    y_train,
    validation_data=(X_test, y_test),
    #传入回调
    callbacks=[checkpoint],
    epochs=nepoch,
    batch_size=32,
)
print(history.history.keys())
# loss提取
lossdata, vallossdata = history.history["loss"], history.history["val_loss"]
# 绘制loss曲线
plot_both_loss_acc_pic(
    lossdata, vallossdata, picpath=saveDir + "train_val_loss.png"
)
history = {}
#提取训练过程对应的log
history["loss"], history["val_loss"] = lossdata, vallossdata
#存储日志数据
with open(saveDir + "history.json", "w") as f:
    f.write(json.dumps(history))

这里我给出history.json的样例数据,如下所示:

python 复制代码
{"loss": [0.0239631230070447, 0.0075705342770514186, 0.004838935030165967, 0.0037340148873459002, 0.002886130001751231, 0.0024534663854011295, 0.0023201104651917924, 0.002976924244579323, 0.002085769131966776, 0.0018753843622715224, 0.0019806173175960172, 0.002174197305382795, 0.001658159012761194, 0.001545024904081888, 0.001667008826952705, 0.0013947380403409929, 0.0012537746476829388, 0.0014786866023657216, 0.0016623390785946131, 0.0016191040555174983, 0.0014966548395261134, 0.001477120676483648, 0.0016280919435364668, 0.0017182350213351422, 0.0038554738028685545, 0.0027464564262130392, 0.0017087835722348526, 0.0014510032096478255, 0.001268975875018749, 0.001481830868523139, 0.001604654047318627, 0.0011948789410770326, 0.001490574051798416, 0.001524109376014187, 0.0015062743931394868, 0.0013054789145924908, 0.0011241542828178905, 0.0010764475793075279, 0.0011480460991939996, 0.0012678029520214276, 0.0012396599495106504, 0.0011639709738934618, 0.0012134075943700145, 0.0012499850020485322, 0.001329989843023205, 0.0011846670753724083, 0.001357856133803473, 0.0015265580290890642, 0.0012421558107066537, 0.001249045898042552, 0.0013697822622925414, 0.0010749583650784015, 0.0010974660338928532, 0.0010916401195769782, 0.0010911698460223627, 0.001078350035803241, 0.001045568893730859, 0.001084814094107926, 0.0011569271895574074, 0.0011443737715600441, 0.001247118570911225, 0.0012540589338988402, 0.0011518743058927274, 0.0015513227900919035, 0.0017111857056945697, 0.0015170943725414776, 0.001481423723410487, 0.0011165965530857377, 0.0016210588698042031, 0.002381780790270182, 0.0011541179547393296, 0.0013710694562572288, 0.0012280710985459404, 0.001037340645381916, 0.0010694707121014697, 0.0009750368017871479, 0.001008019566722004, 0.0011101727457727573, 0.0012511928422843225, 0.001071397447170223, 0.0011470449074543591, 0.0015238439674756194, 0.0010109543884446336, 0.0011297101726488506, 0.001058421874235954, 0.001103364821769398, 0.001025826505723811, 0.0010999036314539848, 0.001329398845137427, 0.0017114742325290903, 0.0011102726525873048, 0.0011274378092930091, 0.0011542693009646294, 0.0011940637438370937, 0.0012636104229160712, 0.0013925317771055863, 0.00100061368093664, 0.0011615896552567776, 0.0010081333990953022, 0.001092779955855081], "val_loss": [0.004923976918584422, 0.00991965542106252, 0.01076323433877214, 0.003843901578434988, 0.011352231488318036, 0.0016448196832482753, 0.0016787166668923179, 0.02015221753696862, 0.003941944209049995, 0.0029026116281257648, 0.0045372380556440665, 0.0014563935330155992, 0.001654032355260202, 0.0013641683946173688, 0.0015195327850769421, 0.0011578488043405262, 0.0012232080662598539, 0.00509458419768826, 0.0012246073744455843, 0.0023663273782738923, 0.0011423173363590124, 0.006865146876263775, 0.0020036918448137217, 0.007316410553788668, 0.001553758288929729, 0.0013593508740948317, 0.0025380967877266045, 0.0023082743653120765, 0.0013224915555359697, 0.00858367411909919, 0.0009927515703326974, 0.0010470627885201553, 0.0011798253622959907, 0.0024045295798905976, 0.0015412836871722614, 0.0038771789925368992, 0.0015362703578399592, 0.001756014192697445, 0.00334732801114258, 0.000975109149983741, 0.0046767660281866, 0.0018946981394516401, 0.0021767043220614524, 0.004211987026869074, 0.0009522750635177975, 0.0021094563270085734, 0.0037733877482088772, 0.001548874757549799, 0.0027838850510306656, 0.008273044527557335, 0.00123940688829048, 0.0016841785786183257, 0.0009756766973479994, 0.001928586675479126, 0.0011492695222075685, 0.0012013394433827336, 0.0010477521618380897, 0.00121309975940293, 0.0030147337820380926, 0.0013649057897150909, 0.0023210468165895067, 0.0011219763923068775, 0.0017544153219971217, 0.0030385015789713516, 0.0016239398731206739, 0.0031037202962723217, 0.002162101651590906, 0.003717466969484169, 0.0033957000386803165, 0.0009902583321826043, 0.00193247984708777, 0.001976198960389746, 0.0027693257654870028, 0.0025635553493262514, 0.0013357499459648113, 0.0012082410958675226, 0.001168333794186382, 0.0025652966841957286, 0.0010059437105634347, 0.0009358364489775054, 0.0036403173617528457, 0.0009317236960142556, 0.0015049418612187238, 0.0017247698554695634, 0.0010254738238903596, 0.001047871537096063, 0.0009514076437641818, 0.0036001800608478096, 0.014663169037942824, 0.002012193938227076, 0.001970677826351388, 0.0037272977164799445, 0.0012484785829560438, 0.002330363199182198, 0.0011025683723059237, 0.0013020975239525893, 0.001059662765137067, 0.0009167807317632986, 0.0009290355350362676, 0.0012791703282058924]}

接下来看下原始loss数据绘制出来的对比可视化曲线:

整体波形不断,也反映了模型实际训练过程并不够稳定,这里抛开模型训练的因素,单纯地基于曲线数据进行分析,想要对其进行平滑处理,得到的效果如下:

核心实现就是使用scipy.signal.savgol_filter方法,scipy.signal.savgol_filter 是 SciPy 库中的一个函数,用于对一维数据序列应用 Savitzky-Golay 平滑滤波器。这个滤波器是一种局部多项式回归的技术,能够在平滑数据的同时尽量保留数据的特征形状,如峰值和谷值,因此特别适用于信号去噪和数据平滑处理,尤其是那些包含噪声的实验数据或时序数据。

以下是 savgol_filter 函数的主要参数及其说明:

  • x (array_like):要过滤的一维数据序列。如果 x 不是单精度或双精度浮点数数组,它将在过滤前被转换为这种类型。

  • window_length (int):滤波器窗口的长度,即应用于数据点上的局部多项式拟合所使用的相邻数据点的数量。这个值必须是奇数,并且 polyorder + 1 <= window_length

  • polyorder (int):拟合局部数据点的多项式的阶数。它决定了平滑程度,阶数越高,可以拟合更复杂的曲线,但也会更多地改变原始数据的特性。必须满足 polyorder < window_length

  • deriv (int, 可选):指定是否计算导数以及计算哪阶导数。默认为0,表示直接平滑数据;大于0的值用于计算相应阶数的导数。

  • delta (float, 可选):采样点之间的间距,默认为1.0。仅在计算导数(deriv > 0)时使用。

  • axis (int, 可选):当输入数据 x 的维度大于1时,指定沿哪个轴应用滤波器。默认为-1,表示最后一个轴。

  • mode (str, 可选):决定如何处理边界效应,可选值有 'mirror''constant''nearest''wrap''interp'。默认为 'interp',表示通过线性插值来扩展数据以处理边界。选择 'mirror' 会在边界处镜像数据,而 'constant' 则会使用边缘值填充。

  • cval (float, 可选):当 mode'constant' 时使用的常数值。默认为0.0。

使用示例

假设我们有一个包含噪声的一维数据列表 data,我们可以使用 savgol_filter 来平滑这些数据:

python 复制代码
from scipy.signal import savgol_filter
import numpy as np

# 假设 data 是一个包含噪声的数据序列
data = np.random.randn(100)  # 生成随机噪声数据作为示例
window_length = 5  # 窗口长度
polyorder = 3  # 多项式阶数

# 应用 Savitzky-Golay 滤波器
smoothed_data = savgol_filter(data, window_length, polyorder)

# 然后可以绘制原始数据和平滑后的数据进行对比
import matplotlib.pyplot as plt

plt.figure()
plt.plot(data, label='Noisy data')
plt.plot(smoothed_data, label='Smoothed data')
plt.legend()
plt.show()

借助于scipy.signal.savgol_filter方法,我们可以非常方便快捷地实现对原生loss曲线的平滑化处理,这里为了直观对比效果,我们绘制对比可视化曲线,如下所示:

有需要的也都可以尝试下。

完整代码实现如下:

python 复制代码
def lossPloter(train_loss,val_loss):
    """
    loss曲线对比可视化
    """
    iters = range(len(train_loss))
    #单独绘制原始loss曲线
    plt.clf()
    plt.figure(figsize=(10,6))
    plt.plot(iters, train_loss, 'red', linewidth = 2, label='train loss')
    plt.plot(iters, val_loss, 'coral', linewidth = 2, label='val loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('A Loss Curve')
    plt.legend(loc="upper right")
    plt.savefig("original_loss.png")
    num = 5 if len(train_loss)<25 else 15
    #插值平滑处理
    train_loss_smooth=scipy.signal.savgol_filter(train_loss, num, 3)
    val_loss_smooth=scipy.signal.savgol_filter(val_loss, num, 3)
    for i in range(5):
        val_loss_smooth=scipy.signal.savgol_filter(val_loss_smooth, num, 3)
    #二者同时绘制
    plt.clf()
    plt.figure(figsize=(10,6))
    plt.plot(iters, train_loss, 'red', linewidth = 2, label='train loss')
    plt.plot(iters, val_loss, 'coral', linewidth = 2, label='val loss')
    plt.plot(iters, train_loss_smooth, 'green', linestyle = '--', linewidth = 2, label='smooth train loss')
    plt.plot(iters, val_loss_smooth, '#8B4513', linestyle = '--', linewidth = 2, label='smooth val loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('A Loss Curve')
    plt.legend(loc="upper right")
    plt.savefig("compare_loss.png")
    plt.cla()
    plt.close("all")
    #单独绘制平滑曲线
    plt.clf()
    plt.figure(figsize=(10,6))
    plt.plot(iters, train_loss_smooth, 'green', linewidth = 2, label='smooth train loss')
    plt.plot(iters, val_loss_smooth, 'blue', linewidth = 2, label='smooth val loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Loss Curve')
    plt.legend(loc="upper right")
    plt.savefig("smooth_loss.png")

会得到三幅图像:

original_loss.png: 原始loss对比曲线

smooth_loss.png: 平滑化的loss对比曲线

compare_loss.png: 二者对比曲线

感兴趣的话可以尝试下!

相关推荐
程序猿小D13 分钟前
第二百六十七节 JPA教程 - JPA查询AND条件示例
java·开发语言·前端·数据库·windows·python·jpa
Yvemil718 分钟前
RabbitMQ 入门到精通指南
开发语言·后端·ruby
潘多编程27 分钟前
Java中的状态机实现:使用Spring State Machine管理复杂状态流转
java·开发语言·spring
冷静 包容1 小时前
C语言学习之 没有重复项数字的全排列
c语言·开发语言·学习
碳苯1 小时前
【rCore OS 开源操作系统】Rust 枚举与模式匹配
开发语言·人工智能·后端·rust·操作系统·os
结衣结衣.1 小时前
C++ 类和对象的初步介绍
java·开发语言·数据结构·c++·笔记·学习·算法
学习使我变快乐1 小时前
C++:静态成员
开发语言·c++
TJKFYY1 小时前
Java.数据结构.HashSet
java·开发语言·数据结构
杰哥在此2 小时前
Python知识点:如何使用Multiprocessing进行并行任务管理
linux·开发语言·python·面试·编程
小白学大数据2 小时前
User-Agent在WebMagic爬虫中的重要性
开发语言·爬虫·http