TensorFlow构建CNN卷积神经网络模型的基本步骤:数据处理、模型构建、模型训练

《------往期经典推荐------》

一、AI应用软件开发实战专栏【链接】

项目名称 项目名称
1.【人脸识别与管理系统开发 2.【车牌识别与自动收费管理系统开发
3.【手势识别系统开发 4.【人脸面部活体检测系统开发
5.【图片风格快速迁移软件开发 6.【人脸表表情识别系统
7.【YOLOv8多目标识别与自动标注软件开发 8.【基于YOLOv8深度学习的行人跌倒检测系统
9.【基于YOLOv8深度学习的PCB板缺陷检测系统 10.【基于YOLOv8深度学习的生活垃圾分类目标检测系统
11.【基于YOLOv8深度学习的安全帽目标检测系统 12.【基于YOLOv8深度学习的120种犬类检测与识别系统
13.【基于YOLOv8深度学习的路面坑洞检测系统 14.【基于YOLOv8深度学习的火焰烟雾检测系统
15.【基于YOLOv8深度学习的钢材表面缺陷检测系统 16.【基于YOLOv8深度学习的舰船目标分类检测系统
17.【基于YOLOv8深度学习的西红柿成熟度检测系统 18.【基于YOLOv8深度学习的血细胞检测与计数系统
19.【基于YOLOv8深度学习的吸烟/抽烟行为检测系统 20.【基于YOLOv8深度学习的水稻害虫检测与识别系统
21.【基于YOLOv8深度学习的高精度车辆行人检测与计数系统 22.【基于YOLOv8深度学习的路面标志线检测与识别系统
23.【基于YOLOv8深度学习的智能小麦害虫检测识别系统 24.【基于YOLOv8深度学习的智能玉米害虫检测识别系统
25.【基于YOLOv8深度学习的200种鸟类智能检测与识别系统 26.【基于YOLOv8深度学习的45种交通标志智能检测与识别系统
27.【基于YOLOv8深度学习的人脸面部表情识别系统 28.【基于YOLOv8深度学习的苹果叶片病害智能诊断系统
29.【基于YOLOv8深度学习的智能肺炎诊断系统 30.【基于YOLOv8深度学习的葡萄簇目标检测系统
31.【基于YOLOv8深度学习的100种中草药智能识别系统 32.【基于YOLOv8深度学习的102种花卉智能识别系统
33.【基于YOLOv8深度学习的100种蝴蝶智能识别系统 34.【基于YOLOv8深度学习的水稻叶片病害智能诊断系统
35.【基于YOLOv8与ByteTrack的车辆行人多目标检测与追踪系统 36.【基于YOLOv8深度学习的智能草莓病害检测与分割系统
37.【基于YOLOv8深度学习的复杂场景下船舶目标检测系统 38.【基于YOLOv8深度学习的农作物幼苗与杂草检测系统
39.【基于YOLOv8深度学习的智能道路裂缝检测与分析系统 40.【基于YOLOv8深度学习的葡萄病害智能诊断与防治系统
41.【基于YOLOv8深度学习的遥感地理空间物体检测系统 42.【基于YOLOv8深度学习的无人机视角地面物体检测系统
43.【基于YOLOv8深度学习的木薯病害智能诊断与防治系统 44.【基于YOLOv8深度学习的野外火焰烟雾检测系统
45.【基于YOLOv8深度学习的脑肿瘤智能检测系统 46.【基于YOLOv8深度学习的玉米叶片病害智能诊断与防治系统
47.【基于YOLOv8深度学习的橙子病害智能诊断与防治系统 48.【基于深度学习的车辆检测追踪与流量计数系统
49.【基于深度学习的行人检测追踪与双向流量计数系统 50.【基于深度学习的反光衣检测与预警系统
51.【基于深度学习的危险区域人员闯入检测与报警系统 52.【基于深度学习的高密度人脸智能检测与统计系统
53.【基于深度学习的CT扫描图像肾结石智能检测系统 54.【基于深度学习的水果智能检测系统
55.【基于深度学习的水果质量好坏智能检测系统 56.【基于深度学习的蔬菜目标检测与识别系统
57.【基于深度学习的非机动车驾驶员头盔检测系统 58.【太基于深度学习的阳能电池板检测与分析系统
59.【基于深度学习的工业螺栓螺母检测 60.【基于深度学习的金属焊缝缺陷检测系统
61.【基于深度学习的链条缺陷检测与识别系统 62.【基于深度学习的交通信号灯检测识别
63.【基于深度学习的草莓成熟度检测与识别系统 64.【基于深度学习的水下海生物检测识别系统
65.【基于深度学习的道路交通事故检测识别系统 66.【基于深度学习的安检X光危险品检测与识别系统
67.【基于深度学习的农作物类别检测与识别系统 68.【基于深度学习的危险驾驶行为检测识别系统
69.【基于深度学习的维修工具检测识别系统 70.【基于深度学习的维修工具检测识别系统
71.【基于深度学习的建筑墙面损伤检测系统 72.【基于深度学习的煤矿传送带异物检测系统
73.【基于深度学习的老鼠智能检测系统

二、机器学习实战专栏【链接】 ,已更新31期,欢迎关注,持续更新中~~
三、深度学习【Pytorch】专栏【链接】
四、【Stable Diffusion绘画系列】专栏【链接】
五、YOLOv8改进专栏【链接】持续更新中~~
六、YOLO性能对比专栏【链接】,持续更新中~

《------正文------》

目录

CNN python图像分类网络结构:

CNN的架构

简介

CNN(卷积神经网络)是计算机视觉的支柱。即使是更快的RCNN等对象检测算法也依赖CNN来执行其任务。因此,对CNN有一个很好的理解可以帮助你在计算机视觉领域出类拔萃。我们将使用CNN来执行图像分类。我将分享样板代码,你可以重用。

导入库

第一步是导入必要的库

python 复制代码
import os
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Conv2D, Flatten, Dropout, MaxPooling2D
from tensorflow.keras.preprocessing.image import ImageDataGenerator

如果在导入时出现任何错误,请确保安装库。

数据处理

我的目录结构如下:

python 复制代码
images
   |____train
   |       |___class A
   |       |___class B
   |       |___class C
   |
   |-----test
   |       |___class A
   |       |___class B
   |       |___class C
       

我们将加载数据并执行基本的图像预处理 。由于收集的图像大小不同,因此需要对图像进行重新定位。选择150 * 150的图像尺寸。图像的像素值范围为0到255。为了提供更好的结果,像素值被重新缩放,因此所有值都在0和1之间。对现有的图像集执行不同的技术,例如水平翻转旋转

python 复制代码
batch_size = 130
IMG_SHAPE = 150

#Rescaling the images and applying horizontal flip
image_gen = ImageDataGenerator(rescale=1./255, horizontal_flip=True)
train_data_gen = image_gen.flow_from_directory(
batch_size=batch_size,
directory=train_dir,
shuffle=True,
target_size=(IMG_SHAPE,IMG_SHAPE)
)

#Rescaling the images and rotation it by 45 degree
image_gen = ImageDataGenerator(rescale=1./255, rotation_range=45)
train_data_gen = image_gen.flow_from_directory(batch_size=batch_size,
directory=train_dir,
shuffle=True,
target_size=(IMG_SHAPE, IMG_SHAPE))


train_data_gen = image_gen_train.flow_from_directory(
batch_size=batch_size,
directory=train_dir,
shuffle=True,
target_size=(IMG_SHAPE,IMG_SHAPE),
class_mode='sparse'
)


val_data_gen = image_gen_val.flow_from_directory(batch_size=batch_size,
directory=val_dir,
target_size=(IMG_SHAPE, IMG_SHAPE),
class_mode='sparse'))

构建模型

一旦数据预处理完成,我们就可以定义我们的模型架构。你可以使用不同的超参数卷积最大池化 应用于数据集,然后将其发送到输出层。模型被展平 。应用Dropout 以防止图像的过度拟合。我们正在使用准确性作为评估指标来编制模型。

卷积层的作用

python 复制代码
model = Sequential()
model.add(Conv2D(16, 3, padding='same', activation='relu', input_shape=(IMG_SHAPE,IMG_SHAPE, 3)))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Conv2D(32, 3, padding='same', activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Conv2D(64, 3, padding='same', activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
# Adding dropout to turn down some neurons
model.add(Flatten())
model.add(Dropout(0.2))
model.add(Dense(512, activation='relu'))
model.add(Dropout(0.2))
model.add(Dense(5, activation='softmax'))
model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['accuracy'])

训练模型

一旦我们定义了架构,我们就可以继续将数据馈送到CNN模型中来训练它。

python 复制代码
epochs = 120
history = model.fit_generator(
train_data_gen,
steps_per_epoch=int(np.ceil(train_data_gen.n / float(batch_size))),
epochs=epochs,
validation_data=val_data_gen,
validation_steps=int(np.ceil(val_data_gen.n / float(batch_size)))
)

可以重复使用提供的样板代码,并根据用例训练CNN模型。


好了,这篇文章就介绍到这里,喜欢的小伙伴感谢给点个赞和关注,更多精彩内容持续更新~~
关于本篇文章大家有任何建议或意见,欢迎在评论区留言交流!

相关推荐
珠海西格电力几秒前
零碳园区边缘计算节点规划:数字底座的硬件部署与能耗控制方案
运维·人工智能·物联网·能源·边缘计算
臼犀4 分钟前
孩子,那不是说明书,那是祈祷文
人工智能·程序员·markdown
黑客思维者5 分钟前
《关于深入实施 “人工智能 +“ 行动的意见》深度解读
人工智能
Sui_Network5 分钟前
Mysten Labs 与不丹王国政府的创新与技术部携手探索离线区块链
大数据·人工智能·web3·去中心化·区块链
互联科技报8 分钟前
GEO优化工具、AI搜索引擎优化软件平台实测报告:四大平台深度体验与选型指南
大数据·人工智能·搜索引擎
山东小木11 分钟前
AI智能问数(ChatBI)开发框架&解决方案&相关产品
人工智能·chatbi·智能问数·jboltai·javaai·ai问数·ai生图表
free-elcmacom16 分钟前
机器学习高阶教程<5>当机器学习遇上运筹学:破解商业决策的“终极难题”
人工智能·python·机器学习
Lun3866buzha27 分钟前
大型铸件表面缺陷检测与分类_YOLO11-C2BRA应用实践
人工智能·分类·数据挖掘
递归尽头是星辰29 分钟前
AI 驱动的报表系统:从传统到智能的落地与演进
大数据·人工智能·大模型应用·spring ai·ai 报表·报表智能化
Wang ruoxi30 分钟前
基于最小二乘法的离散数据拟合
人工智能·算法·机器学习