文章目录
解决问题
针对经典猫狗数据集,基于卷积神经网络,构建猫狗二元分类模型,使用数据集进行参数训练,模型评估,然后使用模型进行分类预测,最后对模型进行保存,供后续使用。
数据集
数据集来源
探索性数据分析
查看待训练识别图片
from matplotlib import pyplot as plt
import os
import random
# 获取文件名
_,_,cat_images = next(os.walk('../../dataset/kagglecatsanddogs_5340/PetImages/Cat'))
# 准备3*3 图表
fig, ax = plt.subplots(3, 3, figsize=(20, 10))
# 随机选择一幅图像并绘制
for idx, img in enumerate(random.sample(cat_images, 9)):
img_read = plt.imread('../../dataset/kagglecatsanddogs_5340/PetImages/Cat/' + img)
ax[int(idx / 3), idx % 3].imshow(img_read)
ax[int(idx / 3), idx % 3].set_title('cat/' + img)
ax[int(idx / 3), idx % 3].axis('off')
plt.show()
查看狗图片类似,将Cat目录换成Dog即可
数据预处理
数据集分割
由于下载的图片猫和狗各在一个文件夹内,如下:
需要将数据按80%:20%进行分割,分为训练集和测试集。目录结构如下:
下面进行数据拆分,核心代码(以猫图片为例)如下:
# 训练数据集80% 测试数据集20%
train_size = 0.8
# 获取猫图像数量
_, _, cat_images = next(os.walk(src_folder+'Cat/'))
num_cat_images = len(cat_images)
num_cat_images_train = int(train_size * num_cat_images)
num_cat_images_test = num_cat_images - num_cat_images_train
# 分割猫图像
cat_train_images = random.sample(cat_images, num_cat_images_train)
for img in cat_train_images:
shutil.copy(src=src_folder+'Cat/'+img, dst=src_folder+'Train/Cat/')
cat_test_images = [img for img in cat_images if img not in cat_train_images]
for img in cat_test_images:
shutil.copy(src=src_folder+'Cat/'+img, dst=src_folder+'Test/Cat/')
数据预处理
这一步要将分割后的数据集转成和模型结构匹配的数据类型。使用keras提供的ImageDataGenerator类和flow_from_directory()方法
ImageDataGenerator类:图像增强类,可以进行图像旋转、图像平移、水平翻转、图像缩放等操作;
flow_from_directory()方法:ImageDataGenerator类的方法,支持以图像路径为输入,按批次加载图像到内存,防止训练数据量过大,机器内存不足问题;还支持对图像进行预处理操作,例如尺寸缩放和图像增强
# 训练数据预处理
training_data_generator = ImageDataGenerator(rescale=1./255)
training_set = training_data_generator.flow_from_directory('../../dataset/kagglecatsanddogs_5340/PetImages/train/',target_size=(32, 32),batch_size=16,class_mode='binary')
# 测试数据预处理
testing_data_generator = ImageDataGenerator(rescale= 1./255)
testing_set = testing_data_generator.flow_from_directory('../../dataset/kagglecatsanddogs_5340/PetImages/test/',target_size=(32, 32),batch_size=16, class_mode='binary')
构建模型并训练
构建模型
# 定义超参数
# 特征滤波器尺寸
FILTER_SIZE = 3
# 特征滤波器数量
FILTER_NUM = 32
# 图片输入尺寸
INPUT_SIZE = 32
# 最大池化尺寸
MAXPOOL_SIZE = 2
# 批量处理图片的大小
BATCH_SIZE = 16
STEPS_PER_EPOCH = 20000 // BATCH_SIZE
# 训练轮次
EPOCHS = 10
# 定义模型
model = Sequential()
# 添加卷积、池化层 提取特征
model.add(Conv2D(FILTER_NUM, (FILTER_SIZE, FILTER_SIZE), input_shape=(INPUT_SIZE, INPUT_SIZE, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(MAXPOOL_SIZE,MAXPOOL_SIZE)))
# 再添加卷积、池化层 提取特征
model.add(Conv2D(FILTER_NUM, (FILTER_SIZE, FILTER_SIZE), input_shape=(INPUT_SIZE, INPUT_SIZE, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(MAXPOOL_SIZE,MAXPOOL_SIZE)))
# 对输出结果进行降维处理,转成一维张量
model.add(Flatten())
# 添加全链接层,根据特征进行分类预测
model.add(Dense(units=128, activation='relu'))
# 添加dropout层,随机将一部分输入设置为0,防止模型复杂,出现过拟合现象
model.add(Dropout(0.5))
# 添加输出层,一个节点
model.add(Dense(units=1, activation='sigmoid'))
该模型结构分为,卷积池化层,卷积池化层,Flatten层,全链接层1,全链接层2(输出层)如下:
其中,第一列是神经网络的层,第二列是每层的输出形状,第三层是每层训练的参数
可以看到,该模型图像输入尺寸是(32,32),经过一层卷积(32个特征过滤器)输出为(30,30,32),经过一层最大池化层,输出为(15,15,32);其中特征滤波器尺寸为3*3,所以滤波后的尺寸会是32-(3-1)=30,经过最大池化(2x2)尺寸减半,为15。
训练模型
# 模型训练
model.fit(training_set, steps_per_epoch=STEPS_PER_EPOCH, epochs=EPOCHS, verbose=1)
结果分析与评估
model.evaluate(testing_set,steps=len(testing_set),verbose=1)
准确度达到了0.7856
模型保存
from joblib import dump, load
# 模型持久化 到磁盘
dump(model, './猫狗分类.onnx')
结果预测
引入保存模型,随机选取一张图片进行预测分类
from matplotlib import pyplot as plt
fig, ax = plt.subplots()
img = plt.imread('../../dataset/kagglecatsanddogs_5340/PetImages/Dog/6.jpg')
ax.imshow(img)
plt.show()
from joblib import dump, load
model = load('./猫狗分类.onnx')
from tensorflow.keras.preprocessing.image import img_to_array,load_img
img = load_img('../../dataset/kagglecatsanddogs_5340/PetImages/Dog/6.jpg',target_size=(32,32))
img = img_to_array(img)
img /= 255
import numpy as np
img_array = np.expand_dims(img, axis=0)
print(img_array.shape)
model.predict(img_array)
由于是二元分类,0和1分别表示猫狗,输出概率接近表示是狗,接近0表示是猫狗。但具体为啥0表示猫1表示狗而不是反过来表示,还待研究。
经验总结
1 在使用next()加载图像时,要确保路径正确,否则会报StopIteration错误,原因是路径错误,找不到可迭代的数据。