python43天

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
from data_loader import create_data_loaders
from model import CNNModel
from train import train_model
from visualize import visualize_grad_cam
import os
import glob
import warnings

def main():
    # 配置参数
    config = {
        'data_dir': '/kaggle/input/dogs-vs-cats',  # Kaggle猫狗数据集路径
        'batch_size': 32,
        'num_epochs': 5,
        'learning_rate': 0.001,
        'img_size': 224,
        'device': torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
        'best_model_path': 'best_model.pth'
    }
    
    print(f"使用设备: {config['device']}")
    
    # 加载数据
    print("加载数据...")
    try:
        train_loader, test_loader, class_names = create_data_loaders(
            data_dir=config['data_dir'],
            batch_size=config['batch_size'],
            img_size=config['img_size']
        )
        print(f"发现 {len(class_names)} 个类别: {class_names}")
    except Exception as e:
        print(f"数据加载失败: {e}")
        return

    # 初始化模型
    print("初始化模型...")
    model = CNNModel(num_classes=len(class_names), pretrained=True)
    model = model.to(config['device'])
    
    # 设置优化器和损失函数
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=config['learning_rate'])

    # 训练模型
    print(f"开始训练 ({config['num_epochs']} 个周期)...")
    try:
        trained_model = train_model(
            model=model,
            train_loader=train_loader,
            test_loader=test_loader,
            criterion=criterion,
            optimizer=optimizer,
            num_epochs=config['num_epochs'],
            device=config['device'],
            save_path=config['best_model_path']
        )
    except KeyboardInterrupt:
        print("训练被用户中断")
        return

    # 加载最佳模型
    if os.path.exists(config['best_model_path']):
        print("加载最佳模型...")
        trained_model.load_state_dict(torch.load(config['best_model_path']))
    else:
        warnings.warn("未找到最佳模型,使用最后训练的模型")
    
    # Grad-CAM可视化
    print("生成Grad-CAM可视化...")
    test_dir = os.path.join(config['data_dir'], 'test', '*')
    sample_images = []
    
    # 获取测试图像样本
    for class_name in class_names:
        class_dir = os.path.join(config['data_dir'], 'test', class_name)
        images = glob.glob(os.path.join(class_dir, '*.jp*g')) + \
                 glob.glob(os.path.join(class_dir, '*.png'))
        
        if images:
            sample_images.append(images[0])
            print(f"为类别 '{class_name}' 选择样本: {os.path.basename(images[0])}")
        else:
            print(f"警告: 类别 '{class_name}' 未找到测试图像")

    # 执行可视化
    for img_path in sample_images:
        try:
            visualize_grad_cam(
                img_path=img_path,
                model=trained_model,
                class_names=class_names,
                transform=train_loader.dataset.transform,
                device=config['device']
            )
        except Exception as e:
            print(f"处理图像 {img_path} 时出错: {e}")

if __name__ == "__main__":
    main()

@浙大疏锦行

相关推荐
小白学大数据1 分钟前
实时监控 1688 商品价格变化的爬虫系统实现
javascript·爬虫·python
jifengzhiling3 分钟前
卡尔曼增益:动态权重,最优估计
人工智能·算法·机器学习
Darkershadow6 分钟前
Python学习之使用笔记本摄像头截屏
python·opencv·学习
Cathyqiii6 分钟前
序列建模模型原理及演进——从RNN、Transformer到SSM与Mamba
人工智能·rnn·深度学习·transformer
极客BIM工作室8 分钟前
大模型的发展历程: 从文本到音视频生成的技术演进
人工智能·机器学习
ekprada9 分钟前
Day 40 深度学习训练与测试的规范写法
人工智能·python
音视频牛哥13 分钟前
C#实战:如何开发设计毫秒级延迟、工业级稳定的Windows平台RTSP/RTMP播放器
人工智能·机器学习·机器人·c#·音视频·rtsp播放器·rtmp播放器
Blossom.1181 小时前
基于时序大模型+强化学习的虚拟电厂储能调度系统:从负荷预测到收益最大化的实战闭环
运维·人工智能·python·决策树·机器学习·自动化·音视频
深蓝海拓2 小时前
PySide6从0开始学习的笔记(四)QMainWindow
笔记·python·学习·pyqt
深蓝海拓2 小时前
PySide6 的 QSettings简单应用学习笔记
python·学习·pyqt