三、CV_VGGnet

三、VGGnet

1.VGG网络架构

VGG可以看成是加深版的AlexNet,整个网络由卷积层和全连接层叠加而成,和AlexNet不同的是,VGG中使用的都是小尺寸的卷积核(3 ×\times× 3)。

VGGNet使用的全部都是3 ×\times× 3的小卷积核和2 ×\times× 2的池化核,通过不断加深网络来提升性能。VGG可通过重复使用简单的基础块来构建深度模型

在tf.keras中实现VGG模型,首先来实现VGG块,它的组成规律是:连续使用多个相同的填充为1、卷积核大小为3 ×\times× 3的卷积层后接上一个步幅为2,窗口形状为2 ×\times× 2的最大池化层。卷积层保持输入高的宽不变,而池化层则对其减半。我们使用vgg_block函数来实现这个基础的VGG块,它可以指定卷积层的数量num_convs和每层的卷积核个数num_filters.

python 复制代码
# 定义VGG网络中的卷积块:卷积层的个数,卷积层中卷积核的个数
def vgg_block(num_convs, num_filters):
    blk = tf.keras.models.Squential()
    for _ in range(num_convs):
        blk.add(tf.keras.layers.Conv2D(num_filters, kernel_size = 3, padding = 'same', activation = 'relu'))

        # 卷积块最后一个是最大池化,窗口大小为2*2,步长为2
        blk.add(tf.keras.layers.MaxPool2D(pool_size = 2, strides = 2))
        return blk

VGG16网络有五个卷积块,前2块使用两个卷积层,而后三块使用三个卷积层。第一块的输出通道是64,之后每次对输出通道数翻倍,直到变为512

python 复制代码
# 定义5个卷积块,指明每个卷积层个数及相应的卷积核个数
conv_arch = ((2, 64), (2, 128), (3, 256), (3, 512), (3, 512))

这个网络使用了13个卷积层和3个全连接层,通过指定conv_arch得到模型架构后构建VGG16

python 复制代码
def vgg(conv_arch):
    # 构建序列模型
    net = tf.keras.models.Squential()
    # 根据conv_arch生成卷积部分
    for (num_convs, num_filters) in conv_arch:
        net.add(vgg_block(num_convs, num_filters))

    # 卷积块序列后添加全连接层
    net.add(tf.keras.models.Squential([
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(4096, activation = 'relu'),
        tf.keras.layers.Dropout(0.5),
        tf.keras.layers.Dense(4096, activation = 'relu'),
        tf.keras.layers.Dropout(0.5),
        tf.keras.layers.Dense(10, activation = 'Softmax')
    ]))

    return net

# 网络实例化
net = vgg(conv_arch)

构造一个高宽均为224的单通道数据样本来看一下模型架构

python 复制代码
X = tf.random.uniform((1, 224, 224, 1))
y = net(X)

net.summary() # 查看网络形状

2.手写数字识别

读取数据时需将图像高和宽扩大到VggNet使用的图像高和宽224,这个通过tf.image.resize_with_pad来实现

(1)数据读取

获取数据并进行维度调整

python 复制代码
import numpy as np
(train_images,  train_labels), (test_images, test_labels) = mnist.load_data()

train_images = np.reshape(train_images, (train_images.shape[0], train_images.shape[1], train_images.shape[2], 1))
test_images = np.reshape(test_images, (test_images.shape[0], test_images.shape[1], test_images.shape[2], 1))

定义两个方法获取部分数据,并将图像调整为224*224大小,进行模型训练

python 复制代码
# 定义两个样本随机抽取部分样本演示
# 获取训练集数据
def get_train(size):
    index = np.random.randint(0, np.shape(train_images)[0], size)

    resize_images = tf.image.resize_with_pad(train_images[index], 224, 224, )

    return resize_images.numpy() ,train_labels[index]

# 获取测试集数据
def get_test(size):
    index = np.random.randint(0, np.shape(test_images)[0], size)

    resize_images = tf.image.resize_witn_pad(test_images[index], 224, 224, )

    return resize_images.numpy(), test_labels[index]

调用上述两个方法,获取参与模型训练及测试的数据集

python 复制代码
train_images, train_labels = get_train(256)
test_images, test_labels = get_test(128)

(2)模型编译

python 复制代码
# 指定优化器,损失函数和评价指标
optimizer = tf.keras.optimizers.SGD(learning_rate = 0.01, momentum = 0.0)

net.compile(
    optimizer = optimizer,
    loss = 'sparse_categorical_crossentropy',
    metrics = ['accuracy']
)

(3)模型训练

python 复制代码
net.fit(train_images, traim_labels, batch_size = 128, epoch = 3, verbose = 1, validation_split = 0.1)

(4)模型评估

python 复制代码
net.evaluate(test_images, test_labels, verbose = 1)
相关推荐
带娃的IT创业者1 分钟前
US Cities Are Axing Flock Safety Surveillance Technology: 当监控之眼被蒙上,我们在守护什么?
人工智能·智慧城市·数据治理·公共安全·隐私保护·监控技术·技术伦理
愚公搬代码2 分钟前
【愚公系列】《AI漫剧创作一本通》004-剧本拆解,把小说改编为可落地的脚本(爆款AI漫剧,从选择合适的小说开始)
人工智能·ai漫剧
玩转单片机与嵌入式5 分钟前
学习嵌入式AI(TInyML),只需掌握这点python基础即可!
人工智能·python·学习
mit6.8249 分钟前
从 Vibe Coding 到 Agentic Engineering
人工智能
kay_54514 分钟前
YOLO26改进| 主干网络 | 提升长距离特征建模与全局上下文理解能力【CVPR】
人工智能·目标检测·计算机视觉·目标跟踪·论文·yolo26·yolo26改进
ting945200017 分钟前
Huddle01 VMs 支持 AI 助手一键部署,MCP 协议重塑云基础设施管理
人工智能
地理探险家18 分钟前
我整理了一份动物数据集合集,做深度学习的直接省掉80%时间(附使用建议)
人工智能·深度学习·数据集·图像·动物
硅谷秋水24 分钟前
语言智体的Harness工程:Harness层作为控制、智体和运行时
人工智能·深度学习·机器学习·语言模型
老黄编程27 分钟前
大型工地实时数据处理与三维重构系统方案(极简中心化部署版)
人工智能·数码相机·计算机视觉·大数据处理·入侵检测·三维重构
狙击主力投资工具34 分钟前
26年5月4日本周复盘总结,好票机会,下周大盘方向,热门板块方向,操作建议,实用干货
人工智能·区块链