Tensorflow实现深度学习案例7:咖啡豆识别

本文为为🔗365天深度学习训练营内部文章

原作者:K同学啊

一、前期工作

1. 导入数据

python 复制代码
from tensorflow       import keras
from tensorflow.keras import layers,models
import numpy             as np
import matplotlib.pyplot as plt
import os,PIL,pathlib
import tensorflow as tf
import warnings as w
w.filterwarnings('ignore')

data_dir = "./coffee/"
data_dir = pathlib.Path(data_dir)

image_count = len(list(data_dir.glob('*/*.png')))

print("图片总数为:",image_count)
复制代码
图片总数为: 1200

二、数据预处理

1. 加载数据

使用image_dataset_from_directory方法将磁盘中的数据加载到tf.data.Dataset

python 复制代码
batch_size = 32
img_height = 224
img_width = 224
python 复制代码
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="training",
    seed=123,
    image_size=(img_height, img_width),
    batch_size=batch_size)

Found 1200 files belonging to 4 classes.

Using 960 files for training.

python 复制代码
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="validation",
    seed=123,
    image_size=(img_height, img_width),
    batch_size=batch_size)

Found 1200 files belonging to 4 classes.

Using 240 files for validation.

python 复制代码
class_names = train_ds.class_names
print(class_names)
复制代码
['Dark', 'Green', 'Light', 'Medium']

2.数据可视化

python 复制代码
plt.figure(figsize=(10, 4))  # 图形的宽为10高为5

for images, labels in train_ds.take(1):
    for i in range(10):
        
        ax = plt.subplot(2, 5, i + 1)  

        plt.imshow(images[i].numpy().astype("uint8"))
        plt.title(class_names[labels[i]])
        
        plt.axis("off")
python 复制代码
for image_batch, labels_batch in train_ds:
    print(image_batch.shape)
    print(labels_batch.shape)
    break
python 复制代码
for image_batch, labels_batch in train_ds:
    print(image_batch.shape)
    print(labels_batch.shape)
    break

(32, 224, 224, 3)

(32,)

3. 配置数据集

  • prefetch() :预取数据,加速运行,其详细介绍可以参考我前两篇文章,里面都有讲解。
  • cache() :将数据集缓存到内存当中,加速运行
python 复制代码
AUTOTUNE = tf.data.AUTOTUNE

train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds   = val_ds.cache().prefetch(buffer_size=AUTOTUNE)

并且将数据归一化

python 复制代码
normalization_layer = layers.experimental.preprocessing.Rescaling(1./255)

train_ds = train_ds.map(lambda x, y: (normalization_layer(x), y))
val_ds   = val_ds.map(lambda x, y: (normalization_layer(x), y))

image_batch, labels_batch = next(iter(val_ds))
first_image = image_batch[0]

# 查看归一化后的数据
print(np.min(first_image), np.max(first_image))

0.0 1.0

三、构建VGG-16网络

1.VGG优缺点分析:

  • VGG优点

VGG的结构非常简洁,整个网络都使用了同样大小的卷积核尺寸(3x3)和最大池化尺寸(2x2)

  • VGG缺点

1)训练时间过长,调参难度大。2)需要的存储容量大,不利于部署。例如存储VGG-16权重值文件的大小为500多MB,不利于安装到嵌入式系统中。

2.网络结构图

结构说明:

  • 13个卷积层(Convolutional Layer),分别用blockX_convX表示
  • 3个全连接层(Fully connected Layer),分别用fcXpredictions表示
  • 5个池化层(Pool layer),分别用blockX_pool表示

VGG-16****包含了16个隐藏层(13个卷积层和3个全连接层),故称为 VGG-16

python 复制代码
model = tf.keras.applications.VGG16(weights='imagenet')
model.summary()
复制代码
Model: "vgg16"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_1 (InputLayer)        [(None, 224, 224, 3)]     0         
                                                                 
 block1_conv1 (Conv2D)       (None, 224, 224, 64)      1792      
                                                                 
 block1_conv2 (Conv2D)       (None, 224, 224, 64)      36928     
                                                                 
 block1_pool (MaxPooling2D)  (None, 112, 112, 64)      0         
                                                                 
 block2_conv1 (Conv2D)       (None, 112, 112, 128)     73856     
                                                                 
 block2_conv2 (Conv2D)       (None, 112, 112, 128)     147584    
                                                                 
 block2_pool (MaxPooling2D)  (None, 56, 56, 128)       0         
                                                                 
 block3_conv1 (Conv2D)       (None, 56, 56, 256)       295168    
                                                                 
 block3_conv2 (Conv2D)       (None, 56, 56, 256)       590080    
                                                                 
 block3_conv3 (Conv2D)       (None, 56, 56, 256)       590080    
                                                                 
 block3_pool (MaxPooling2D)  (None, 28, 28, 256)       0         
                                                                 
 block4_conv1 (Conv2D)       (None, 28, 28, 512)       1180160   
                                                                 
 block4_conv2 (Conv2D)       (None, 28, 28, 512)       2359808   
                                                                 
 block4_conv3 (Conv2D)       (None, 28, 28, 512)       2359808   
                                                                 
 block4_pool (MaxPooling2D)  (None, 14, 14, 512)       0         
                                                                 
 block5_conv1 (Conv2D)       (None, 14, 14, 512)       2359808   
                                                                 
 block5_conv2 (Conv2D)       (None, 14, 14, 512)       2359808   
                                                                 
 block5_conv3 (Conv2D)       (None, 14, 14, 512)       2359808   
                                                                 
 block5_pool (MaxPooling2D)  (None, 7, 7, 512)         0         
                                                                 
 flatten (Flatten)           (None, 25088)             0         
                                                                 
 fc1 (Dense)                 (None, 4096)              102764544 
                                                                 
 fc2 (Dense)                 (None, 4096)              16781312  
                                                                 
 predictions (Dense)         (None, 1000)              4097000   
                                                                 
=================================================================
Total params: 138,357,544
Trainable params: 138,357,544
Non-trainable params: 0
_________________________________________________________________

四、编译

在准备对模型进行训练之前,还需要再对其进行一些设置。以下内容是在模型的编译步骤中添加的:

  • 损失函数(loss):用于衡量模型在训练期间的准确率。
  • 优化器(optimizer):决定模型如何根据其看到的数据和自身的损失函数进行更新。
  • 指标(metrics):用于监控训练和测试步骤。以下示例使用了准确率,即被正确分类的图像的比率。
python 复制代码
# 设置初始学习率
initial_learning_rate = 1e-4

lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
        initial_learning_rate, 
        decay_steps=30,      # 敲黑板!!!这里是指 steps,不是指epochs
        decay_rate=0.92,     # lr经过一次衰减就会变成 decay_rate*lr
        staircase=True)

# 设置优化器
opt = tf.keras.optimizers.Adam(learning_rate=initial_learning_rate)

model.compile(optimizer=opt,
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

五、训练模型

python 复制代码
epochs = 20

history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=epochs
)
复制代码
Epoch 1/20
30/30 [==============================] - 346s 11s/step - loss: 1.7546 - accuracy: 0.2625 - val_loss: 1.4646 - val_accuracy: 0.2125
Epoch 2/20
30/30 [==============================] - 352s 12s/step - loss: 1.3637 - accuracy: 0.3104 - val_loss: 1.0428 - val_accuracy: 0.4583
Epoch 3/20
30/30 [==============================] - 338s 11s/step - loss: 0.7237 - accuracy: 0.6458 - val_loss: 0.4818 - val_accuracy: 0.7833
Epoch 4/20
30/30 [==============================] - 336s 11s/step - loss: 0.3633 - accuracy: 0.8479 - val_loss: 1.1034 - val_accuracy: 0.6167
Epoch 5/20
30/30 [==============================] - 340s 11s/step - loss: 0.2880 - accuracy: 0.8927 - val_loss: 0.1480 - val_accuracy: 0.9500
Epoch 6/20
30/30 [==============================] - 338s 11s/step - loss: 0.1802 - accuracy: 0.9333 - val_loss: 0.4709 - val_accuracy: 0.8458
Epoch 7/20
30/30 [==============================] - 334s 11s/step - loss: 0.1468 - accuracy: 0.9490 - val_loss: 0.0214 - val_accuracy: 1.0000
Epoch 8/20
30/30 [==============================] - 339s 11s/step - loss: 0.0174 - accuracy: 0.9969 - val_loss: 0.0196 - val_accuracy: 0.9875
Epoch 9/20
30/30 [==============================] - 329s 11s/step - loss: 0.0399 - accuracy: 0.9875 - val_loss: 0.2539 - val_accuracy: 0.9292
Epoch 10/20
30/30 [==============================] - 330s 11s/step - loss: 0.2606 - accuracy: 0.9073 - val_loss: 0.0737 - val_accuracy: 0.9917
Epoch 11/20
30/30 [==============================] - 334s 11s/step - loss: 0.0610 - accuracy: 0.9812 - val_loss: 0.0070 - val_accuracy: 1.0000
Epoch 12/20
30/30 [==============================] - 341s 11s/step - loss: 0.0296 - accuracy: 0.9917 - val_loss: 0.0256 - val_accuracy: 0.9875
Epoch 13/20
30/30 [==============================] - 335s 11s/step - loss: 0.0252 - accuracy: 0.9917 - val_loss: 0.0431 - val_accuracy: 0.9833
Epoch 14/20
30/30 [==============================] - 345s 12s/step - loss: 0.0058 - accuracy: 0.9979 - val_loss: 0.0088 - val_accuracy: 0.9958
Epoch 15/20
30/30 [==============================] - 557s 19s/step - loss: 0.0015 - accuracy: 1.0000 - val_loss: 0.0144 - val_accuracy: 0.9917
Epoch 16/20
30/30 [==============================] - 340s 11s/step - loss: 3.6823e-04 - accuracy: 1.0000 - val_loss: 0.0052 - val_accuracy: 0.9958
Epoch 17/20
30/30 [==============================] - 347s 12s/step - loss: 5.9116e-05 - accuracy: 1.0000 - val_loss: 0.0064 - val_accuracy: 0.9958
Epoch 18/20
30/30 [==============================] - 347s 12s/step - loss: 2.5309e-05 - accuracy: 1.0000 - val_loss: 0.0048 - val_accuracy: 0.9958
Epoch 19/20
30/30 [==============================] - 350s 12s/step - loss: 1.0864e-05 - accuracy: 1.0000 - val_loss: 0.0033 - val_accuracy: 1.0000
Epoch 20/20
30/30 [==============================] - 341s 11s/step - loss: 6.0013e-06 - accuracy: 1.0000 - val_loss: 0.0045 - val_accuracy: 0.9958

六 可视化结果

python 复制代码
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']

loss = history.history['loss']
val_loss = history.history['val_loss']

epochs_range = range(epochs)

plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

预测图片

python 复制代码
import numpy as np

# 采用加载的模型(new_model)来看预测结果
plt.figure(figsize=(18, 3))  # 图形的宽为18高为5
plt.suptitle("预测结果展示")

for images, labels in val_ds.take(1):
    for i in range(8):
        ax = plt.subplot(1,8, i + 1)  
        
        # 显示图片
        plt.imshow(images[i].numpy())
        
        # 需要给图片增加一个维度
        img_array = tf.expand_dims(images[i], 0) 
        
        # 使用模型预测图片中的人物
        predictions = model.predict(img_array)
        plt.title(class_names[np.argmax(predictions)])

        plt.axis("off")
复制代码
1/1 [==============================] - 0s 279ms/step
1/1 [==============================] - 0s 110ms/step
1/1 [==============================] - 0s 118ms/step
1/1 [==============================] - 0s 109ms/step
1/1 [==============================] - 0s 110ms/step
1/1 [==============================] - 0s 104ms/step
1/1 [==============================] - 0s 111ms/step
1/1 [==============================] - 0s 115ms/step
相关推荐
jndingxin8 分钟前
OpenCV特征检测(1)检测图像中的线段的类LineSegmentDe()的使用
人工智能·opencv·计算机视觉
@月落18 分钟前
alibaba获得店铺的所有商品 API接口
java·大数据·数据库·人工智能·学习
z千鑫27 分钟前
【人工智能】如何利用AI轻松将java,c++等代码转换为Python语言?程序员必读
java·c++·人工智能·gpt·agent·ai编程·ai工具
MinIO官方账号1 小时前
从 HDFS 迁移到 MinIO 企业对象存储
人工智能·分布式·postgresql·架构·开源
aWty_1 小时前
机器学习--K-Means
人工智能·机器学习·kmeans
草莓屁屁我不吃1 小时前
AI大语言模型的全面解读
人工智能·语言模型·自然语言处理·chatgpt
WPG大大通1 小时前
有奖直播 | onsemi IPM 助力汽车电气革命及电子化时代冷热管理
大数据·人工智能·汽车·方案·电气·大大通·研讨会
百锦再1 小时前
AI对汽车行业的冲击和比亚迪新能源汽车市场占比
人工智能·汽车
ws2019071 小时前
抓机遇,促发展——2025第十二届广州国际汽车零部件加工技术及汽车模具展览会
大数据·人工智能·汽车
Zhangci]1 小时前
Opencv图像预处理(三)
人工智能·opencv·计算机视觉