Day43 Python打卡训练营

作业:

kaggle找到一个图像数据集,用cnn网络进行训练并且用grad-cam做可视化

进阶:并拆分成多个文件

选取Kaggle上的CIFAR-10数据集进行CNN训练,并使用Grad-CAM进行可视化,代码将拆分为多个文件以保持模块化。CIFAR-10是一个包含60,000张32x32彩色图像的数据集,分为10个类别。

项目结构

复制代码
cifar10_cnn_gradcam/
├── data_loader.py         # 数据加载和预处理
├── model.py              # CNN模型定义
├── gradcam.py            # Grad-CAM实现
├── train.py              # 模型训练逻辑
├── visualize.py          # 可视化Grad-CAM结果
├── main.py               # 主执行脚本
└── requirements.txt      # 依赖库

1. 数据加载(data_loader.py)

此文件负责加载和预处理CIFAR-10数据集,并进行训练、验证、测试集划分。

复制代码
import tensorflow as tf
from sklearn.model_selection import train_test_split

def load_cifar10_data():
    # 加载CIFAR-10数据集
    (x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
    
    # 归一化像素值到[0, 1]
    x_train = x_train.astype('float32') / 255.0
    x_test = x_test.astype('float32') / 255.0
    
    # 将训练集进一步拆分为训练和验证集(80%训练,20%验证)
    x_train, x_val, y_train, y_val = train_test_split(
        x_train, y_train, test_size=0.2, random_state=42
    )
    
    # 类名
    class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer',
                   'dog', 'frog', 'horse', 'ship', 'truck']
    
    return (x_train, y_train), (x_val, y_val), (x_test, y_test), class_names

2. 模型定义 (model.py)

此文件定义一个简单的CNN模型,适合CIFAR-10分类任务。

复制代码
import tensorflow as tf
from tensorflow.keras import layers, models

def build_cnn_model():
    model = models.Sequential([
        layers.Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3), padding='same'),
        layers.MaxPooling2D((2, 2)),
        layers.Conv2D(64, (3, 3), activation='relu', padding='same'),
        layers.MaxPooling2D((2, 2)),
        layers.Conv2D(128, (3, 3), activation='relu', padding='same'),
        layers.MaxPooling2D((2, 2)),
        layers.Flatten(),
        layers.Dense(128, activation='relu'),
        layers.Dropout(0.5),
        layers.Dense(10, activation='softmax')
    ])
    
    model.compile(optimizer='adam',
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])
    
    return model

3. Grad-CAM实现 (gradcam.py)

此文件实现Grad-CAM算法,用于生成CNN的注意力热图。

复制代码
import tensorflow as tf
import numpy as np
import cv2

class GradCAM:
    def __init__(self, model, layer_name):
        self.model = model
        self.layer_name = layer_name
        self.grad_model = tf.keras.models.Model(
            [model.inputs], [model.get_layer(layer_name).output, model.output]
        )

    def generate_heatmap(self, image, class_idx):
        image = tf.cast(image, tf.float32)
        with tf.GradientTape() as tape:
            conv_output, predictions = self.grad_model(image)
            loss = predictions[:, class_idx]

        grads = tape.gradient(loss, conv_output)
        pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
        conv_output = conv_output[0]
        heatmap = tf.reduce_mean(tf.multiply(conv_output, pooled_grads), axis=-1)
        heatmap = tf.maximum(heatmap, 0) / tf.math.reduce_max(heatmap)
        return heatmap.numpy()

    def superimpose_heatmap(self, image, heatmap, alpha=0.4):
        heatmap = cv2.resize(heatmap, (image.shape[1], image.shape[0]))
        heatmap = np.uint8(255 * heatmap)
        heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
        image = np.uint8(255 * image)
        superimposed_img = heatmap * alpha + image
        superimposed_img = np.clip(superimposed_img, 0, 255).astype(np.uint8)
        return superimposed_img

4. 模型训练 (train.py)

此文件包含训练逻辑,使用数据增强以提高模型鲁棒性。

复制代码
import tensorflow as tf
from tensorflow.keras import layers
from model import build_cnn_model

def train_model(x_train, y_train, x_val, y_val, epochs=25, batch_size=32):
    model = build_cnn_model()
    
    # 数据增强
    data_augmentation = tf.keras.Sequential([
        layers.RandomFlip("horizontal"),
        layers.RandomRotation(0.1),
        layers.RandomZoom(0.1),
    ])
    
    # 训练模型
    history = model.fit(
        data_augmentation(x_train), y_train,
        validation_data=(x_val, y_val),
        epochs=epochs,
        batch_size=batch_size,
        verbose=1
    )
    
    model.save('cifar10_cnn_model.h5')
    return model, history

5. 可视化Grad-CAM结果 (visualize.py)

此文件负责生成和保存Grad-CAM可视化结果。

复制代码
import numpy as np
import matplotlib.pyplot as plt
from gradcam import GradCAM

def visualize_gradcam(model, x_test, y_test, class_names, num_images=5):
    gradcam = GradCAM(model, layer_name='conv2d_2')  # 选择最后一层卷积层
    
    plt.figure(figsize=(15, 10))
    for i in range(num_images):
        img = x_test[i:i+1]
        true_label = y_test[i][0]
        pred = model.predict(img)
        pred_label = np.argmax(pred, axis=1)[0]
        
        # 生成热图
        heatmap = gradcam.generate_heatmap(img, pred_label)
        superimposed_img = gradcam.superimpose_heatmap(img[0], heatmap)
        
        # 可视化
        plt.subplot(num_images, 3, i*3 + 1)
        plt.imshow(img[0])
        plt.title(f'True: {class_names[true_label]}')
        plt.axis('off')
        
        plt.subplot(num_images, 3, i*3 + 2)
        plt.imshow(heatmap, cmap='jet')
        plt.title('Heatmap')
        plt.axis('off')
        
        plt.subplot(num_images, 3, i*3 + 3)
        plt.imshow(superimposed_img)
        plt.title(f'Pred: {class_names[pred_label]}')
        plt.axis('off')
    
    plt.tight_layout()
    plt.savefig('gradcam_visualization.png')
    plt.close()

6. 主执行脚本 (main.py)

此文件协调整个流程,调用其他模块执行数据加载、训练和可视化。

复制代码
from data_loader import load_cifar10_data
from train import train_model
from visualize import visualize_gradcam

def main():
    # 加载数据
    (x_train, y_train), (x_val, y_val), (x_test, y_test), class_names = load_cifar10_data()
    
    # 训练模型
    model, history = train_model(x_train, y_train, x_val, y_val, epochs=25, batch_size=32)
    
    # 可视化Grad-CAM
    visualize_gradcam(model, x_test, y_test, class_names, num_images=5)
    
    # 评估模型
    test_loss, test_acc = model.evaluate(x_test, y_test, verbose=0)
    print(f"Test accuracy: {test_acc:.4f}")

if __name__ == "__main__":
    main()

7. 依赖文件 (requirements.txt)

列出项目所需的Python库。

复制代码
tensorflow==2.10.0 numpy scikit-learn matplotlib opencv-python

@浙大疏锦行

相关推荐
QQ676580085 分钟前
基于 PyTorch 的 VGG16 深度学习人脸识别检测系统的实现+ui界面
人工智能·pytorch·python·深度学习·ui·人脸识别
木木黄木木6 分钟前
Python制作史莱姆桌面宠物!可爱的
开发语言·python·宠物
exploration-earth30 分钟前
本地优先的状态管理与工具选型策略
开发语言·前端·javascript
胖哥真不错34 分钟前
Python基于方差-协方差方法实现投资组合风险管理的VaR与ES模型项目实战
python·毕业设计·课程设计·方差-协方差方法·投资组合风险管理·var与es模型
慧一居士38 分钟前
flask功能使用总结和完整示例
python
苦学编程的谢1 小时前
Java网络编程API 1
java·开发语言·网络
大模型铲屎官1 小时前
【深度学习-Day 23】框架实战:模型训练与评估核心环节详解 (MNIST实战)
人工智能·pytorch·python·深度学习·大模型·llm·mnist
寒山李白1 小时前
Java 依赖注入、控制反转与面向切面:面试深度解析
java·开发语言·面试·依赖注入·控制反转·面向切面
梓仁沐白1 小时前
【Kotlin】数字&字符串&数组&集合
android·开发语言·kotlin
Java菜鸟、2 小时前
设计模式(代理设计模式)
java·开发语言·设计模式