Tensorflow——第三讲神经网络八股

前两讲我们学习了使用tensorflow原生代码搭建神经网络,本讲主要学习使用Tensorflow API:tf.keras****搭建神经网络

一、搭建网络八股Sequential

六步法:

1.import:import 相关模块,如 import tensorflow as tf

2.train, test:指定输入网络的训练集和测试集,如指定训练集的输入 x_train 和标签

y_train,测试集的输入 x_test 和标签 y_test。

3.model = tf.keras.models.Sequential:逐层搭建网络结构

4.model.compile:在 model.compile()中配置训练方法,选择训练时使用的优化器、损失

函数和最终评价指标。

5.model.fit:在 model.fit()中执行训练过程,告知训练集和测试集的输入值和标签、

每个 batch 的大小(batchsize)和数据集的迭代次数(epoch)

6.model.summary:使用 model.summary()打印网络结构,统计参数数目。

model = tf.keras.models.Sequential的使用:

model.compile的使用

:from_logits=False:神经网络末端如果使用了softmax函数,输出为概率分布而不是原始输出,from_logits就为false,否则为True

model.fit()的使用

model.summary()的使用

二、搭建网络八股class

用Sequential能搭建上层输入就是下层输出的顺序网络结构,但是无法写出一些带有跳连的非顺序网络结构,这个时候我们可以选择用类class搭建神经网络结构。

class的使用 :

对比 Sequential和class搭建神经网络的过程:

以实现鸢尾花分类为例

Sequential

python 复制代码
import tensorflow as tf
from sklearn import datasets
import numpy as np

x_train = datasets.load_iris().data
y_train = datasets.load_iris().target

np.random.seed(116)
np.random.shuffle(x_train)
np.random.seed(116)
np.random.shuffle(y_train)
tf.random.set_seed(116)

model = tf.keras.models.Sequential([
    tf.keras.layers.Dense(3, activation='softmax', kernel_regularizer=tf.keras.regularizers.l2())
])

model.compile(optimizer=tf.keras.optimizers.SGD(lr=0.1),
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
              metrics=['sparse_categorical_accuracy'])

model.fit(x_train, y_train, batch_size=32, epochs=500, validation_split=0.2, validation_freq=20)

model.summary()

class

python 复制代码
import tensorflow as tf
from tensorflow.keras.layers import Dense
from tensorflow.keras import Model
from sklearn import datasets
import numpy as np

x_train = datasets.load_iris().data
y_train = datasets.load_iris().target

np.random.seed(116)
np.random.shuffle(x_train)
np.random.seed(116)
np.random.shuffle(y_train)
tf.random.set_seed(116)

class IrisModel(Model):
    def __init__(self):
        super(IrisModel, self).__init__()
        self.d1 = Dense(3, activation='softmax', kernel_regularizer=tf.keras.regularizers.l2())

    def call(self, x):
        y = self.d1(x)
        return y

model = IrisModel()

model.compile(optimizer=tf.keras.optimizers.SGD(lr=0.1),
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
              metrics=['sparse_categorical_accuracy'])

model.fit(x_train, y_train, batch_size=32, epochs=500, validation_split=0.2, validation_freq=20)
model.summary()

三、MNIST数据集 ---手写数字识别训练

1.数据集的介绍

(1)MNIST数据集:
提供 6万张 28*28 像素点的0~9手写数字图片和标签,用于训练。
提供 1万张 28*28 像素点的0~9手写数字图片和标签,用于测试。

(2)导入MNIST数据集:

mnist = tf.keras.datasets.mnist

(x_train, y_train) , (x_test, y_test) = mnist.load_data()

(3)作为输入特征,输入神经网络时,将数据拉伸为一维数组:

tf.keras.layers.Flatten( )

[ 0 0 0 48 238 252 252 ...... ...... ...... 253 186 12 0 0 0 0 0]

注:不知道这里大家有没有这样一个疑问,为什么鸢尾花的数据集不需要拉伸:

原因:鸢尾花数据集不需要拉直为一维是因为它的特征已经是数值型的,可以直接用于机器学习模型的训练和预测。而手写数字数据需要拉直为一维是因为它们的原始数据是图像形式的,需要通过转换才能被机器学习算法处理。

(4)观察数据集

2.代码实现书写数字识别

python 复制代码
import tensorflow as tf

mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dense(10, activation='softmax')
])

model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
              metrics=['sparse_categorical_accuracy'])

model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1)
model.summary()
python 复制代码
import tensorflow as tf
from tensorflow.keras.layers import Dense, Flatten
from tensorflow.keras import Model

mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0


class MnistModel(Model):
    def __init__(self):
        super(MnistModel, self).__init__()
        self.flatten = Flatten()
        self.d1 = Dense(128, activation='relu')
        self.d2 = Dense(10, activation='softmax')

    def call(self, x):
        x = self.flatten(x)
        x = self.d1(x)
        y = self.d2(x)
        return y


model = MnistModel()

model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
              metrics=['sparse_categorical_accuracy'])

model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1)
model.summary()

后面还有FASHION数据集数据集,与MNIST数据集处理方式类似,就不再赘述。

相关推荐
少说多想勤做26 分钟前
【计算机视觉前沿研究 热点 顶会】ECCV 2024中Mamba有关的论文
人工智能·计算机视觉·目标跟踪·论文笔记·mamba·状态空间模型·eccv
宜向华1 小时前
opencv 实现两个图片的拼接去重功能
人工智能·opencv·计算机视觉
OpenVINO生态社区2 小时前
【了解ADC差分非线性(DNL)错误】
人工智能
醉后才知酒浓2 小时前
图像处理之蒸馏
图像处理·人工智能·深度学习·计算机视觉
炸弹气旋3 小时前
基于CNN卷积神经网络迁移学习的图像识别实现
人工智能·深度学习·神经网络·计算机视觉·cnn·自动驾驶·迁移学习
python_知世3 小时前
时下改变AI的6大NLP语言模型
人工智能·深度学习·自然语言处理·nlp·大语言模型·ai大模型·大模型应用
愤怒的可乐3 小时前
Sentence-BERT实现文本匹配【CoSENT损失】
人工智能·深度学习·bert
冻感糕人~3 小时前
HRGraph: 利用大型语言模型(LLMs)构建基于信息传播的HR数据知识图谱与职位推荐
人工智能·深度学习·自然语言处理·知识图谱·ai大模型·llms·大模型应用
花生糖@3 小时前
Midjourney即将推出的AI生视频产品:CEO洞见分享
人工智能·ai·aigc·midjourney
小言从不摸鱼3 小时前
【NLP自然语言处理】文本处理的基本方法
人工智能·python·自然语言处理