【李宏毅机器学习·学习笔记】Deep Learning General Guidance

本节课可视为机器学习系列课程的一个前期攻略,这节课主要对Machine Learning 的框架进行了简单的介绍;并以training data上的loss大小为切入点,介绍了几种常见的在模型训练的过程中容易出现的情况。

课程视频:

Youtube: https://www.youtube.com/watch?v=WeHM2xpYQpw
课程PPT:
https://view.officeapps.live.com/op/view.aspx?src=https%3A%2F%2Fspeech.ee.ntu.edu.tw%2F~hylee%2Fml%2Fml2021-course-data%2Foverfit-v6.pptx&wdOrigin=BROWSELINK

以下是本节课的课程笔记。

一、Framework of ML

机器学习的数据集总体上分为训练集 (training data)和测试集 (testing data)。其中训练集由feature x 和ground truth y组成,模型在训练集上学习x和y之间的隐含关系,再在测试集上对模型的好坏进行验证。

模型在训练集上的training大致可以分为以下三个steps:

Step1:初步划定一个model set:y = f(x),其中模型 f 由系列参数 𝜽 确定,如果𝜽的值不同,我们则说模型不同。

Step2:划定好model set后就需要定义一个loss function 来对模型的好坏进行评估,通常,loss function反映的是模型的预测值和ground truth之间的差距,差距越小(loss值越小),则模型越好。

Step3:定义好loss function后,就开始对模型进行优化,找到让loss指最小的参数集合𝜽*,𝜽* 所对应的model f* 即为我们最终想要学习到的模型。

二、General Guide

在训练模型的过程中,我们往往会根据training data上的loss值来初步判断模型的好坏。

1.training data上loss过大

导致training data上loss值过大的原因主要有以下两个:

(1)model bias

即模型模型太简单(大海捞针,但针不在海里),通常的解决措施是重新设计模型使其具有更大的弹性,例如,在输入中增加更多的feature,或者使用deep learning以增加模型的弹性(more neurons, layers)。

(2)optimization

optimization做得不好,没有找到最优的function(大海捞针,针在海里,但就是没捞到)。例如我们通常使用gradient decent的optimization方法,但这种方法可能会卡在local minimum的地方,从而导致我们没有找到全局最小解。如果是optimization做的不好,我们需要使用更powerful的optimization方法,这在后面的学习中会有介绍。

Q:如何判断训练集上的loss大时由model bias还是optimization引起的?

(参考文献:Deep Residual Learning for Image Recognition

主要是通过对不同的模型进行比较来判断(判断模型是否足够大)。当我们看到一个从来没有做过的问题,可以先跑一些比较浅的network,甚至一些不属于DL的方法,因为这些方法不太会有optimization失败的问题。如果在训练集上deeper network反而没有得到更小的loss,则可能是optimization出了问题。(注意:过拟合是deeper network在训练集上loss小,在测试集上loss大)

例如,在下图右部分,56-layer的loss值较之20-layer的反而更大,则很可能是opyimization出了问题。

2.training data上loss值较小

如果在training data上的loss值比较小,则可以看看模型在测试集上的表现了。如果测试值上loss值很小,那这正是我们期待的结果。如果很不幸,模型在测试集上loss较大,此时又可大致分为两种情况:

(1)overfitting

overfitting即模型过度地对训练数据进行了拟合,把一些非common feature当做common data学习到了。此事的solution主要有:

A. more training data,即增加更多的训练数据;

B. data augmentation,如果训练数据有限,则可以在原有数据的基础上,通过一些特殊处理,创造一些资料。

C. make your model simpler,常见的举措有:

  • less parameters/ sharing parameters (让一些model共用参数)
  • early stopping
  • regularization
  • dropout

(2)mismatch

mismatch则是由于训练资料和测试资料的分布不一致导致的,这个时候增加训练资料也没用。在HW11中会具体讲解这类情况。

三、如何保证选择的model是合理的

如果一个模型只是在训练集上强行将输入x 和ground truth y相关联,而没有学习到一些实质性的东西,那么到了测试集上模型的表现将会是很差的。通常的解决措施是引入交叉验证。

1.Cross Validation

在训练时,我们从测试集中划出部分数据作为validation set来衡量loss,根据validation set上的得分情况去挑选最优的模型,再在测试集上对模型的好坏进行验证。

2.N-fold Cross Validation

如果训练数据较少,可采用N折交叉验证的方法。即将训练数据分为N等份,依次以第一份、第二份......第 i 份作为验证集(其余作为测试集),这样重复N次对模型进行训练、验证。在将这N次训练中各个模型在验证集上的N次得分的平均进行比较,选择loss最小的模型作为我们的最优模型,并用它在测试集上对模型进行评分。

相关推荐
moxiaoran57537 分钟前
uni-app学习笔记十八--uni-app static目录简介
笔记·学习·uni-app
#guiyin1113 分钟前
基于机器学习的心脏病预测模型构建与可解释性分析
人工智能·机器学习
不会敲代码的灵长类20 分钟前
机器学习算法-k-means
算法·机器学习·kmeans
Studying 开龙wu21 分钟前
机器学习有监督学习sklearn实战二:六种算法对鸢尾花(Iris)数据集进行分类和特征可视化
学习·算法·机器学习
IMA小队长28 分钟前
06.概念二:神经网络
人工智能·深度学习·机器学习·transformer
罗西的思考31 分钟前
探秘Transformer系列之(35)--- 大模型量化基础
人工智能·深度学习·机器学习
Lester_11011 小时前
嵌入式学习笔记 - STM32 HAL库以及标准库内核以及外设头文件区别问题
笔记·stm32·单片机·学习
Moonnnn.2 小时前
2023年电赛C题——电感电容测量装置
笔记·学习·硬件工程
拾忆-eleven3 小时前
NLP学习路线图(十六):N-gram模型
人工智能·学习·自然语言处理
编程有点难3 小时前
Python训练打卡Day39
人工智能·python·深度学习