神经网络模型---ResNet

一、ResNet

1.导入包

python 复制代码
import tensorflow as tf
from tensorflow.keras import layers, models, datasets, optimizers

optimizers是用于更新模型参数以最小化损失函数的算法

2.加载数据集、归一化、转为独热编码的内容一致

3.增加颜色通道

python 复制代码
train_images = train_images[..., tf.newaxis].astype("float32")
test_images = test_images[..., tf.newaxis].astype("float32")

在train_images和test_images最后一个维度增加一个新的维度
这两行代码还将图像数据转换为浮点数类型

4.定义一个用于图像预处理的模型

4.1创造模型

python 复制代码
preprocessing = models.Sequential([

4.2添加一个卷积层,该层有3个1x1的卷积核,激活函数为relu,并且指定了输入形状为28x28像素的单通道图像

python 复制代码
layers.Conv2D(3, (1, 1), activation='relu', input_shape=(28, 28, 1)),

4.3 将图像尺寸增加到56x56

python 复制代码
    layers.UpSampling2D((2, 2)), 
])

5.应用预处理模型到训练和测试图像上

python 复制代码
train_images = preprocessing(train_images)
test_images = preprocessing(test_images)

6.加载ResNet50模型并冻结所有层

python 复制代码
base_model=tf.keras.applications.ResNet50(weights='imagenet',include_top=False,input_shape=(56, 56, 3))

ResNet50是一个预训练的卷积神经网络模型,
参数1:加载模型的权重
参数2:是否包括模型顶部的全连接层,设置False意味着不包括这些层,由此可以得到模型的特征提取部分
参数3:输入图像的尺寸
base_model.trainable = False
使ResNet50模型的所有层都不可训练

7.创建模型

python 复制代码
model = models.Sequential([

7.1放在模型的第一层添加到序列中,用于提取图像特征

python 复制代码
base_model,

7.2在Keras中添加的一个全局平均池化层

python 复制代码
layers.GlobalAveragePooling2D(),

7.3在Keras中添加的一个全连接层,使用softmax为激活函数

python 复制代码
    layers.Dense(10, activation='softmax')
])

8.编译模型

python 复制代码
model.compile(optimizer=optimizers.Adam(),
              loss='categorical_crossentropy',
              metrics=['accuracy'])

  • 和上一个博客的模型的内容一样,此处省略

9.训练模型

python 复制代码
model.fit(train_images, train_labels, epochs=10, batch_size=64, validation_data=(test_images, test_labels))

结果:

10.保存文件

python 复制代码
model.save('ResNet.h5')

结果:

相关推荐
qzhqbb11 分钟前
基于统计方法的语言模型
人工智能·语言模型·easyui
冷眼看人间恩怨36 分钟前
【话题讨论】AI大模型重塑软件开发:定义、应用、优势与挑战
人工智能·ai编程·软件开发
2401_8830410837 分钟前
新锐品牌电商代运营公司都有哪些?
大数据·人工智能
AI极客菌2 小时前
Controlnet作者新作IC-light V2:基于FLUX训练,支持处理风格化图像,细节远高于SD1.5。
人工智能·计算机视觉·ai作画·stable diffusion·aigc·flux·人工智能作画
阿_旭2 小时前
一文读懂| 自注意力与交叉注意力机制在计算机视觉中作用与基本原理
人工智能·深度学习·计算机视觉·cross-attention·self-attention
王哈哈^_^2 小时前
【数据集】【YOLO】【目标检测】交通事故识别数据集 8939 张,YOLO道路事故目标检测实战训练教程!
前端·人工智能·深度学习·yolo·目标检测·计算机视觉·pyqt
Power20246663 小时前
NLP论文速读|LongReward:基于AI反馈来提升长上下文大语言模型
人工智能·深度学习·机器学习·自然语言处理·nlp
数据猎手小k3 小时前
AIDOVECL数据集:包含超过15000张AI生成的车辆图像数据集,目的解决旨在解决眼水平分类和定位问题。
人工智能·分类·数据挖掘
好奇龙猫3 小时前
【学习AI-相关路程-mnist手写数字分类-win-硬件:windows-自我学习AI-实验步骤-全连接神经网络(BPnetwork)-操作流程(3) 】
人工智能·算法
沉下心来学鲁班3 小时前
复现LLM:带你从零认识语言模型
人工智能·语言模型