python43天

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
from data_loader import create_data_loaders
from model import CNNModel
from train import train_model
from visualize import visualize_grad_cam
import os
import glob
import warnings

def main():
    # 配置参数
    config = {
        'data_dir': '/kaggle/input/dogs-vs-cats',  # Kaggle猫狗数据集路径
        'batch_size': 32,
        'num_epochs': 5,
        'learning_rate': 0.001,
        'img_size': 224,
        'device': torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
        'best_model_path': 'best_model.pth'
    }
    
    print(f"使用设备: {config['device']}")
    
    # 加载数据
    print("加载数据...")
    try:
        train_loader, test_loader, class_names = create_data_loaders(
            data_dir=config['data_dir'],
            batch_size=config['batch_size'],
            img_size=config['img_size']
        )
        print(f"发现 {len(class_names)} 个类别: {class_names}")
    except Exception as e:
        print(f"数据加载失败: {e}")
        return

    # 初始化模型
    print("初始化模型...")
    model = CNNModel(num_classes=len(class_names), pretrained=True)
    model = model.to(config['device'])
    
    # 设置优化器和损失函数
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=config['learning_rate'])

    # 训练模型
    print(f"开始训练 ({config['num_epochs']} 个周期)...")
    try:
        trained_model = train_model(
            model=model,
            train_loader=train_loader,
            test_loader=test_loader,
            criterion=criterion,
            optimizer=optimizer,
            num_epochs=config['num_epochs'],
            device=config['device'],
            save_path=config['best_model_path']
        )
    except KeyboardInterrupt:
        print("训练被用户中断")
        return

    # 加载最佳模型
    if os.path.exists(config['best_model_path']):
        print("加载最佳模型...")
        trained_model.load_state_dict(torch.load(config['best_model_path']))
    else:
        warnings.warn("未找到最佳模型,使用最后训练的模型")
    
    # Grad-CAM可视化
    print("生成Grad-CAM可视化...")
    test_dir = os.path.join(config['data_dir'], 'test', '*')
    sample_images = []
    
    # 获取测试图像样本
    for class_name in class_names:
        class_dir = os.path.join(config['data_dir'], 'test', class_name)
        images = glob.glob(os.path.join(class_dir, '*.jp*g')) + \
                 glob.glob(os.path.join(class_dir, '*.png'))
        
        if images:
            sample_images.append(images[0])
            print(f"为类别 '{class_name}' 选择样本: {os.path.basename(images[0])}")
        else:
            print(f"警告: 类别 '{class_name}' 未找到测试图像")

    # 执行可视化
    for img_path in sample_images:
        try:
            visualize_grad_cam(
                img_path=img_path,
                model=trained_model,
                class_names=class_names,
                transform=train_loader.dataset.transform,
                device=config['device']
            )
        except Exception as e:
            print(f"处理图像 {img_path} 时出错: {e}")

if __name__ == "__main__":
    main()

@浙大疏锦行

相关推荐
骥龙1 小时前
1.2、实战准备:AI安全研究环境搭建与工具链
人工智能·python·安全
黄思搏1 小时前
Python + uiautomator2 手机自动化控制教程
python·智能手机·自动化
@LetsTGBot搜索引擎机器人1 小时前
Telegram 被封是什么原因?如何解决?(附 @letstgbot 搜索引擎重连技巧)
开发语言·python·搜索引擎·机器人·.net
AndrewHZ1 小时前
【图像处理基石】图像对比度增强入门:从概念到实战(Python+OpenCV)
图像处理·python·opencv·计算机视觉·cv·对比度增强·算法入门
XXX-X-XXJ1 小时前
Django 用户认证流程详解:从原理到实现
数据库·后端·python·django·sqlite
2401_841495643 小时前
【数据结构】基于Prim算法的最小生成树
java·数据结构·c++·python·算法·最小生成树·prim
昵称是6硬币4 小时前
YOLO26论文精读(逐段解析)
人工智能·深度学习·yolo·目标检测·计算机视觉·yolo26
数据村的古老师6 小时前
Python数据分析实战:基于25年黄金价格数据的特征提取与算法应用【数据集可下载】
开发语言·python·数据分析
小王不爱笑1326 小时前
Java 核心知识点查漏补缺(一)
java·开发语言·python
小冷爱读书6 小时前
F-INR: Functional Tensor Decomposition for Implicit Neural Representations
深度学习·inr·函数张量分解