python43天

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
from data_loader import create_data_loaders
from model import CNNModel
from train import train_model
from visualize import visualize_grad_cam
import os
import glob
import warnings

def main():
    # 配置参数
    config = {
        'data_dir': '/kaggle/input/dogs-vs-cats',  # Kaggle猫狗数据集路径
        'batch_size': 32,
        'num_epochs': 5,
        'learning_rate': 0.001,
        'img_size': 224,
        'device': torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
        'best_model_path': 'best_model.pth'
    }
    
    print(f"使用设备: {config['device']}")
    
    # 加载数据
    print("加载数据...")
    try:
        train_loader, test_loader, class_names = create_data_loaders(
            data_dir=config['data_dir'],
            batch_size=config['batch_size'],
            img_size=config['img_size']
        )
        print(f"发现 {len(class_names)} 个类别: {class_names}")
    except Exception as e:
        print(f"数据加载失败: {e}")
        return

    # 初始化模型
    print("初始化模型...")
    model = CNNModel(num_classes=len(class_names), pretrained=True)
    model = model.to(config['device'])
    
    # 设置优化器和损失函数
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=config['learning_rate'])

    # 训练模型
    print(f"开始训练 ({config['num_epochs']} 个周期)...")
    try:
        trained_model = train_model(
            model=model,
            train_loader=train_loader,
            test_loader=test_loader,
            criterion=criterion,
            optimizer=optimizer,
            num_epochs=config['num_epochs'],
            device=config['device'],
            save_path=config['best_model_path']
        )
    except KeyboardInterrupt:
        print("训练被用户中断")
        return

    # 加载最佳模型
    if os.path.exists(config['best_model_path']):
        print("加载最佳模型...")
        trained_model.load_state_dict(torch.load(config['best_model_path']))
    else:
        warnings.warn("未找到最佳模型,使用最后训练的模型")
    
    # Grad-CAM可视化
    print("生成Grad-CAM可视化...")
    test_dir = os.path.join(config['data_dir'], 'test', '*')
    sample_images = []
    
    # 获取测试图像样本
    for class_name in class_names:
        class_dir = os.path.join(config['data_dir'], 'test', class_name)
        images = glob.glob(os.path.join(class_dir, '*.jp*g')) + \
                 glob.glob(os.path.join(class_dir, '*.png'))
        
        if images:
            sample_images.append(images[0])
            print(f"为类别 '{class_name}' 选择样本: {os.path.basename(images[0])}")
        else:
            print(f"警告: 类别 '{class_name}' 未找到测试图像")

    # 执行可视化
    for img_path in sample_images:
        try:
            visualize_grad_cam(
                img_path=img_path,
                model=trained_model,
                class_names=class_names,
                transform=train_loader.dataset.transform,
                device=config['device']
            )
        except Exception as e:
            print(f"处理图像 {img_path} 时出错: {e}")

if __name__ == "__main__":
    main()

@浙大疏锦行

相关推荐
我材不敲代码3 小时前
Python实现打包贪吃蛇游戏
开发语言·python·游戏
0思必得05 小时前
[Web自动化] Selenium处理动态网页
前端·爬虫·python·selenium·自动化
韩立学长5 小时前
【开题答辩实录分享】以《基于Python的大学超市仓储信息管理系统的设计与实现》为例进行选题答辩实录分享
开发语言·python
大山同学5 小时前
图片补全-Context Encoder
人工智能·机器学习·计算机视觉
qq_192779875 小时前
高级爬虫技巧:处理JavaScript渲染(Selenium)
jvm·数据库·python
薛定谔的猫19826 小时前
十七、用 GPT2 中文对联模型实现经典上联自动对下联:
人工智能·深度学习·gpt2·大模型 训练 调优
u0109272716 小时前
使用Plotly创建交互式图表
jvm·数据库·python
爱学习的阿磊6 小时前
Python GUI开发:Tkinter入门教程
jvm·数据库·python
Imm7776 小时前
中国知名的车膜品牌推荐几家
人工智能·python
tudficdew6 小时前
实战:用Python分析某电商销售数据
jvm·数据库·python