打卡day52

简单cnn 借助调参指南进一步提高精度

基础CNN模型代码

python 复制代码
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.datasets import cifar10
from tensorflow.keras.utils import to_categorical

# 加载数据
(train_images, train_labels), (test_images, test_labels) = cifar10.load_data()

# 数据预处理
train_images = train_images.astype('float32') / 255
test_images = test_images.astype('float32') / 255
train_labels = to_categorical(train_labels)
test_labels = to_categorical(test_labels)

# 基础CNN模型
model = models.Sequential([
    layers.Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3)),
    layers.MaxPooling2D((2, 2)),
    layers.Conv2D(64, (3, 3), activation='relu'),
    layers.MaxPooling2D((2, 2)),
    layers.Conv2D(64, (3, 3), activation='relu'),
    layers.Flatten(),
    layers.Dense(64, activation='relu'),
    layers.Dense(10, activation='softmax')
])

model.compile(optimizer='adam',
              loss='categorical_crossentropy',
              metrics=['accuracy'])

history = model.fit(train_images, train_labels, 
                    epochs=10, 
                    batch_size=64,
                    validation_data=(test_images, test_labels))

改进方法

增加模型复杂度

python 复制代码
model = models.Sequential([
    layers.Conv2D(64, (3, 3), activation='relu', input_shape=(32, 32, 3), padding='same'),
    layers.BatchNormalization(),
    layers.Conv2D(64, (3, 3), activation='relu', padding='same'),
    layers.BatchNormalization(),
    layers.MaxPooling2D((2, 2)),
    layers.Dropout(0.25),
    
    layers.Conv2D(128, (3, 3), activation='relu', padding='same'),
    layers.BatchNormalization(),
    layers.Conv2D(128, (3, 3), activation='relu', padding='same'),
    layers.BatchNormalization(),
    layers.MaxPooling2D((2, 2)),
    layers.Dropout(0.25),
    
    layers.Conv2D(256, (3, 3), activation='relu', padding='same'),
    layers.BatchNormalization(),
    layers.Conv2D(256, (3, 3), activation='relu', padding='same'),
    layers.BatchNormalization(),
    layers.MaxPooling2D((2, 2)),
    layers.Dropout(0.25),
    
    layers.Flatten(),
    layers.Dense(512, activation='relu'),
    layers.BatchNormalization(),
    layers.Dropout(0.5),
    layers.Dense(10, activation='softmax')
])

优化器调参

python 复制代码
from tensorflow.keras.optimizers import Adam

optimizer = Adam(learning_rate=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-07)
model.compile(optimizer=optimizer,
              loss='categorical_crossentropy',
              metrics=['accuracy'])

数据增强

python 复制代码
from tensorflow.keras.preprocessing.image import ImageDataGenerator

datagen = ImageDataGenerator(
    rotation_range=15,
    width_shift_range=0.1,
    height_shift_range=0.1,
    horizontal_flip=True,
    zoom_range=0.1
)
datagen.fit(train_images)

history = model.fit(datagen.flow(train_images, train_labels, batch_size=64),
                    epochs=50,
                    validation_data=(test_images, test_labels))

早停和模型检查点

python 复制代码
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint

callbacks = [
    EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True),
    ModelCheckpoint('best_model.h5', monitor='val_accuracy', save_best_only=True)
]

history = model.fit(..., callbacks=callbacks, epochs=100)
相关推荐
LateFrames18 分钟前
用 【C# + Winform + MediaPipe】 实现人脸468点识别
python·c#·.net·mediapipe
人工干智能3 小时前
科普:Python 中,字典的“动态创建键”特性
开发语言·python
一条星星鱼3 小时前
深度学习是如何收敛的?梯度下降算法原理详解
人工智能·深度学习·算法
开心-开心急了6 小时前
主窗口(QMainWindow)如何放入文本编辑器(QPlainTextEdit)等继承自QWidget的对象--(重构版)
python·ui·pyqt
moshumu18 小时前
局域网访问Win11下的WSL中的jupyter notebook
ide·python·深度学习·神经网络·机器学习·jupyter
计算机毕设残哥8 小时前
基于Hadoop+Spark的人体体能数据分析与可视化系统开源实现
大数据·hadoop·python·scrapy·数据分析·spark·dash
图学习的小张10 小时前
Windows安装mamba全流程(全网最稳定最成功)
人工智能·windows·深度学习·语言模型
编程指南针10 小时前
2026新选题-基于Python的老年病医疗数据分析系统的设计与实现(数据采集+可视化分析)
开发语言·python·病历分析·医疗病历分析
王哥儿聊AI11 小时前
告别人工出题!PromptCoT 2.0 让大模型自己造训练难题,7B 模型仅用合成数据碾压人工数据集效果!
人工智能·深度学习·算法·机器学习·软件工程
reasonsummer11 小时前
【办公类-116-01】20250929家长会PPT(Python快速批量制作16:9PPT相册,带文件名,照片横版和竖版)
java·数据库·python·powerpoint