Tensorflow实现深度学习案例7:咖啡豆识别

本文为为🔗365天深度学习训练营内部文章

原作者:K同学啊

一、前期工作

1. 导入数据

python 复制代码
from tensorflow       import keras
from tensorflow.keras import layers,models
import numpy             as np
import matplotlib.pyplot as plt
import os,PIL,pathlib
import tensorflow as tf
import warnings as w
w.filterwarnings('ignore')

data_dir = "./coffee/"
data_dir = pathlib.Path(data_dir)

image_count = len(list(data_dir.glob('*/*.png')))

print("图片总数为:",image_count)
复制代码
图片总数为: 1200

二、数据预处理

1. 加载数据

使用image_dataset_from_directory方法将磁盘中的数据加载到tf.data.Dataset

python 复制代码
batch_size = 32
img_height = 224
img_width = 224
python 复制代码
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="training",
    seed=123,
    image_size=(img_height, img_width),
    batch_size=batch_size)

Found 1200 files belonging to 4 classes.

Using 960 files for training.

python 复制代码
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="validation",
    seed=123,
    image_size=(img_height, img_width),
    batch_size=batch_size)

Found 1200 files belonging to 4 classes.

Using 240 files for validation.

python 复制代码
class_names = train_ds.class_names
print(class_names)
复制代码
['Dark', 'Green', 'Light', 'Medium']

2.数据可视化

python 复制代码
plt.figure(figsize=(10, 4))  # 图形的宽为10高为5

for images, labels in train_ds.take(1):
    for i in range(10):
        
        ax = plt.subplot(2, 5, i + 1)  

        plt.imshow(images[i].numpy().astype("uint8"))
        plt.title(class_names[labels[i]])
        
        plt.axis("off")
python 复制代码
for image_batch, labels_batch in train_ds:
    print(image_batch.shape)
    print(labels_batch.shape)
    break
python 复制代码
for image_batch, labels_batch in train_ds:
    print(image_batch.shape)
    print(labels_batch.shape)
    break

(32, 224, 224, 3)

(32,)

3. 配置数据集

  • prefetch() :预取数据,加速运行,其详细介绍可以参考我前两篇文章,里面都有讲解。
  • cache() :将数据集缓存到内存当中,加速运行
python 复制代码
AUTOTUNE = tf.data.AUTOTUNE

train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds   = val_ds.cache().prefetch(buffer_size=AUTOTUNE)

并且将数据归一化

python 复制代码
normalization_layer = layers.experimental.preprocessing.Rescaling(1./255)

train_ds = train_ds.map(lambda x, y: (normalization_layer(x), y))
val_ds   = val_ds.map(lambda x, y: (normalization_layer(x), y))

image_batch, labels_batch = next(iter(val_ds))
first_image = image_batch[0]

# 查看归一化后的数据
print(np.min(first_image), np.max(first_image))

0.0 1.0

三、构建VGG-16网络

1.VGG优缺点分析:

  • VGG优点

VGG的结构非常简洁,整个网络都使用了同样大小的卷积核尺寸(3x3)和最大池化尺寸(2x2)

  • VGG缺点

1)训练时间过长,调参难度大。2)需要的存储容量大,不利于部署。例如存储VGG-16权重值文件的大小为500多MB,不利于安装到嵌入式系统中。

2.网络结构图

结构说明:

  • 13个卷积层(Convolutional Layer),分别用blockX_convX表示
  • 3个全连接层(Fully connected Layer),分别用fcXpredictions表示
  • 5个池化层(Pool layer),分别用blockX_pool表示

VGG-16****包含了16个隐藏层(13个卷积层和3个全连接层),故称为 VGG-16

python 复制代码
model = tf.keras.applications.VGG16(weights='imagenet')
model.summary()
复制代码
Model: "vgg16"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_1 (InputLayer)        [(None, 224, 224, 3)]     0         
                                                                 
 block1_conv1 (Conv2D)       (None, 224, 224, 64)      1792      
                                                                 
 block1_conv2 (Conv2D)       (None, 224, 224, 64)      36928     
                                                                 
 block1_pool (MaxPooling2D)  (None, 112, 112, 64)      0         
                                                                 
 block2_conv1 (Conv2D)       (None, 112, 112, 128)     73856     
                                                                 
 block2_conv2 (Conv2D)       (None, 112, 112, 128)     147584    
                                                                 
 block2_pool (MaxPooling2D)  (None, 56, 56, 128)       0         
                                                                 
 block3_conv1 (Conv2D)       (None, 56, 56, 256)       295168    
                                                                 
 block3_conv2 (Conv2D)       (None, 56, 56, 256)       590080    
                                                                 
 block3_conv3 (Conv2D)       (None, 56, 56, 256)       590080    
                                                                 
 block3_pool (MaxPooling2D)  (None, 28, 28, 256)       0         
                                                                 
 block4_conv1 (Conv2D)       (None, 28, 28, 512)       1180160   
                                                                 
 block4_conv2 (Conv2D)       (None, 28, 28, 512)       2359808   
                                                                 
 block4_conv3 (Conv2D)       (None, 28, 28, 512)       2359808   
                                                                 
 block4_pool (MaxPooling2D)  (None, 14, 14, 512)       0         
                                                                 
 block5_conv1 (Conv2D)       (None, 14, 14, 512)       2359808   
                                                                 
 block5_conv2 (Conv2D)       (None, 14, 14, 512)       2359808   
                                                                 
 block5_conv3 (Conv2D)       (None, 14, 14, 512)       2359808   
                                                                 
 block5_pool (MaxPooling2D)  (None, 7, 7, 512)         0         
                                                                 
 flatten (Flatten)           (None, 25088)             0         
                                                                 
 fc1 (Dense)                 (None, 4096)              102764544 
                                                                 
 fc2 (Dense)                 (None, 4096)              16781312  
                                                                 
 predictions (Dense)         (None, 1000)              4097000   
                                                                 
=================================================================
Total params: 138,357,544
Trainable params: 138,357,544
Non-trainable params: 0
_________________________________________________________________

四、编译

在准备对模型进行训练之前,还需要再对其进行一些设置。以下内容是在模型的编译步骤中添加的:

  • 损失函数(loss):用于衡量模型在训练期间的准确率。
  • 优化器(optimizer):决定模型如何根据其看到的数据和自身的损失函数进行更新。
  • 指标(metrics):用于监控训练和测试步骤。以下示例使用了准确率,即被正确分类的图像的比率。
python 复制代码
# 设置初始学习率
initial_learning_rate = 1e-4

lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
        initial_learning_rate, 
        decay_steps=30,      # 敲黑板!!!这里是指 steps,不是指epochs
        decay_rate=0.92,     # lr经过一次衰减就会变成 decay_rate*lr
        staircase=True)

# 设置优化器
opt = tf.keras.optimizers.Adam(learning_rate=initial_learning_rate)

model.compile(optimizer=opt,
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

五、训练模型

python 复制代码
epochs = 20

history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=epochs
)
复制代码
Epoch 1/20
30/30 [==============================] - 346s 11s/step - loss: 1.7546 - accuracy: 0.2625 - val_loss: 1.4646 - val_accuracy: 0.2125
Epoch 2/20
30/30 [==============================] - 352s 12s/step - loss: 1.3637 - accuracy: 0.3104 - val_loss: 1.0428 - val_accuracy: 0.4583
Epoch 3/20
30/30 [==============================] - 338s 11s/step - loss: 0.7237 - accuracy: 0.6458 - val_loss: 0.4818 - val_accuracy: 0.7833
Epoch 4/20
30/30 [==============================] - 336s 11s/step - loss: 0.3633 - accuracy: 0.8479 - val_loss: 1.1034 - val_accuracy: 0.6167
Epoch 5/20
30/30 [==============================] - 340s 11s/step - loss: 0.2880 - accuracy: 0.8927 - val_loss: 0.1480 - val_accuracy: 0.9500
Epoch 6/20
30/30 [==============================] - 338s 11s/step - loss: 0.1802 - accuracy: 0.9333 - val_loss: 0.4709 - val_accuracy: 0.8458
Epoch 7/20
30/30 [==============================] - 334s 11s/step - loss: 0.1468 - accuracy: 0.9490 - val_loss: 0.0214 - val_accuracy: 1.0000
Epoch 8/20
30/30 [==============================] - 339s 11s/step - loss: 0.0174 - accuracy: 0.9969 - val_loss: 0.0196 - val_accuracy: 0.9875
Epoch 9/20
30/30 [==============================] - 329s 11s/step - loss: 0.0399 - accuracy: 0.9875 - val_loss: 0.2539 - val_accuracy: 0.9292
Epoch 10/20
30/30 [==============================] - 330s 11s/step - loss: 0.2606 - accuracy: 0.9073 - val_loss: 0.0737 - val_accuracy: 0.9917
Epoch 11/20
30/30 [==============================] - 334s 11s/step - loss: 0.0610 - accuracy: 0.9812 - val_loss: 0.0070 - val_accuracy: 1.0000
Epoch 12/20
30/30 [==============================] - 341s 11s/step - loss: 0.0296 - accuracy: 0.9917 - val_loss: 0.0256 - val_accuracy: 0.9875
Epoch 13/20
30/30 [==============================] - 335s 11s/step - loss: 0.0252 - accuracy: 0.9917 - val_loss: 0.0431 - val_accuracy: 0.9833
Epoch 14/20
30/30 [==============================] - 345s 12s/step - loss: 0.0058 - accuracy: 0.9979 - val_loss: 0.0088 - val_accuracy: 0.9958
Epoch 15/20
30/30 [==============================] - 557s 19s/step - loss: 0.0015 - accuracy: 1.0000 - val_loss: 0.0144 - val_accuracy: 0.9917
Epoch 16/20
30/30 [==============================] - 340s 11s/step - loss: 3.6823e-04 - accuracy: 1.0000 - val_loss: 0.0052 - val_accuracy: 0.9958
Epoch 17/20
30/30 [==============================] - 347s 12s/step - loss: 5.9116e-05 - accuracy: 1.0000 - val_loss: 0.0064 - val_accuracy: 0.9958
Epoch 18/20
30/30 [==============================] - 347s 12s/step - loss: 2.5309e-05 - accuracy: 1.0000 - val_loss: 0.0048 - val_accuracy: 0.9958
Epoch 19/20
30/30 [==============================] - 350s 12s/step - loss: 1.0864e-05 - accuracy: 1.0000 - val_loss: 0.0033 - val_accuracy: 1.0000
Epoch 20/20
30/30 [==============================] - 341s 11s/step - loss: 6.0013e-06 - accuracy: 1.0000 - val_loss: 0.0045 - val_accuracy: 0.9958

六 可视化结果

python 复制代码
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']

loss = history.history['loss']
val_loss = history.history['val_loss']

epochs_range = range(epochs)

plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

预测图片

python 复制代码
import numpy as np

# 采用加载的模型(new_model)来看预测结果
plt.figure(figsize=(18, 3))  # 图形的宽为18高为5
plt.suptitle("预测结果展示")

for images, labels in val_ds.take(1):
    for i in range(8):
        ax = plt.subplot(1,8, i + 1)  
        
        # 显示图片
        plt.imshow(images[i].numpy())
        
        # 需要给图片增加一个维度
        img_array = tf.expand_dims(images[i], 0) 
        
        # 使用模型预测图片中的人物
        predictions = model.predict(img_array)
        plt.title(class_names[np.argmax(predictions)])

        plt.axis("off")
复制代码
1/1 [==============================] - 0s 279ms/step
1/1 [==============================] - 0s 110ms/step
1/1 [==============================] - 0s 118ms/step
1/1 [==============================] - 0s 109ms/step
1/1 [==============================] - 0s 110ms/step
1/1 [==============================] - 0s 104ms/step
1/1 [==============================] - 0s 111ms/step
1/1 [==============================] - 0s 115ms/step
相关推荐
hans汉斯22 分钟前
【人工智能与机器人研究】基于力传感器坐标系预标定的重力补偿算法
人工智能·算法·机器人·信号处理·深度神经网络
cver12330 分钟前
CSGO 训练数据集介绍-2,427 张图片 AI 游戏助手 游戏数据分析
人工智能·深度学习·yolo·目标检测·游戏·计算机视觉
FreeBuf_34 分钟前
新型BERT勒索软件肆虐:多线程攻击同时针对Windows、Linux及ESXi系统
人工智能·深度学习·bert
强哥之神1 小时前
Meta AI 推出 Multi - SpatialMLLM:借助多模态大语言模型实现多帧空间理解
人工智能·深度学习·计算机视觉·语言模型·自然语言处理·llama
成都极云科技1 小时前
成都算力租赁新趋势:H20 八卡服务器如何重塑 AI 产业格局?
大数据·服务器·人工智能·云计算·gpu算力
喜欢吃豆1 小时前
从零构建MCP服务器:FastMCP实战指南
运维·服务器·人工智能·python·大模型·mcp
ai_xiaogui2 小时前
AIStarter用户与创作者模式详解:一键管理Stable Diffusion项目!
人工智能·stable diffusion·一键发布ai项目·熊哥aistarter教程·开发者必备aistarter
止步前行2 小时前
Cursor配置DeepSeek调用MCP服务实现任务自动化
人工智能·cursor·deepseek·mcp
阿星AI工作室2 小时前
AI产品经理必看的大模型微调劝退指南丨实战笔记
人工智能·产品经理·ai编程
Damon小智2 小时前
蚂蚁百宝箱实战:艺考生文化课助手的设计与搭建
人工智能·mcp