Tensorflow——第三讲神经网络八股

前两讲我们学习了使用tensorflow原生代码搭建神经网络,本讲主要学习使用Tensorflow API:tf.keras****搭建神经网络

一、搭建网络八股Sequential

六步法:

1.import:import 相关模块,如 import tensorflow as tf

2.train, test:指定输入网络的训练集和测试集,如指定训练集的输入 x_train 和标签

y_train,测试集的输入 x_test 和标签 y_test。

3.model = tf.keras.models.Sequential:逐层搭建网络结构

4.model.compile:在 model.compile()中配置训练方法,选择训练时使用的优化器、损失

函数和最终评价指标。

5.model.fit:在 model.fit()中执行训练过程,告知训练集和测试集的输入值和标签、

每个 batch 的大小(batchsize)和数据集的迭代次数(epoch)

6.model.summary:使用 model.summary()打印网络结构,统计参数数目。

model = tf.keras.models.Sequential的使用:

model.compile的使用

:from_logits=False:神经网络末端如果使用了softmax函数,输出为概率分布而不是原始输出,from_logits就为false,否则为True

model.fit()的使用

model.summary()的使用

二、搭建网络八股class

用Sequential能搭建上层输入就是下层输出的顺序网络结构,但是无法写出一些带有跳连的非顺序网络结构,这个时候我们可以选择用类class搭建神经网络结构。

class的使用 :

对比 Sequential和class搭建神经网络的过程:

以实现鸢尾花分类为例

Sequential

python 复制代码
import tensorflow as tf
from sklearn import datasets
import numpy as np

x_train = datasets.load_iris().data
y_train = datasets.load_iris().target

np.random.seed(116)
np.random.shuffle(x_train)
np.random.seed(116)
np.random.shuffle(y_train)
tf.random.set_seed(116)

model = tf.keras.models.Sequential([
    tf.keras.layers.Dense(3, activation='softmax', kernel_regularizer=tf.keras.regularizers.l2())
])

model.compile(optimizer=tf.keras.optimizers.SGD(lr=0.1),
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
              metrics=['sparse_categorical_accuracy'])

model.fit(x_train, y_train, batch_size=32, epochs=500, validation_split=0.2, validation_freq=20)

model.summary()

class

python 复制代码
import tensorflow as tf
from tensorflow.keras.layers import Dense
from tensorflow.keras import Model
from sklearn import datasets
import numpy as np

x_train = datasets.load_iris().data
y_train = datasets.load_iris().target

np.random.seed(116)
np.random.shuffle(x_train)
np.random.seed(116)
np.random.shuffle(y_train)
tf.random.set_seed(116)

class IrisModel(Model):
    def __init__(self):
        super(IrisModel, self).__init__()
        self.d1 = Dense(3, activation='softmax', kernel_regularizer=tf.keras.regularizers.l2())

    def call(self, x):
        y = self.d1(x)
        return y

model = IrisModel()

model.compile(optimizer=tf.keras.optimizers.SGD(lr=0.1),
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
              metrics=['sparse_categorical_accuracy'])

model.fit(x_train, y_train, batch_size=32, epochs=500, validation_split=0.2, validation_freq=20)
model.summary()

三、MNIST数据集 ---手写数字识别训练

1.数据集的介绍

(1)MNIST数据集:
提供 6万张 28*28 像素点的0~9手写数字图片和标签,用于训练。
提供 1万张 28*28 像素点的0~9手写数字图片和标签,用于测试。

(2)导入MNIST数据集:

mnist = tf.keras.datasets.mnist

(x_train, y_train) , (x_test, y_test) = mnist.load_data()

(3)作为输入特征,输入神经网络时,将数据拉伸为一维数组:

tf.keras.layers.Flatten( )

0 0 0 48 238 252 252 ...... ...... ...... 253 186 12 0 0 0 0 0

注:不知道这里大家有没有这样一个疑问,为什么鸢尾花的数据集不需要拉伸:

原因:鸢尾花数据集不需要拉直为一维是因为它的特征已经是数值型的,可以直接用于机器学习模型的训练和预测。而手写数字数据需要拉直为一维是因为它们的原始数据是图像形式的,需要通过转换才能被机器学习算法处理。

(4)观察数据集

2.代码实现书写数字识别

python 复制代码
import tensorflow as tf

mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dense(10, activation='softmax')
])

model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
              metrics=['sparse_categorical_accuracy'])

model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1)
model.summary()
python 复制代码
import tensorflow as tf
from tensorflow.keras.layers import Dense, Flatten
from tensorflow.keras import Model

mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0


class MnistModel(Model):
    def __init__(self):
        super(MnistModel, self).__init__()
        self.flatten = Flatten()
        self.d1 = Dense(128, activation='relu')
        self.d2 = Dense(10, activation='softmax')

    def call(self, x):
        x = self.flatten(x)
        x = self.d1(x)
        y = self.d2(x)
        return y


model = MnistModel()

model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
              metrics=['sparse_categorical_accuracy'])

model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1)
model.summary()

后面还有FASHION数据集数据集,与MNIST数据集处理方式类似,就不再赘述。

相关推荐
说私域1 小时前
日本零售精髓赋能下 链动2+1模式驱动新零售本质回归与发展格局研究
人工智能·小程序·数据挖掘·回归·流量运营·零售·私域运营
千里马也想飞1 小时前
汉语言文学《朝花夕拾》叙事艺术研究论文写作实操:AI 辅助快速完成框架 + 正文创作
人工智能
玉梅小洋1 小时前
解决 VS Code Claude Code 插件「Allow this bash command_」弹窗问题
人工智能·ai·大模型·ai编程
肾透侧视攻城狮1 小时前
《解锁计算机视觉:深度解析 PyTorch torchvision 核心与进阶技巧》
人工智能·深度学习·计算机视觉模快·支持的数据集类型·常用变换方法分类·图像分类流程实战·视觉模快高级功能
一战成名9961 小时前
AI 模型持续集成流水线:CANN 支持的 DevOps 最佳实践
人工智能·ci/cd·devops
23遇见1 小时前
AI视角下的 CANN 仓库架构全解析:高效计算的核心
人工智能
有趣的杰克1 小时前
开源|macOS 菜单栏 AI 启动器 GroAsk:⌥Space 一键直达 ChatGPT / Claude / Gemini
人工智能·macos·chatgpt
yumgpkpm1 小时前
预测:2026年大数据软件+AI大模型的发展趋势
大数据·人工智能·算法·zookeeper·kafka·开源·cloudera
星爷AG I1 小时前
11-2 距离知觉(AGI基础理论)
人工智能·agi
算法狗21 小时前
大模型面试题:在混合精度训练中如何选择合适的精度
人工智能·深度学习·机器学习·语言模型