吴恩达2022机器学习专项课程C2W2:2.17 TensorFlow实现 & 2.18 训练细节

这里写目录标题

本周任务

神经网络如何训练?上周的内容,我们没有关注每一层的最合适的w,b是如何计算出来的,只是直接拿了现成的合适的结果,这周要关注这些参数是如何计算出最优解的。

TensorFlow训练神经网络模型的简要过程

  • :如果给定一个训练集,包含图像 X 和其对应的真实标签 Y,你会如何训练这个神经网络的参数?
  • :第一步,使用Sequential构键一个前向传播的3层神经网络。第二步,指定BinaryCrossentropy(二元交叉熵)为编译模型需要的损失函数。第三步,调用fit函数训练模型。

训练模型的三个步骤

1.自行训练逻辑回归模型

(1)定义模型:定义将输入特征 x 通过参数w和b计算输出的公式,即sigmoid函数。

(2)指定损失函数和成本函数:为了训练逻辑回归模型,需要定义损失函数和成本函数,损失函数(loss)衡量单个训练样本的表现,成本函数则是所有训练样本的平均损失(损失函数求和,然后除以训练样本数m)。

(3)最小化成本函数:使用梯度下降算法来更新参数w和b,以最小化成本函数。

2.TensorFlow训练神经网络模型

(1)定义模型:使用Sequential定义模型。

(2)指定损失函数和成本函数:编译模型时指定二元交叉熵损失函数,该函数对每个训练样本计算损失,整个训练集的平均损失就是神经网络的成本函数。

(3)最小化成本函数:调用fit来最小化神经网络的成本。

TensorFlow训练神经网络模型的代码含义

1.定义模型

Tensorflow指定了神经网络的整体架构,在已经每层参数w,b的情况下,TensorFlow能够计算出最终输出。

2.指定损失函数和成本函数

(1)分类问题:如果解决二分类问题,可以指定TensorFlow内置的损失函数BinaryCrossentropy(二元交叉熵),TensorFlow会自动处理损失函数的构建和最小化过程。


(2)回归问题:解决回归问题可以指定损失函数MeanSquaredError(平方误差)损失函数。

(3)如果自己构建成本函数,需要用到神经网络里的所有参数。

3.最小化成本函数

如果我们自己使用梯度下降来训练神经网络的参数,那么你将重复地对每一层 l 和每个单元 j 更新 w_lj,根据 w_lj 减去学习率 alpha 乘以成本函数 j 的参数导数,以及对参数 b 也是类似的更新。而TensorFlow可以我们完成所有这些工作。它在这个名为 fit 的函数中实现了反向传播。你只需要调用 model.fit,使用x,y 作为你的训练集,并设置100次迭代或100个周期。

总结

首先,介绍了TensorFlow训练神经网络的基本思路,并与逻辑回归模型的训练步骤进行了简要对比,强调了在神经网络训练中需要关注的关键步骤。然后,详细描述了TensorFlow训练神经网络的每个步骤的含义,强调了TensorFlow在简化神经网络训练过程中的重要性。

Quiz

Quiz1

哪种类型的任务中使用二元交叉熵损失函数(BinaryCrossentropy)?

A:BinaryCrossentropy()不应用于任何任务。

B:具有3个或更多类别(类别)的分类任务

C:二元分类(正好有2个分类)

D:回归任务(预测数字的任务)

正确答案C:二元交叉熵,我们也称之为逻辑损失,用于分类两个类别(两个类别)之间。

Quiz2


哪行代码用于更新网络参数以减少成本?

A:model = Sequential([...)

B:以上都不是-这段代码不会更新参数

C:model.fit(X, y, epochs=100)

D:model.compile(loss=BinaryCrossentropy()

正确答案C:模型训练的第三步是在数据上训练模型以最小化损失(和成本)。

相关推荐
想跑步的小弱鸡5 小时前
Leetcode hot 100(day 3)
算法·leetcode·职场和发展
Uzuki5 小时前
AI可解释性 II | Saliency Maps-based 归因方法(Attribution)论文导读(持续更新)
深度学习·机器学习·可解释性
xyliiiiiL6 小时前
ZGC初步了解
java·jvm·算法
爱的叹息7 小时前
RedisTemplate 的 6 个可配置序列化器属性对比
算法·哈希算法
蹦蹦跳跳真可爱5897 小时前
Python----机器学习(KNN:使用数学方法实现KNN)
人工智能·python·机器学习
独好紫罗兰7 小时前
洛谷题单2-P5713 【深基3.例5】洛谷团队系统-python-流程图重构
开发语言·python·算法
zhuyixiangyyds8 小时前
day21和day22学习Pandas库
笔记·学习·pandas
每次的天空8 小时前
Android学习总结之算法篇四(字符串)
android·学习·算法
请来次降维打击!!!8 小时前
优选算法系列(5.位运算)
java·前端·c++·算法
qystca9 小时前
蓝桥云客 刷题统计
算法·模拟