第T8周:猫狗识别

第T8周:猫狗识别

tf.config.list_physical_devices("GPU"),用于检测当前系统是否有可用的 GPU,并将结果存入 gpus 变量。如果系统检测到 GPU,代码会选择第一块 GPU(gpu0 = gpus[0]),然后调用tf.config.experimental.set_memory_growth(gpu0, True) 来启用 GPU。

设定 batch_size=8,即每次训练时取 8 张图片进行计算。

设定图像大小 224x224,这样所有加载的图片都会被调整到该尺寸,以确保模型输入维度一致。

seed=12:固定随机种子,确保数据划分不会因多次运行而变化。image_size=(img_height, img_width):调整所有图片大小为 224x224。batch_size=batch_size:每批次加载 8 张图片。

数据集中共有3400 张图片,分别属于2个类别。其中,2720张作为训练集,680张作为验证集。

这个批次包含 8 张图片(因为 batch_size=8),每张图片的尺寸是 224x224。 图片有 3 个通道。

Label_batch是形状(8,)的张量,这些标签对应8张图片

cache(),将数据缓存在内存中,提高训练速度

shuffle(1000):打乱训练数据,缓冲区大小是 1000

prefetch(buffer_size=AUTOTUNE):异步加载数据,加速训练过程(AUTOTUNE 会自动选择合适的预取大小)

layers.Rescaling(1./255):把像素值从 [0, 255] 缩放到 [0.0, 1.0]

map(lambda x, y: ...):对数据集中的每张图片 x 应用归一化操作,标签 y 保持不变

从 val_ds 中取出一个 batch(默认是 (batch_size, height, width, 3))

多分类任务的 VGG16 卷积神经网络,保留了 VGG16 的经典结构(13 个卷积层 + 3 个全连接层),输出为 nb_classes 类的 softmax 结果。

输入图像的 shape 是 (img_width, img_height, 3),支持 RGB 彩图。

每个 block 都由若干个 3x3 卷积层(带 ReLU 激活),一个 2x2 最大池化层

每个卷积层都使用 'same' padding 保证输出尺寸一致,池化后尺寸减半。

Flatten(),把多维 feature map 展平成一维向量。

两个 Dense(4096) 层,经典的全连接层(重参数)。

Dense(nb_classes, activation='softmax'),输出最终的分类概率。

img_width, img_height 是图像的宽和高(为 224x224 默认的 VGG 输入尺寸)。


tqdm 是一个进度条库,显示每轮训练/验证的进度。

总共训练10 个 epochs,初始学习率设置为 0.0001。

每轮将学习率乘以 0.92,手动设置给模型的优化器。

遍历 train_ds,对每一个 batch 使用 model.train_on_batch() 进行训练。把最后一个 batch 的 loss 和 accuracy 存进历史列表。保存最后一个 batch 的验证指标。

Training and Validation Accuracy

随着训练的进行,两条线都持续上升,验证准确率与训练准确率接近,

模型在训练集与验证集上都学习得不错。

Training and Validation Loss

两条线都持续下降,并且非常接近,说明模型在两个数据集上都表现出良好的收敛趋势,没有过拟合现象。

相关推荐
西猫雷婶34 分钟前
python学智能算法(十九)|SVM基础概念-超平面
开发语言·人工智能·python·深度学习·算法·机器学习·支持向量机
斟的是酒中桃2 小时前
基于YOLOv8的火灾智能检测系统设计与实现
人工智能·深度学习·yolo·pyqt
算法_小学生8 小时前
Huber Loss(胡贝损失)详解:稳健回归的秘密武器 + Python实现
人工智能·python·深度学习·机器学习
绝顶大聪明10 小时前
【深度学习】神经网络 批量标准化-part6
人工智能·深度学习·神经网络
DAWN_T1711 小时前
神经网络——卷积层
人工智能·pytorch·深度学习·神经网络·机器学习·计算机视觉·cnn
墨尘游子12 小时前
3-大语言模型—理论基础:生成式预训练语言模型GPT(代码“活起来”)
人工智能·python·gpt·深度学习·神经网络·语言模型·自然语言处理
Blossom.11812 小时前
基于深度学习的语音识别:从音频信号到文本转录
人工智能·深度学习·线性代数·机器学习·计算机视觉·音视频·语音识别
格林威12 小时前
Baumer工业相机堡盟工业相机如何通过YoloV8深度学习模型实现动物分类(C#源码,UI界面版)
人工智能·深度学习·数码相机·yolo·计算机视觉·c#·视觉检测
强哥之神12 小时前
深入解析 vLLM 分布式推理与部署策略
深度学习·语言模型·架构·llm·transformer·vllm
橡晟15 小时前
4.循环结构:让电脑做重复的事情
人工智能·python·深度学习·机器学习