深度学习笔记15_TensorFlow实现运动鞋品牌识别

一、我的环境

1.语言环境:Python 3.9

2.编译器:Pycharm

3.深度学习环境:TensorFlow 2.10.0

二、GPU设置

若使用的是cpu则可忽略

复制代码
import tensorflow as tf
gpus = tf.config.list_physical_devices("GPU")

if gpus:
    gpu0 = gpus[0] #如果有多个GPU,仅使用第0个GPU
    tf.config.experimental.set_memory_growth(gpu0, True) #设置GPU显存用量按需使用
    tf.config.set_visible_devices([gpu0],"GPU")

三**、导入数据**

复制代码
data_dir = "./data/"
data_dir = pathlib.Path(data_dir)

image_count = len(list(data_dir.glob('*/*/*.jpg')))

print("图片总数为:",image_count)
#图片总数为:578

四**、数据预处理**

复制代码
batch_size = 32
img_height = 224
img_width = 224

"""
关于image_dataset_from_directory()的详细介绍可以参考文章:https://mtyjkh.blog.csdn.net/article/details/117018789
"""
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    "./data/train/",
    seed=123,
    image_size=(img_height, img_width),
    batch_size=batch_size)

"""
关于image_dataset_from_directory()的详细介绍可以参考文章:https://mtyjkh.blog.csdn.net/article/details/117018789
"""
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
    "./data/test/",
    seed=123,
    image_size=(img_height, img_width),
    batch_size=batch_size)

class_names = train_ds.class_names
print(class_names)

运行结果:

复制代码
['adidas', 'nike']

五、可视化图片

复制代码
plt.figure(figsize=(20, 10))

for images, labels in train_ds.take(1):
    for i in range(20):
        ax = plt.subplot(5, 10, i + 1)

        plt.imshow(images[i].numpy().astype("uint8"))
        plt.title(class_names[labels[i]])
        
        plt.axis("off")
plt.show()

运行结果:

​​

再次检查数据:

复制代码
for image_batch, labels_batch in train_ds:
    print(image_batch.shape)
    print(labels_batch.shape)
    break

运行结果:

复制代码
(32, 224, 224, 3)
(32,)

六、配置数据集

  • shuffle() :打乱数据,关于此函数的详细介绍可以参考:https://zhuanlan.zhihu.com/p/42417456
  • prefetch():预取数据,加速运行
  • cache():将数据集缓存到内存当中,加速运行
python 复制代码
AUTOTUNE = tf.data.AUTOTUNE

train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)

七、构建CNN网络模型

卷积神经网络(CNN)的输入是张量 (Tensor) 形式的 (image_height, image_width, color_channels),包含了图像高度、宽度及颜色信息。不需要输入batch size。color_channels 为 (R,G,B) 分别对应 RGB 的三个颜色通道(color channel)。在此示例中,我们的 CNN 输入形状是 (180, 180, 3)。我们需要在声明第一层时将形状赋值给参数input_shape

python 复制代码
"""
关于卷积核的计算不懂的可以参考文章:https://blog.csdn.net/qq_38251616/article/details/114278995

layers.Dropout(0.4) 作用是防止过拟合,提高模型的泛化能力。
关于Dropout层的更多介绍可以参考文章:https://mtyjkh.blog.csdn.net/article/details/115826689
"""

model = models.Sequential([
    layers.experimental.preprocessing.Rescaling(1./255, input_shape=(img_height, img_width, 3)),
    
    layers.Conv2D(16, (3, 3), activation='relu', input_shape=(img_height, img_width, 3)), # 卷积层1,卷积核3*3  
    layers.AveragePooling2D((2, 2)),               # 池化层1,2*2采样
    layers.Conv2D(32, (3, 3), activation='relu'),  # 卷积层2,卷积核3*3
    layers.AveragePooling2D((2, 2)),               # 池化层2,2*2采样
    layers.Dropout(0.3),  
    layers.Conv2D(64, (3, 3), activation='relu'),  # 卷积层3,卷积核3*3
    layers.Dropout(0.3),  
    
    layers.Flatten(),                       # Flatten层,连接卷积层与全连接层
    layers.Dense(128, activation='relu'),   # 全连接层,特征进一步提取
    layers.Dense(len(class_names))               # 输出层,输出预期结果
])

model.summary()  # 打印网络结构

运行结果:

python 复制代码
_________________________________________________________________
 Layer (type)                Output Shape              Param #
=================================================================
 rescaling (Rescaling)       (None, 224, 224, 3)       0

 conv2d (Conv2D)             (None, 222, 222, 16)      448

 average_pooling2d (AverageP  (None, 111, 111, 16)     0
 ooling2D)

 conv2d_1 (Conv2D)           (None, 109, 109, 32)      4640

 average_pooling2d_1 (Averag  (None, 54, 54, 32)       0
 ePooling2D)

 dropout (Dropout)           (None, 54, 54, 32)        0

 conv2d_2 (Conv2D)           (None, 52, 52, 64)        18496

 dropout_1 (Dropout)         (None, 52, 52, 64)        0

 flatten (Flatten)           (None, 173056)            0

 dense (Dense)               (None, 128)               22151296

 dense_1 (Dense)             (None, 2)                 258

=================================================================
Total params: 22,175,138
Trainable params: 22,175,138
Non-trainable params: 0
_________________________________________________________________

八、编译

在准备对模型进行训练之前,还需要再对其进行一些设置。以下内容是在模型的编译步骤中添加的:

  • 损失函数(loss):用于衡量模型在训练期间的准确率。
  • 优化器(optimizer):决定模型如何根据其看到的数据和自身的损失函数进行更新。
  • 指标(metrics):用于监控训练和测试步骤。以下示例使用了准确率,即被正确分类的图像的比率。
python 复制代码
# 设置初始学习率
initial_learning_rate = 0.001

lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
        initial_learning_rate, 
        decay_steps=10,      # 敲黑板!!!这里是指 steps,不是指epochs
        decay_rate=0.92,     # lr经过一次衰减就会变成 decay_rate*lr
        staircase=True)

# 将指数衰减学习率送入优化器
optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)

model.compile(optimizer=optimizer,
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

早停与保存最佳模型参数

python 复制代码
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping

epochs = 50

# 保存最佳模型参数
checkpointer = ModelCheckpoint('best_model.h5',
                                monitor='val_accuracy',
                                verbose=1,
                                save_best_only=True,
                                save_weights_only=True)

# 设置早停
earlystopper = EarlyStopping(monitor='val_accuracy', 
                             min_delta=0.001,
                             patience=20, 
                             verbose=1)

九、训练模型

python 复制代码
history = model.fit(train_ds,
                    validation_data=val_ds,
                    epochs=epochs,
                    callbacks=[checkpointer, earlystopper])

运行结果:

python 复制代码
Epoch 1/50
16/16 [==============================] - ETA: 0s - loss: 3.6308 - accuracy: 0.5000
Epoch 1: val_accuracy improved from -inf to 0.48684, saving model to best_model.h5
16/16 [==============================] - 7s 73ms/step - loss: 3.6308 - accuracy: 0.5000 - val_loss: 0.6932 - val_accuracy: 0.4868
Epoch 2/50
16/16 [==============================] - ETA: 0s - loss: 0.6951 - accuracy: 0.4880
Epoch 2: val_accuracy improved from 0.48684 to 0.50000, saving model to best_model.h5
16/16 [==============================] - 1s 40ms/step - loss: 0.6951 - accuracy: 0.4880 - val_loss: 0.6949 - val_accuracy: 0.5000
Epoch 3/50
15/16 [===========================>..] - ETA: 0s - loss: 0.6928 - accuracy: 0.4979
Epoch 3: val_accuracy did not improve from 0.50000
16/16 [==============================] - 1s 33ms/step - loss: 0.6927 - accuracy: 0.5060 - val_loss: 0.6932 - val_accuracy: 0.5000
Epoch 4/50
15/16 [===========================>..] - ETA: 0s - loss: 0.6922 - accuracy: 0.5553
Epoch 4: val_accuracy improved from 0.50000 to 0.51316, saving model to best_model.h5
16/16 [==============================] - 1s 41ms/step - loss: 0.6920 - accuracy: 0.5578 - val_loss: 0.6925 - val_accuracy: 0.5132
Epoch 5/50
15/16 [===========================>..] - ETA: 0s - loss: 0.6894 - accuracy: 0.5574
Epoch 5: val_accuracy improved from 0.51316 to 0.65789, saving model to best_model.h5
16/16 [==============================] - 1s 39ms/step - loss: 0.6890 - accuracy: 0.5697 - val_loss: 0.6891 - val_accuracy: 0.6579
Epoch 6/50
15/16 [===========================>..] - ETA: 0s - loss: 0.6883 - accuracy: 0.5340
Epoch 6: val_accuracy did not improve from 0.65789
16/16 [==============================] - 1s 33ms/step - loss: 0.6882 - accuracy: 0.5339 - val_loss: 0.6823 - val_accuracy: 0.6184
Epoch 7/50
15/16 [===========================>..] - ETA: 0s - loss: 0.6810 - accuracy: 0.6191
Epoch 7: val_accuracy did not improve from 0.65789
16/16 [==============================] - 1s 33ms/step - loss: 0.6805 - accuracy: 0.6155 - val_loss: 0.6774 - val_accuracy: 0.6316
Epoch 8/50
15/16 [===========================>..] - ETA: 0s - loss: 0.6737 - accuracy: 0.6043
Epoch 8: val_accuracy improved from 0.65789 to 0.71053, saving model to best_model.h5
16/16 [==============================] - 1s 39ms/step - loss: 0.6738 - accuracy: 0.5996 - val_loss: 0.6608 - val_accuracy: 0.7105
Epoch 9/50
15/16 [===========================>..] - ETA: 0s - loss: 0.6461 - accuracy: 0.6979
Epoch 9: val_accuracy did not improve from 0.71053
16/16 [==============================] - 1s 33ms/step - loss: 0.6424 - accuracy: 0.7012 - val_loss: 0.6200 - val_accuracy: 0.6974
Epoch 10/50
15/16 [===========================>..] - ETA: 0s - loss: 0.6148 - accuracy: 0.6979
Epoch 10: val_accuracy did not improve from 0.71053
16/16 [==============================] - 1s 34ms/step - loss: 0.6114 - accuracy: 0.6972 - val_loss: 0.6302 - val_accuracy: 0.6316
Epoch 11/50
15/16 [===========================>..] - ETA: 0s - loss: 0.5956 - accuracy: 0.7234
Epoch 11: val_accuracy improved from 0.71053 to 0.73684, saving model to best_model.h5
16/16 [==============================] - 1s 39ms/step - loss: 0.5968 - accuracy: 0.7191 - val_loss: 0.5779 - val_accuracy: 0.7368
Epoch 12/50
15/16 [===========================>..] - ETA: 0s - loss: 0.5442 - accuracy: 0.7723
Epoch 12: val_accuracy did not improve from 0.73684
16/16 [==============================] - 1s 33ms/step - loss: 0.5505 - accuracy: 0.7570 - val_loss: 0.6001 - val_accuracy: 0.6579
Epoch 13/50
15/16 [===========================>..] - ETA: 0s - loss: 0.5566 - accuracy: 0.7298
Epoch 13: val_accuracy improved from 0.73684 to 0.75000, saving model to best_model.h5
16/16 [==============================] - 1s 40ms/step - loss: 0.5581 - accuracy: 0.7251 - val_loss: 0.5442 - val_accuracy: 0.7500
Epoch 14/50
15/16 [===========================>..] - ETA: 0s - loss: 0.5194 - accuracy: 0.7617
Epoch 14: val_accuracy did not improve from 0.75000
16/16 [==============================] - 1s 33ms/step - loss: 0.5200 - accuracy: 0.7629 - val_loss: 0.5347 - val_accuracy: 0.7368
Epoch 15/50
15/16 [===========================>..] - ETA: 0s - loss: 0.5114 - accuracy: 0.7681
Epoch 15: val_accuracy did not improve from 0.75000
16/16 [==============================] - 1s 33ms/step - loss: 0.5048 - accuracy: 0.7769 - val_loss: 0.5161 - val_accuracy: 0.7500
Epoch 16/50
15/16 [===========================>..] - ETA: 0s - loss: 0.4836 - accuracy: 0.7830
Epoch 16: val_accuracy improved from 0.75000 to 0.76316, saving model to best_model.h5
16/16 [==============================] - 1s 40ms/step - loss: 0.4901 - accuracy: 0.7789 - val_loss: 0.5069 - val_accuracy: 0.7632
Epoch 17/50
15/16 [===========================>..] - ETA: 0s - loss: 0.4636 - accuracy: 0.7809
Epoch 17: val_accuracy did not improve from 0.76316
16/16 [==============================] - 1s 33ms/step - loss: 0.4585 - accuracy: 0.7888 - val_loss: 0.5071 - val_accuracy: 0.7500
Epoch 18/50
15/16 [===========================>..] - ETA: 0s - loss: 0.4717 - accuracy: 0.7723
Epoch 18: val_accuracy did not improve from 0.76316
16/16 [==============================] - 1s 34ms/step - loss: 0.4655 - accuracy: 0.7769 - val_loss: 0.5034 - val_accuracy: 0.7368
Epoch 19/50
15/16 [===========================>..] - ETA: 0s - loss: 0.4610 - accuracy: 0.8064
Epoch 19: val_accuracy did not improve from 0.76316
16/16 [==============================] - 1s 33ms/step - loss: 0.4567 - accuracy: 0.8088 - val_loss: 0.5440 - val_accuracy: 0.7368
Epoch 20/50
15/16 [===========================>..] - ETA: 0s - loss: 0.4547 - accuracy: 0.7872
Epoch 20: val_accuracy improved from 0.76316 to 0.78947, saving model to best_model.h5
16/16 [==============================] - 1s 40ms/step - loss: 0.4507 - accuracy: 0.7948 - val_loss: 0.4812 - val_accuracy: 0.7895
Epoch 21/50
15/16 [===========================>..] - ETA: 0s - loss: 0.4228 - accuracy: 0.8298
Epoch 21: val_accuracy did not improve from 0.78947
16/16 [==============================] - 1s 33ms/step - loss: 0.4238 - accuracy: 0.8287 - val_loss: 0.4926 - val_accuracy: 0.7632
Epoch 22/50
15/16 [===========================>..] - ETA: 0s - loss: 0.4460 - accuracy: 0.8125
Epoch 22: val_accuracy did not improve from 0.78947
16/16 [==============================] - 1s 33ms/step - loss: 0.4386 - accuracy: 0.8187 - val_loss: 0.4857 - val_accuracy: 0.7632
Epoch 23/50
15/16 [===========================>..] - ETA: 0s - loss: 0.4262 - accuracy: 0.8167
Epoch 23: val_accuracy did not improve from 0.78947
16/16 [==============================] - 1s 34ms/step - loss: 0.4204 - accuracy: 0.8227 - val_loss: 0.4718 - val_accuracy: 0.7632
Epoch 24/50
15/16 [===========================>..] - ETA: 0s - loss: 0.4196 - accuracy: 0.8277
Epoch 24: val_accuracy did not improve from 0.78947
16/16 [==============================] - 1s 33ms/step - loss: 0.4208 - accuracy: 0.8247 - val_loss: 0.5068 - val_accuracy: 0.7632
Epoch 25/50
15/16 [===========================>..] - ETA: 0s - loss: 0.4112 - accuracy: 0.8362
Epoch 25: val_accuracy did not improve from 0.78947
16/16 [==============================] - 1s 33ms/step - loss: 0.4118 - accuracy: 0.8347 - val_loss: 0.4658 - val_accuracy: 0.7895
Epoch 26/50
15/16 [===========================>..] - ETA: 0s - loss: 0.4005 - accuracy: 0.8298
Epoch 26: val_accuracy did not improve from 0.78947
16/16 [==============================] - 1s 34ms/step - loss: 0.3981 - accuracy: 0.8347 - val_loss: 0.4822 - val_accuracy: 0.7632
Epoch 27/50
15/16 [===========================>..] - ETA: 0s - loss: 0.4003 - accuracy: 0.8426
Epoch 27: val_accuracy did not improve from 0.78947
16/16 [==============================] - 1s 34ms/step - loss: 0.4038 - accuracy: 0.8406 - val_loss: 0.4756 - val_accuracy: 0.7763
Epoch 28/50
15/16 [===========================>..] - ETA: 0s - loss: 0.3884 - accuracy: 0.8511
Epoch 28: val_accuracy improved from 0.78947 to 0.80263, saving model to best_model.h5
16/16 [==============================] - 1s 40ms/step - loss: 0.3967 - accuracy: 0.8486 - val_loss: 0.4636 - val_accuracy: 0.8026
Epoch 29/50
15/16 [===========================>..] - ETA: 0s - loss: 0.4139 - accuracy: 0.8489
Epoch 29: val_accuracy did not improve from 0.80263
16/16 [==============================] - 1s 33ms/step - loss: 0.4091 - accuracy: 0.8486 - val_loss: 0.4735 - val_accuracy: 0.7763
Epoch 30/50
15/16 [===========================>..] - ETA: 0s - loss: 0.3857 - accuracy: 0.8617
Epoch 30: val_accuracy did not improve from 0.80263
16/16 [==============================] - 1s 34ms/step - loss: 0.3870 - accuracy: 0.8586 - val_loss: 0.4655 - val_accuracy: 0.7763
Epoch 31/50
15/16 [===========================>..] - ETA: 0s - loss: 0.3853 - accuracy: 0.8447
Epoch 31: val_accuracy did not improve from 0.80263
16/16 [==============================] - 1s 33ms/step - loss: 0.3908 - accuracy: 0.8347 - val_loss: 0.4688 - val_accuracy: 0.7763
Epoch 32/50
15/16 [===========================>..] - ETA: 0s - loss: 0.3814 - accuracy: 0.8596
Epoch 32: val_accuracy did not improve from 0.80263
16/16 [==============================] - 1s 34ms/step - loss: 0.3869 - accuracy: 0.8546 - val_loss: 0.4728 - val_accuracy: 0.7632
Epoch 33/50
15/16 [===========================>..] - ETA: 0s - loss: 0.3938 - accuracy: 0.8396
Epoch 33: val_accuracy did not improve from 0.80263
16/16 [==============================] - 1s 34ms/step - loss: 0.3887 - accuracy: 0.8446 - val_loss: 0.4798 - val_accuracy: 0.7763
Epoch 34/50
15/16 [===========================>..] - ETA: 0s - loss: 0.4032 - accuracy: 0.8542
Epoch 34: val_accuracy did not improve from 0.80263
16/16 [==============================] - 1s 34ms/step - loss: 0.3955 - accuracy: 0.8586 - val_loss: 0.4708 - val_accuracy: 0.7632
Epoch 35/50
15/16 [===========================>..] - ETA: 0s - loss: 0.3937 - accuracy: 0.8375
Epoch 35: val_accuracy did not improve from 0.80263
16/16 [==============================] - 1s 34ms/step - loss: 0.3865 - accuracy: 0.8426 - val_loss: 0.4695 - val_accuracy: 0.7632
Epoch 36/50
15/16 [===========================>..] - ETA: 0s - loss: 0.3883 - accuracy: 0.8447
Epoch 36: val_accuracy did not improve from 0.80263
16/16 [==============================] - 1s 34ms/step - loss: 0.3862 - accuracy: 0.8486 - val_loss: 0.4700 - val_accuracy: 0.7632
Epoch 37/50
15/16 [===========================>..] - ETA: 0s - loss: 0.3729 - accuracy: 0.8617
Epoch 37: val_accuracy did not improve from 0.80263
16/16 [==============================] - 1s 33ms/step - loss: 0.3767 - accuracy: 0.8586 - val_loss: 0.4685 - val_accuracy: 0.7632
Epoch 38/50
15/16 [===========================>..] - ETA: 0s - loss: 0.3831 - accuracy: 0.8479
Epoch 38: val_accuracy did not improve from 0.80263
16/16 [==============================] - 1s 34ms/step - loss: 0.3788 - accuracy: 0.8506 - val_loss: 0.4720 - val_accuracy: 0.7632
Epoch 39/50
15/16 [===========================>..] - ETA: 0s - loss: 0.3872 - accuracy: 0.8468
Epoch 39: val_accuracy did not improve from 0.80263
16/16 [==============================] - 1s 33ms/step - loss: 0.3872 - accuracy: 0.8466 - val_loss: 0.4648 - val_accuracy: 0.7895
Epoch 40/50
15/16 [===========================>..] - ETA: 0s - loss: 0.3800 - accuracy: 0.8489
Epoch 40: val_accuracy did not improve from 0.80263
16/16 [==============================] - 1s 34ms/step - loss: 0.3739 - accuracy: 0.8546 - val_loss: 0.4682 - val_accuracy: 0.7632
Epoch 41/50
15/16 [===========================>..] - ETA: 0s - loss: 0.3800 - accuracy: 0.8511
Epoch 41: val_accuracy did not improve from 0.80263
16/16 [==============================] - 1s 34ms/step - loss: 0.3813 - accuracy: 0.8486 - val_loss: 0.4649 - val_accuracy: 0.7895
Epoch 42/50
15/16 [===========================>..] - ETA: 0s - loss: 0.3727 - accuracy: 0.8617
Epoch 42: val_accuracy did not improve from 0.80263
16/16 [==============================] - 1s 34ms/step - loss: 0.3712 - accuracy: 0.8645 - val_loss: 0.4675 - val_accuracy: 0.7632
Epoch 43/50
15/16 [===========================>..] - ETA: 0s - loss: 0.3803 - accuracy: 0.8468
Epoch 43: val_accuracy did not improve from 0.80263
16/16 [==============================] - 1s 34ms/step - loss: 0.3830 - accuracy: 0.8486 - val_loss: 0.4672 - val_accuracy: 0.7632
Epoch 44/50
15/16 [===========================>..] - ETA: 0s - loss: 0.3648 - accuracy: 0.8745
Epoch 44: val_accuracy did not improve from 0.80263
16/16 [==============================] - 1s 34ms/step - loss: 0.3698 - accuracy: 0.8705 - val_loss: 0.4708 - val_accuracy: 0.7632
Epoch 45/50
15/16 [===========================>..] - ETA: 0s - loss: 0.3742 - accuracy: 0.8489
Epoch 45: val_accuracy did not improve from 0.80263
16/16 [==============================] - 1s 33ms/step - loss: 0.3695 - accuracy: 0.8546 - val_loss: 0.4683 - val_accuracy: 0.7632
Epoch 46/50
15/16 [===========================>..] - ETA: 0s - loss: 0.3792 - accuracy: 0.8447
Epoch 46: val_accuracy did not improve from 0.80263
16/16 [==============================] - 1s 33ms/step - loss: 0.3878 - accuracy: 0.8406 - val_loss: 0.4706 - val_accuracy: 0.7632
Epoch 47/50
15/16 [===========================>..] - ETA: 0s - loss: 0.3681 - accuracy: 0.8745
Epoch 47: val_accuracy did not improve from 0.80263
16/16 [==============================] - 1s 34ms/step - loss: 0.3639 - accuracy: 0.8745 - val_loss: 0.4708 - val_accuracy: 0.7632
Epoch 48/50
15/16 [===========================>..] - ETA: 0s - loss: 0.3771 - accuracy: 0.8489
Epoch 48: val_accuracy did not improve from 0.80263
16/16 [==============================] - 1s 34ms/step - loss: 0.3778 - accuracy: 0.8506 - val_loss: 0.4729 - val_accuracy: 0.7632
Epoch 48: early stopping
1/1 [==============================] - 0s 100ms/step

十、模型评估

python 复制代码
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']

loss = history.history['loss']
val_loss = history.history['val_loss']

epochs_range = range(len(loss))

plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

十一、指定图片预测

python 复制代码
# 加载效果最好的模型权重
model.load_weights('best_model.h5')
from PIL import Image
import numpy as np


img = Image.open("./data/test/nike/1.jpg")  #这里选择你需要预测的图片
image = tf.image.resize(img, [img_height, img_width])

img_array = tf.expand_dims(image, 0) #/255.0  # 记得做归一化处理(与训练集处理方式保持一致)

predictions = model.predict(img_array) # 这里选用你已经训练好的模型
print("预测结果为:",class_names[np.argmax(predictions)])

运行结果:

python 复制代码
预测结果为: nike

十二、总结

本周通过学习TensorFlow实现运动鞋品牌识别;首先学习设置动态学习率,在训练神经网络时动态地降低学习率,可以帮助优化算法更有效地收敛到全局最小值,从而提高模型的性能。其次就是学习早停与保存最佳模型参数,模型在指定epoch次都没有提升的情况下,可以提前停止训练。

相关推荐
小哥谈1 小时前
论文解析篇 | YOLOv12:以注意力机制为核心的实时目标检测算法
人工智能·深度学习·yolo·目标检测·机器学习·计算机视觉
水龙吟啸1 小时前
从零开始搭建深度学习大厦系列-2.卷积神经网络基础(5-9)
人工智能·pytorch·深度学习·cnn·mxnet
饕餮争锋2 小时前
设计模式笔记_创建型_建造者模式
笔记·设计模式·建造者模式
HollowKnightZ2 小时前
论文阅读笔记:VI-Net: Boosting Category-level 6D Object Pose Estimation
人工智能·深度学习·计算机视觉
yzx9910132 小时前
AI大模型平台
大数据·人工智能·深度学习·机器学习
萝卜青今天也要开心2 小时前
2025年上半年软件设计师考后分享
笔记·学习
吃货界的硬件攻城狮3 小时前
【STM32 学习笔记】SPI通信协议
笔记·stm32·学习
蓝染yy3 小时前
Apache
笔记
Better Rose4 小时前
人工智能与机器学习暑期科研项目招募(可发表论文)
人工智能·深度学习·机器学习·论文撰写
lxiaoj1114 小时前
Python文件操作笔记
笔记·python