Tensorflow2.0笔记 - 使用卷积神经网络层做CIFA100数据集训练(类VGG13)

本笔记记录CNN做CIFAR100数据集的训练相关内容,代码中使用了类似VGG13的网络结构,做了两个Sequetial(CNN和全连接层),没有用Flatten层而是用reshape操作做CNN和全连接层的中转操作。由于网络层次较深,参数量相比之前的网络多了不少,因此只做了10次epoch(RTX4090),没有继续跑了,最终准确率大概在33.8%左右。

复制代码
import os
import time
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers, Sequential, metrics, Input

os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
#tf.random.set_seed(12345)
tf.__version__

#如果下载很慢,可以使用迅雷下载到本地,迅雷的链接也可以直接用官网URL:
#      https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz
#下载好后,将cifar-100.python.tar.gz放到 .keras\datasets 目录下(我的环境是C:\Users\Administrator\.keras\datasets)
# 参考:https://blog.csdn.net/zy_like_study/article/details/104219259
(x_train,y_train), (x_test, y_test) = datasets.cifar100.load_data()
print("Train data shape:", x_train.shape)
print("Train label shape:", y_train.shape)
print("Test data shape:", x_test.shape)
print("Test label shape:", y_test.shape)

def preprocess(x, y):
    x = tf.cast(x, dtype=tf.float32) / 255.
    y = tf.cast(y, dtype=tf.int32)
    return x,y

y_train = tf.squeeze(y_train, axis=1)
y_test = tf.squeeze(y_test, axis=1)

batch_size = 128
train_db = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_db = train_db.shuffle(1000).map(preprocess).batch(batch_size)

test_db = tf.data.Dataset.from_tensor_slices((x_test, y_test))
test_db = test_db.map(preprocess).batch(batch_size)

sample = next(iter(train_db))
print("Train data sample:", sample[0].shape, sample[1].shape, 
         tf.reduce_min(sample[0]), tf.reduce_max(sample[0]))


#创建CNN网络,总共4个unit,每个unit主要是两个卷积层和Max Pooling池化层
cnn_layers = [
    #unit 1
    layers.Conv2D(64, kernel_size=[3,3], padding='same', activation='relu'),
    layers.Conv2D(64, kernel_size=[3,3], padding='same', activation='relu'),
    #layers.MaxPool2D(pool_size=[2,2], strides=2, padding='same'),
    layers.MaxPool2D(pool_size=[2,2], strides=2),

    #unit 2
    layers.Conv2D(128, kernel_size=[3,3], padding='same', activation='relu'),
    layers.Conv2D(128, kernel_size=[3,3], padding='same', activation='relu'),
    #layers.MaxPool2D(pool_size=[2,2], strides=2, padding='same'),
    layers.MaxPool2D(pool_size=[2,2], strides=2),

    #unit 3
    layers.Conv2D(256, kernel_size=[3,3], padding='same', activation='relu'),
    layers.Conv2D(256, kernel_size=[3,3], padding='same', activation='relu'),
    #layers.MaxPool2D(pool_size=[2,2], strides=2, padding='same'),
    layers.MaxPool2D(pool_size=[2,2], strides=2),

    #unit 4
    layers.Conv2D(512, kernel_size=[3,3], padding='same', activation='relu'),
    layers.Conv2D(512, kernel_size=[3,3], padding='same', activation='relu'),
    #layers.MaxPool2D(pool_size=[2,2], strides=2, padding='same'),
    layers.MaxPool2D(pool_size=[2,2], strides=2),

    #unit 5
    layers.Conv2D(512, kernel_size=[3,3], padding='same', activation='relu'),
    layers.Conv2D(512, kernel_size=[3,3], padding='same', activation='relu'),
    #layers.MaxPool2D(pool_size=[2,2], strides=2, padding='same'),
    layers.MaxPool2D(pool_size=[2,2], strides=2),
]


def main():
    #[b, 32, 32, 3] => [b, 1, 1, 512]
    cnn_net = Sequential(cnn_layers)
    cnn_net.build(input_shape=[None, 32, 32, 3])
    
    #测试一下卷积层的输出
    #x = tf.random.normal([4, 32, 32, 3])
    #out = cnn_net(x)
    #print(out.shape)

    #创建全连接层, 输出为100分类
    fc_net = Sequential([
        layers.Dense(256, activation='relu'),
        layers.Dense(128, activation='relu'),
        layers.Dense(100, activation=None),
    ])
    fc_net.build(input_shape=[None, 512])

    #设置优化器
    optimizer = optimizers.Adam(learning_rate=1e-4)

    #记录cnn层和全连接层所有可训练参数, 实现的效果类似list拼接,比如
    # [1, 2] + [3, 4] => [1, 2, 3, 4]
    variables = cnn_net.trainable_variables + fc_net.trainable_variables
    #进行训练
    num_epoches = 10
    for epoch in range(num_epoches):
        for step, (x,y) in enumerate(train_db):
            with tf.GradientTape() as tape:
                #[b, 32, 32, 3] => [b, 1, 1, 512]
                out = cnn_net(x)
                #flatten打平 => [b, 512]
                out = tf.reshape(out, [-1, 512])
                #使用全连接层做100分类logits输出
                #[b, 512] => [b, 100]
                logits = fc_net(out)
                #标签做one_hot encoding
                y_onehot = tf.one_hot(y, depth=100)
                #计算损失
                loss = tf.losses.categorical_crossentropy(y_onehot, logits, from_logits=True)
                loss = tf.reduce_mean(loss)
            #计算梯度
            grads = tape.gradient(loss, variables)
            #更新参数
            optimizer.apply_gradients(zip(grads, variables))

            if (step % 100 == 0):
                print("Epoch[", epoch + 1, "/", num_epoches, "]: step-", step, " loss:", float(loss))
        #进行验证
        total_samples = 0
        total_correct = 0
        for x,y in test_db:
            out = cnn_net(x)
            out = tf.reshape(out, [-1, 512])
            logits = fc_net(out)
            prob = tf.nn.softmax(logits, axis=1)
            pred = tf.argmax(prob, axis=1)
            pred = tf.cast(pred, dtype=tf.int32)
            correct = tf.cast(tf.equal(pred, y), dtype=tf.int32)
            correct = tf.reduce_sum(correct)

            total_samples += x.shape[0]
            total_correct += int(correct)

        #统计准确率
        acc = total_correct / total_samples
        print("Epoch[", epoch + 1, "/", num_epoches, "]: accuracy:", acc)
if __name__ == '__main__':
    main()

运行结果:

相关推荐
测试员周周6 小时前
【Appium 系列】第16节-WebView-H5上下文切换 — 混合应用的自动化难点
运维·开发语言·人工智能·功能测试·appium·自动化·测试用例
测试19986 小时前
软件测试 - 单元测试总结
自动化测试·软件测试·python·测试工具·职场和发展·单元测试·测试用例
K姐研究社8 小时前
怎么用AI制作电商口播视频,开拍APP一键生成
人工智能·音视频
LaughingZhu8 小时前
Product Hunt 每日热榜 | 2026-05-21
前端·人工智能·经验分享·chatgpt·html
辰海Coding8 小时前
MiniSpring框架学习笔记-解决循环依赖的简化IoC容器
笔记·学习
曲幽8 小时前
我用了FastApiAdmin后,连夜把踩过的坑都整理出来了
redis·python·postgresql·vue3·fastapi·web·sqlalchemy·admin·fastapiadmin
晓梦林8 小时前
cp520靶场学习笔记
android·笔记·学习
传说故事8 小时前
【论文阅读】MotuBrain: An Advanced World Action Model for Robot Control
论文阅读·人工智能·具身智能·wam
北京耐用通信9 小时前
全域适配工业场景耐达讯自动化Modbus TCP 转 PROFIBUS 网关轻松实现以太网与现场总线互通
网络·人工智能·网络协议·自动化·信息与通信
火山引擎开发者社区9 小时前
TRAE × 火山引擎 Supabase:为你的 AI 应用装上“数据引擎”
人工智能