python43天

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
from data_loader import create_data_loaders
from model import CNNModel
from train import train_model
from visualize import visualize_grad_cam
import os
import glob
import warnings

def main():
    # 配置参数
    config = {
        'data_dir': '/kaggle/input/dogs-vs-cats',  # Kaggle猫狗数据集路径
        'batch_size': 32,
        'num_epochs': 5,
        'learning_rate': 0.001,
        'img_size': 224,
        'device': torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
        'best_model_path': 'best_model.pth'
    }
    
    print(f"使用设备: {config['device']}")
    
    # 加载数据
    print("加载数据...")
    try:
        train_loader, test_loader, class_names = create_data_loaders(
            data_dir=config['data_dir'],
            batch_size=config['batch_size'],
            img_size=config['img_size']
        )
        print(f"发现 {len(class_names)} 个类别: {class_names}")
    except Exception as e:
        print(f"数据加载失败: {e}")
        return

    # 初始化模型
    print("初始化模型...")
    model = CNNModel(num_classes=len(class_names), pretrained=True)
    model = model.to(config['device'])
    
    # 设置优化器和损失函数
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=config['learning_rate'])

    # 训练模型
    print(f"开始训练 ({config['num_epochs']} 个周期)...")
    try:
        trained_model = train_model(
            model=model,
            train_loader=train_loader,
            test_loader=test_loader,
            criterion=criterion,
            optimizer=optimizer,
            num_epochs=config['num_epochs'],
            device=config['device'],
            save_path=config['best_model_path']
        )
    except KeyboardInterrupt:
        print("训练被用户中断")
        return

    # 加载最佳模型
    if os.path.exists(config['best_model_path']):
        print("加载最佳模型...")
        trained_model.load_state_dict(torch.load(config['best_model_path']))
    else:
        warnings.warn("未找到最佳模型,使用最后训练的模型")
    
    # Grad-CAM可视化
    print("生成Grad-CAM可视化...")
    test_dir = os.path.join(config['data_dir'], 'test', '*')
    sample_images = []
    
    # 获取测试图像样本
    for class_name in class_names:
        class_dir = os.path.join(config['data_dir'], 'test', class_name)
        images = glob.glob(os.path.join(class_dir, '*.jp*g')) + \
                 glob.glob(os.path.join(class_dir, '*.png'))
        
        if images:
            sample_images.append(images[0])
            print(f"为类别 '{class_name}' 选择样本: {os.path.basename(images[0])}")
        else:
            print(f"警告: 类别 '{class_name}' 未找到测试图像")

    # 执行可视化
    for img_path in sample_images:
        try:
            visualize_grad_cam(
                img_path=img_path,
                model=trained_model,
                class_names=class_names,
                transform=train_loader.dataset.transform,
                device=config['device']
            )
        except Exception as e:
            print(f"处理图像 {img_path} 时出错: {e}")

if __name__ == "__main__":
    main()

@浙大疏锦行

相关推荐
大翻哥哥17 分钟前
Python 2025:异步革命与AI驱动下的开发新范式
开发语言·人工智能·python
hhzz24 分钟前
Pythoner 的Flask项目实践-在web页面实现矢量数据转换工具集功能(附源码)
前端·python·flask
学习的学习者1 小时前
CS课程项目设计19:基于DeepFace人脸识别库的课堂签到系统
人工智能·python·深度学习·人脸识别算法
悠哉悠哉愿意1 小时前
【数据结构与算法学习笔记】双指针
数据结构·笔记·python·学习·算法
MoRanzhi12031 小时前
5. Pandas 缺失值与异常值处理
数据结构·python·数据挖掘·数据分析·pandas·缺失值处理·异常值处理
程序员的奶茶馆2 小时前
Python 字典速查:键值对操作与高频函数
python·面试
tryCbest3 小时前
Python 使用 Redis 详细教程
redis·python·bootstrap
Francek Chen3 小时前
【深度学习计算机视觉】09:语义分割和数据集
人工智能·pytorch·深度学习·计算机视觉·数据集·语义分割
小小毛毛虫~3 小时前
使用Cursor遇到的问题(一):cursor使用conda虚拟环境
python·conda·cursor
wuli玉shell3 小时前
机器学习、数据科学、深度学习、神经网络的区别与联系
深度学习·神经网络·机器学习