吴恩达2022机器学习专项课程C2W2:2.17 TensorFlow实现 & 2.18 训练细节

这里写目录标题

本周任务

神经网络如何训练?上周的内容,我们没有关注每一层的最合适的w,b是如何计算出来的,只是直接拿了现成的合适的结果,这周要关注这些参数是如何计算出最优解的。

TensorFlow训练神经网络模型的简要过程

  • :如果给定一个训练集,包含图像 X 和其对应的真实标签 Y,你会如何训练这个神经网络的参数?
  • :第一步,使用Sequential构键一个前向传播的3层神经网络。第二步,指定BinaryCrossentropy(二元交叉熵)为编译模型需要的损失函数。第三步,调用fit函数训练模型。

训练模型的三个步骤

1.自行训练逻辑回归模型

(1)定义模型:定义将输入特征 x 通过参数w和b计算输出的公式,即sigmoid函数。

(2)指定损失函数和成本函数:为了训练逻辑回归模型,需要定义损失函数和成本函数,损失函数(loss)衡量单个训练样本的表现,成本函数则是所有训练样本的平均损失(损失函数求和,然后除以训练样本数m)。

(3)最小化成本函数:使用梯度下降算法来更新参数w和b,以最小化成本函数。

2.TensorFlow训练神经网络模型

(1)定义模型:使用Sequential定义模型。

(2)指定损失函数和成本函数:编译模型时指定二元交叉熵损失函数,该函数对每个训练样本计算损失,整个训练集的平均损失就是神经网络的成本函数。

(3)最小化成本函数:调用fit来最小化神经网络的成本。

TensorFlow训练神经网络模型的代码含义

1.定义模型

Tensorflow指定了神经网络的整体架构,在已经每层参数w,b的情况下,TensorFlow能够计算出最终输出。

2.指定损失函数和成本函数

(1)分类问题:如果解决二分类问题,可以指定TensorFlow内置的损失函数BinaryCrossentropy(二元交叉熵),TensorFlow会自动处理损失函数的构建和最小化过程。


(2)回归问题:解决回归问题可以指定损失函数MeanSquaredError(平方误差)损失函数。

(3)如果自己构建成本函数,需要用到神经网络里的所有参数。

3.最小化成本函数

如果我们自己使用梯度下降来训练神经网络的参数,那么你将重复地对每一层 l 和每个单元 j 更新 w_lj,根据 w_lj 减去学习率 alpha 乘以成本函数 j 的参数导数,以及对参数 b 也是类似的更新。而TensorFlow可以我们完成所有这些工作。它在这个名为 fit 的函数中实现了反向传播。你只需要调用 model.fit,使用x,y 作为你的训练集,并设置100次迭代或100个周期。

总结

首先,介绍了TensorFlow训练神经网络的基本思路,并与逻辑回归模型的训练步骤进行了简要对比,强调了在神经网络训练中需要关注的关键步骤。然后,详细描述了TensorFlow训练神经网络的每个步骤的含义,强调了TensorFlow在简化神经网络训练过程中的重要性。

Quiz

Quiz1

哪种类型的任务中使用二元交叉熵损失函数(BinaryCrossentropy)?

A:BinaryCrossentropy()不应用于任何任务。

B:具有3个或更多类别(类别)的分类任务

C:二元分类(正好有2个分类)

D:回归任务(预测数字的任务)

正确答案C:二元交叉熵,我们也称之为逻辑损失,用于分类两个类别(两个类别)之间。

Quiz2


哪行代码用于更新网络参数以减少成本?

A:model = Sequential([...)

B:以上都不是-这段代码不会更新参数

C:model.fit(X, y, epochs=100)

D:model.compile(loss=BinaryCrossentropy()

正确答案C:模型训练的第三步是在数据上训练模型以最小化损失(和成本)。

相关推荐
荒古前27 分钟前
龟兔赛跑 PTA
c语言·算法
Colinnian30 分钟前
Codeforces Round 994 (Div. 2)-D题
算法·动态规划
用户00993831430136 分钟前
代码随想录算法训练营第十三天 | 二叉树part01
数据结构·算法
shinelord明40 分钟前
【再谈设计模式】享元模式~对象共享的优化妙手
开发语言·数据结构·算法·设计模式·软件工程
დ旧言~1 小时前
专题八:背包问题
算法·leetcode·动态规划·推荐算法
_WndProc1 小时前
C++ 日志输出
开发语言·c++·算法
biter00881 小时前
opencv(15) OpenCV背景减除器(Background Subtractors)学习
人工智能·opencv·学习
努力学习编程的伍大侠1 小时前
基础排序算法
数据结构·c++·算法
qq_529025292 小时前
Torch.gather
python·深度学习·机器学习