Tensorflow实现深度学习案例7:咖啡豆识别

本文为为🔗365天深度学习训练营内部文章

原作者:K同学啊

一、前期工作

1. 导入数据

python 复制代码
from tensorflow       import keras
from tensorflow.keras import layers,models
import numpy             as np
import matplotlib.pyplot as plt
import os,PIL,pathlib
import tensorflow as tf
import warnings as w
w.filterwarnings('ignore')

data_dir = "./coffee/"
data_dir = pathlib.Path(data_dir)

image_count = len(list(data_dir.glob('*/*.png')))

print("图片总数为:",image_count)
复制代码
图片总数为: 1200

二、数据预处理

1. 加载数据

使用image_dataset_from_directory方法将磁盘中的数据加载到tf.data.Dataset

python 复制代码
batch_size = 32
img_height = 224
img_width = 224
python 复制代码
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="training",
    seed=123,
    image_size=(img_height, img_width),
    batch_size=batch_size)

Found 1200 files belonging to 4 classes.

Using 960 files for training.

python 复制代码
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="validation",
    seed=123,
    image_size=(img_height, img_width),
    batch_size=batch_size)

Found 1200 files belonging to 4 classes.

Using 240 files for validation.

python 复制代码
class_names = train_ds.class_names
print(class_names)
复制代码
['Dark', 'Green', 'Light', 'Medium']

2.数据可视化

python 复制代码
plt.figure(figsize=(10, 4))  # 图形的宽为10高为5

for images, labels in train_ds.take(1):
    for i in range(10):
        
        ax = plt.subplot(2, 5, i + 1)  

        plt.imshow(images[i].numpy().astype("uint8"))
        plt.title(class_names[labels[i]])
        
        plt.axis("off")
python 复制代码
for image_batch, labels_batch in train_ds:
    print(image_batch.shape)
    print(labels_batch.shape)
    break
python 复制代码
for image_batch, labels_batch in train_ds:
    print(image_batch.shape)
    print(labels_batch.shape)
    break

(32, 224, 224, 3)

(32,)

3. 配置数据集

  • prefetch() :预取数据,加速运行,其详细介绍可以参考我前两篇文章,里面都有讲解。
  • cache() :将数据集缓存到内存当中,加速运行
python 复制代码
AUTOTUNE = tf.data.AUTOTUNE

train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds   = val_ds.cache().prefetch(buffer_size=AUTOTUNE)

并且将数据归一化

python 复制代码
normalization_layer = layers.experimental.preprocessing.Rescaling(1./255)

train_ds = train_ds.map(lambda x, y: (normalization_layer(x), y))
val_ds   = val_ds.map(lambda x, y: (normalization_layer(x), y))

image_batch, labels_batch = next(iter(val_ds))
first_image = image_batch[0]

# 查看归一化后的数据
print(np.min(first_image), np.max(first_image))

0.0 1.0

三、构建VGG-16网络

1.VGG优缺点分析:

  • VGG优点

VGG的结构非常简洁,整个网络都使用了同样大小的卷积核尺寸(3x3)和最大池化尺寸(2x2)

  • VGG缺点

1)训练时间过长,调参难度大。2)需要的存储容量大,不利于部署。例如存储VGG-16权重值文件的大小为500多MB,不利于安装到嵌入式系统中。

2.网络结构图

结构说明:

  • 13个卷积层(Convolutional Layer),分别用blockX_convX表示
  • 3个全连接层(Fully connected Layer),分别用fcXpredictions表示
  • 5个池化层(Pool layer),分别用blockX_pool表示

VGG-16****包含了16个隐藏层(13个卷积层和3个全连接层),故称为 VGG-16

python 复制代码
model = tf.keras.applications.VGG16(weights='imagenet')
model.summary()
复制代码
Model: "vgg16"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_1 (InputLayer)        [(None, 224, 224, 3)]     0         
                                                                 
 block1_conv1 (Conv2D)       (None, 224, 224, 64)      1792      
                                                                 
 block1_conv2 (Conv2D)       (None, 224, 224, 64)      36928     
                                                                 
 block1_pool (MaxPooling2D)  (None, 112, 112, 64)      0         
                                                                 
 block2_conv1 (Conv2D)       (None, 112, 112, 128)     73856     
                                                                 
 block2_conv2 (Conv2D)       (None, 112, 112, 128)     147584    
                                                                 
 block2_pool (MaxPooling2D)  (None, 56, 56, 128)       0         
                                                                 
 block3_conv1 (Conv2D)       (None, 56, 56, 256)       295168    
                                                                 
 block3_conv2 (Conv2D)       (None, 56, 56, 256)       590080    
                                                                 
 block3_conv3 (Conv2D)       (None, 56, 56, 256)       590080    
                                                                 
 block3_pool (MaxPooling2D)  (None, 28, 28, 256)       0         
                                                                 
 block4_conv1 (Conv2D)       (None, 28, 28, 512)       1180160   
                                                                 
 block4_conv2 (Conv2D)       (None, 28, 28, 512)       2359808   
                                                                 
 block4_conv3 (Conv2D)       (None, 28, 28, 512)       2359808   
                                                                 
 block4_pool (MaxPooling2D)  (None, 14, 14, 512)       0         
                                                                 
 block5_conv1 (Conv2D)       (None, 14, 14, 512)       2359808   
                                                                 
 block5_conv2 (Conv2D)       (None, 14, 14, 512)       2359808   
                                                                 
 block5_conv3 (Conv2D)       (None, 14, 14, 512)       2359808   
                                                                 
 block5_pool (MaxPooling2D)  (None, 7, 7, 512)         0         
                                                                 
 flatten (Flatten)           (None, 25088)             0         
                                                                 
 fc1 (Dense)                 (None, 4096)              102764544 
                                                                 
 fc2 (Dense)                 (None, 4096)              16781312  
                                                                 
 predictions (Dense)         (None, 1000)              4097000   
                                                                 
=================================================================
Total params: 138,357,544
Trainable params: 138,357,544
Non-trainable params: 0
_________________________________________________________________

四、编译

在准备对模型进行训练之前,还需要再对其进行一些设置。以下内容是在模型的编译步骤中添加的:

  • 损失函数(loss):用于衡量模型在训练期间的准确率。
  • 优化器(optimizer):决定模型如何根据其看到的数据和自身的损失函数进行更新。
  • 指标(metrics):用于监控训练和测试步骤。以下示例使用了准确率,即被正确分类的图像的比率。
python 复制代码
# 设置初始学习率
initial_learning_rate = 1e-4

lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
        initial_learning_rate, 
        decay_steps=30,      # 敲黑板!!!这里是指 steps,不是指epochs
        decay_rate=0.92,     # lr经过一次衰减就会变成 decay_rate*lr
        staircase=True)

# 设置优化器
opt = tf.keras.optimizers.Adam(learning_rate=initial_learning_rate)

model.compile(optimizer=opt,
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

五、训练模型

python 复制代码
epochs = 20

history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=epochs
)
复制代码
Epoch 1/20
30/30 [==============================] - 346s 11s/step - loss: 1.7546 - accuracy: 0.2625 - val_loss: 1.4646 - val_accuracy: 0.2125
Epoch 2/20
30/30 [==============================] - 352s 12s/step - loss: 1.3637 - accuracy: 0.3104 - val_loss: 1.0428 - val_accuracy: 0.4583
Epoch 3/20
30/30 [==============================] - 338s 11s/step - loss: 0.7237 - accuracy: 0.6458 - val_loss: 0.4818 - val_accuracy: 0.7833
Epoch 4/20
30/30 [==============================] - 336s 11s/step - loss: 0.3633 - accuracy: 0.8479 - val_loss: 1.1034 - val_accuracy: 0.6167
Epoch 5/20
30/30 [==============================] - 340s 11s/step - loss: 0.2880 - accuracy: 0.8927 - val_loss: 0.1480 - val_accuracy: 0.9500
Epoch 6/20
30/30 [==============================] - 338s 11s/step - loss: 0.1802 - accuracy: 0.9333 - val_loss: 0.4709 - val_accuracy: 0.8458
Epoch 7/20
30/30 [==============================] - 334s 11s/step - loss: 0.1468 - accuracy: 0.9490 - val_loss: 0.0214 - val_accuracy: 1.0000
Epoch 8/20
30/30 [==============================] - 339s 11s/step - loss: 0.0174 - accuracy: 0.9969 - val_loss: 0.0196 - val_accuracy: 0.9875
Epoch 9/20
30/30 [==============================] - 329s 11s/step - loss: 0.0399 - accuracy: 0.9875 - val_loss: 0.2539 - val_accuracy: 0.9292
Epoch 10/20
30/30 [==============================] - 330s 11s/step - loss: 0.2606 - accuracy: 0.9073 - val_loss: 0.0737 - val_accuracy: 0.9917
Epoch 11/20
30/30 [==============================] - 334s 11s/step - loss: 0.0610 - accuracy: 0.9812 - val_loss: 0.0070 - val_accuracy: 1.0000
Epoch 12/20
30/30 [==============================] - 341s 11s/step - loss: 0.0296 - accuracy: 0.9917 - val_loss: 0.0256 - val_accuracy: 0.9875
Epoch 13/20
30/30 [==============================] - 335s 11s/step - loss: 0.0252 - accuracy: 0.9917 - val_loss: 0.0431 - val_accuracy: 0.9833
Epoch 14/20
30/30 [==============================] - 345s 12s/step - loss: 0.0058 - accuracy: 0.9979 - val_loss: 0.0088 - val_accuracy: 0.9958
Epoch 15/20
30/30 [==============================] - 557s 19s/step - loss: 0.0015 - accuracy: 1.0000 - val_loss: 0.0144 - val_accuracy: 0.9917
Epoch 16/20
30/30 [==============================] - 340s 11s/step - loss: 3.6823e-04 - accuracy: 1.0000 - val_loss: 0.0052 - val_accuracy: 0.9958
Epoch 17/20
30/30 [==============================] - 347s 12s/step - loss: 5.9116e-05 - accuracy: 1.0000 - val_loss: 0.0064 - val_accuracy: 0.9958
Epoch 18/20
30/30 [==============================] - 347s 12s/step - loss: 2.5309e-05 - accuracy: 1.0000 - val_loss: 0.0048 - val_accuracy: 0.9958
Epoch 19/20
30/30 [==============================] - 350s 12s/step - loss: 1.0864e-05 - accuracy: 1.0000 - val_loss: 0.0033 - val_accuracy: 1.0000
Epoch 20/20
30/30 [==============================] - 341s 11s/step - loss: 6.0013e-06 - accuracy: 1.0000 - val_loss: 0.0045 - val_accuracy: 0.9958

六 可视化结果

python 复制代码
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']

loss = history.history['loss']
val_loss = history.history['val_loss']

epochs_range = range(epochs)

plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

预测图片

python 复制代码
import numpy as np

# 采用加载的模型(new_model)来看预测结果
plt.figure(figsize=(18, 3))  # 图形的宽为18高为5
plt.suptitle("预测结果展示")

for images, labels in val_ds.take(1):
    for i in range(8):
        ax = plt.subplot(1,8, i + 1)  
        
        # 显示图片
        plt.imshow(images[i].numpy())
        
        # 需要给图片增加一个维度
        img_array = tf.expand_dims(images[i], 0) 
        
        # 使用模型预测图片中的人物
        predictions = model.predict(img_array)
        plt.title(class_names[np.argmax(predictions)])

        plt.axis("off")
复制代码
1/1 [==============================] - 0s 279ms/step
1/1 [==============================] - 0s 110ms/step
1/1 [==============================] - 0s 118ms/step
1/1 [==============================] - 0s 109ms/step
1/1 [==============================] - 0s 110ms/step
1/1 [==============================] - 0s 104ms/step
1/1 [==============================] - 0s 111ms/step
1/1 [==============================] - 0s 115ms/step
相关推荐
Tony聊跨境10 分钟前
独立站SEO类型及优化:来检查这些方面你有没有落下
网络·人工智能·tcp/ip·ip
懒惰才能让科技进步16 分钟前
从零学习大模型(十二)-----基于梯度的重要性剪枝(Gradient-based Pruning)
人工智能·深度学习·学习·算法·chatgpt·transformer·剪枝
Qspace丨轻空间27 分钟前
气膜场馆:推动体育文化旅游创新发展的关键力量—轻空间
大数据·人工智能·安全·生活·娱乐
没有不重的名么28 分钟前
门控循环单元GRU
人工智能·深度学习·gru
love_and_hope31 分钟前
Pytorch学习--神经网络--搭建小实战(手撕CIFAR 10 model structure)和 Sequential 的使用
人工智能·pytorch·python·深度学习·学习
2403_875736871 小时前
道品科技智慧农业中的自动气象检测站
网络·人工智能·智慧城市
学术头条1 小时前
AI 的「phone use」竟是这样练成的,清华、智谱团队发布 AutoGLM 技术报告
人工智能·科技·深度学习·语言模型
准橙考典1 小时前
怎么能更好的通过驾考呢?
人工智能·笔记·自动驾驶·汽车·学习方法
ai_xiaogui1 小时前
AIStarter教程:快速学会卸载AI项目【AI项目管理平台】
人工智能·ai作画·语音识别·ai写作·ai软件
孙同学要努力1 小时前
《深度学习》——深度学习基础知识(全连接神经网络)
人工智能·深度学习·神经网络