Tensorflow实现深度学习案例7:咖啡豆识别

本文为为🔗365天深度学习训练营内部文章

原作者:K同学啊

一、前期工作

1. 导入数据

python 复制代码
from tensorflow       import keras
from tensorflow.keras import layers,models
import numpy             as np
import matplotlib.pyplot as plt
import os,PIL,pathlib
import tensorflow as tf
import warnings as w
w.filterwarnings('ignore')

data_dir = "./coffee/"
data_dir = pathlib.Path(data_dir)

image_count = len(list(data_dir.glob('*/*.png')))

print("图片总数为:",image_count)
复制代码
图片总数为: 1200

二、数据预处理

1. 加载数据

使用image_dataset_from_directory方法将磁盘中的数据加载到tf.data.Dataset

python 复制代码
batch_size = 32
img_height = 224
img_width = 224
python 复制代码
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="training",
    seed=123,
    image_size=(img_height, img_width),
    batch_size=batch_size)

Found 1200 files belonging to 4 classes.

Using 960 files for training.

python 复制代码
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="validation",
    seed=123,
    image_size=(img_height, img_width),
    batch_size=batch_size)

Found 1200 files belonging to 4 classes.

Using 240 files for validation.

python 复制代码
class_names = train_ds.class_names
print(class_names)
复制代码
['Dark', 'Green', 'Light', 'Medium']

2.数据可视化

python 复制代码
plt.figure(figsize=(10, 4))  # 图形的宽为10高为5

for images, labels in train_ds.take(1):
    for i in range(10):
        
        ax = plt.subplot(2, 5, i + 1)  

        plt.imshow(images[i].numpy().astype("uint8"))
        plt.title(class_names[labels[i]])
        
        plt.axis("off")
python 复制代码
for image_batch, labels_batch in train_ds:
    print(image_batch.shape)
    print(labels_batch.shape)
    break
python 复制代码
for image_batch, labels_batch in train_ds:
    print(image_batch.shape)
    print(labels_batch.shape)
    break

(32, 224, 224, 3)

(32,)

3. 配置数据集

  • prefetch() :预取数据,加速运行,其详细介绍可以参考我前两篇文章,里面都有讲解。
  • cache() :将数据集缓存到内存当中,加速运行
python 复制代码
AUTOTUNE = tf.data.AUTOTUNE

train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds   = val_ds.cache().prefetch(buffer_size=AUTOTUNE)

并且将数据归一化

python 复制代码
normalization_layer = layers.experimental.preprocessing.Rescaling(1./255)

train_ds = train_ds.map(lambda x, y: (normalization_layer(x), y))
val_ds   = val_ds.map(lambda x, y: (normalization_layer(x), y))

image_batch, labels_batch = next(iter(val_ds))
first_image = image_batch[0]

# 查看归一化后的数据
print(np.min(first_image), np.max(first_image))

0.0 1.0

三、构建VGG-16网络

1.VGG优缺点分析:

  • VGG优点

VGG的结构非常简洁,整个网络都使用了同样大小的卷积核尺寸(3x3)和最大池化尺寸(2x2)

  • VGG缺点

1)训练时间过长,调参难度大。2)需要的存储容量大,不利于部署。例如存储VGG-16权重值文件的大小为500多MB,不利于安装到嵌入式系统中。

2.网络结构图

结构说明:

  • 13个卷积层(Convolutional Layer),分别用blockX_convX表示
  • 3个全连接层(Fully connected Layer),分别用fcXpredictions表示
  • 5个池化层(Pool layer),分别用blockX_pool表示

VGG-16****包含了16个隐藏层(13个卷积层和3个全连接层),故称为 VGG-16

python 复制代码
model = tf.keras.applications.VGG16(weights='imagenet')
model.summary()
复制代码
Model: "vgg16"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_1 (InputLayer)        [(None, 224, 224, 3)]     0         
                                                                 
 block1_conv1 (Conv2D)       (None, 224, 224, 64)      1792      
                                                                 
 block1_conv2 (Conv2D)       (None, 224, 224, 64)      36928     
                                                                 
 block1_pool (MaxPooling2D)  (None, 112, 112, 64)      0         
                                                                 
 block2_conv1 (Conv2D)       (None, 112, 112, 128)     73856     
                                                                 
 block2_conv2 (Conv2D)       (None, 112, 112, 128)     147584    
                                                                 
 block2_pool (MaxPooling2D)  (None, 56, 56, 128)       0         
                                                                 
 block3_conv1 (Conv2D)       (None, 56, 56, 256)       295168    
                                                                 
 block3_conv2 (Conv2D)       (None, 56, 56, 256)       590080    
                                                                 
 block3_conv3 (Conv2D)       (None, 56, 56, 256)       590080    
                                                                 
 block3_pool (MaxPooling2D)  (None, 28, 28, 256)       0         
                                                                 
 block4_conv1 (Conv2D)       (None, 28, 28, 512)       1180160   
                                                                 
 block4_conv2 (Conv2D)       (None, 28, 28, 512)       2359808   
                                                                 
 block4_conv3 (Conv2D)       (None, 28, 28, 512)       2359808   
                                                                 
 block4_pool (MaxPooling2D)  (None, 14, 14, 512)       0         
                                                                 
 block5_conv1 (Conv2D)       (None, 14, 14, 512)       2359808   
                                                                 
 block5_conv2 (Conv2D)       (None, 14, 14, 512)       2359808   
                                                                 
 block5_conv3 (Conv2D)       (None, 14, 14, 512)       2359808   
                                                                 
 block5_pool (MaxPooling2D)  (None, 7, 7, 512)         0         
                                                                 
 flatten (Flatten)           (None, 25088)             0         
                                                                 
 fc1 (Dense)                 (None, 4096)              102764544 
                                                                 
 fc2 (Dense)                 (None, 4096)              16781312  
                                                                 
 predictions (Dense)         (None, 1000)              4097000   
                                                                 
=================================================================
Total params: 138,357,544
Trainable params: 138,357,544
Non-trainable params: 0
_________________________________________________________________

四、编译

在准备对模型进行训练之前,还需要再对其进行一些设置。以下内容是在模型的编译步骤中添加的:

  • 损失函数(loss):用于衡量模型在训练期间的准确率。
  • 优化器(optimizer):决定模型如何根据其看到的数据和自身的损失函数进行更新。
  • 指标(metrics):用于监控训练和测试步骤。以下示例使用了准确率,即被正确分类的图像的比率。
python 复制代码
# 设置初始学习率
initial_learning_rate = 1e-4

lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
        initial_learning_rate, 
        decay_steps=30,      # 敲黑板!!!这里是指 steps,不是指epochs
        decay_rate=0.92,     # lr经过一次衰减就会变成 decay_rate*lr
        staircase=True)

# 设置优化器
opt = tf.keras.optimizers.Adam(learning_rate=initial_learning_rate)

model.compile(optimizer=opt,
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

五、训练模型

python 复制代码
epochs = 20

history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=epochs
)
复制代码
Epoch 1/20
30/30 [==============================] - 346s 11s/step - loss: 1.7546 - accuracy: 0.2625 - val_loss: 1.4646 - val_accuracy: 0.2125
Epoch 2/20
30/30 [==============================] - 352s 12s/step - loss: 1.3637 - accuracy: 0.3104 - val_loss: 1.0428 - val_accuracy: 0.4583
Epoch 3/20
30/30 [==============================] - 338s 11s/step - loss: 0.7237 - accuracy: 0.6458 - val_loss: 0.4818 - val_accuracy: 0.7833
Epoch 4/20
30/30 [==============================] - 336s 11s/step - loss: 0.3633 - accuracy: 0.8479 - val_loss: 1.1034 - val_accuracy: 0.6167
Epoch 5/20
30/30 [==============================] - 340s 11s/step - loss: 0.2880 - accuracy: 0.8927 - val_loss: 0.1480 - val_accuracy: 0.9500
Epoch 6/20
30/30 [==============================] - 338s 11s/step - loss: 0.1802 - accuracy: 0.9333 - val_loss: 0.4709 - val_accuracy: 0.8458
Epoch 7/20
30/30 [==============================] - 334s 11s/step - loss: 0.1468 - accuracy: 0.9490 - val_loss: 0.0214 - val_accuracy: 1.0000
Epoch 8/20
30/30 [==============================] - 339s 11s/step - loss: 0.0174 - accuracy: 0.9969 - val_loss: 0.0196 - val_accuracy: 0.9875
Epoch 9/20
30/30 [==============================] - 329s 11s/step - loss: 0.0399 - accuracy: 0.9875 - val_loss: 0.2539 - val_accuracy: 0.9292
Epoch 10/20
30/30 [==============================] - 330s 11s/step - loss: 0.2606 - accuracy: 0.9073 - val_loss: 0.0737 - val_accuracy: 0.9917
Epoch 11/20
30/30 [==============================] - 334s 11s/step - loss: 0.0610 - accuracy: 0.9812 - val_loss: 0.0070 - val_accuracy: 1.0000
Epoch 12/20
30/30 [==============================] - 341s 11s/step - loss: 0.0296 - accuracy: 0.9917 - val_loss: 0.0256 - val_accuracy: 0.9875
Epoch 13/20
30/30 [==============================] - 335s 11s/step - loss: 0.0252 - accuracy: 0.9917 - val_loss: 0.0431 - val_accuracy: 0.9833
Epoch 14/20
30/30 [==============================] - 345s 12s/step - loss: 0.0058 - accuracy: 0.9979 - val_loss: 0.0088 - val_accuracy: 0.9958
Epoch 15/20
30/30 [==============================] - 557s 19s/step - loss: 0.0015 - accuracy: 1.0000 - val_loss: 0.0144 - val_accuracy: 0.9917
Epoch 16/20
30/30 [==============================] - 340s 11s/step - loss: 3.6823e-04 - accuracy: 1.0000 - val_loss: 0.0052 - val_accuracy: 0.9958
Epoch 17/20
30/30 [==============================] - 347s 12s/step - loss: 5.9116e-05 - accuracy: 1.0000 - val_loss: 0.0064 - val_accuracy: 0.9958
Epoch 18/20
30/30 [==============================] - 347s 12s/step - loss: 2.5309e-05 - accuracy: 1.0000 - val_loss: 0.0048 - val_accuracy: 0.9958
Epoch 19/20
30/30 [==============================] - 350s 12s/step - loss: 1.0864e-05 - accuracy: 1.0000 - val_loss: 0.0033 - val_accuracy: 1.0000
Epoch 20/20
30/30 [==============================] - 341s 11s/step - loss: 6.0013e-06 - accuracy: 1.0000 - val_loss: 0.0045 - val_accuracy: 0.9958

六 可视化结果

python 复制代码
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']

loss = history.history['loss']
val_loss = history.history['val_loss']

epochs_range = range(epochs)

plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

预测图片

python 复制代码
import numpy as np

# 采用加载的模型(new_model)来看预测结果
plt.figure(figsize=(18, 3))  # 图形的宽为18高为5
plt.suptitle("预测结果展示")

for images, labels in val_ds.take(1):
    for i in range(8):
        ax = plt.subplot(1,8, i + 1)  
        
        # 显示图片
        plt.imshow(images[i].numpy())
        
        # 需要给图片增加一个维度
        img_array = tf.expand_dims(images[i], 0) 
        
        # 使用模型预测图片中的人物
        predictions = model.predict(img_array)
        plt.title(class_names[np.argmax(predictions)])

        plt.axis("off")
复制代码
1/1 [==============================] - 0s 279ms/step
1/1 [==============================] - 0s 110ms/step
1/1 [==============================] - 0s 118ms/step
1/1 [==============================] - 0s 109ms/step
1/1 [==============================] - 0s 110ms/step
1/1 [==============================] - 0s 104ms/step
1/1 [==============================] - 0s 111ms/step
1/1 [==============================] - 0s 115ms/step
相关推荐
studyer_domi22 分钟前
Matlab 三维时频图
开发语言·人工智能·matlab
小森77671 小时前
(四)机器学习---逻辑回归及其Python实现
人工智能·python·算法·机器学习·逻辑回归·线性回归
生信碱移1 小时前
入门级宏基因组数据分析教程,从实验到分析与应用
人工智能·经验分享·python·神经网络·数据挖掘·数据分析·数据可视化
發發期权酱1 小时前
期权中的Gamma指标详解
大数据·人工智能
补三补四1 小时前
【深度学习基础】——机器的神经元:感知机
人工智能·深度学习·算法·机器学习
永洪科技2 小时前
AI领域再突破,永洪科技荣获“2025人工智能+创新案例”奖
大数据·人工智能·科技·数据分析·数据可视化
that's boy2 小时前
Google 发布 Sec-Gemini v1:用 AI 重塑网络安全防御格局?
人工智能·安全·web安全·chatgpt·midjourney·ai编程·ai写作
Sui_Network2 小时前
Crossmint 与 Walrus 合作,将协议集成至其跨链铸造 API 中
人工智能·物联网·游戏·区块链·智能合约
liruiqiang052 小时前
循环神经网络 - 长短期记忆网络
人工智能·rnn·深度学习·神经网络·机器学习·ai·lstm