python43天

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
from data_loader import create_data_loaders
from model import CNNModel
from train import train_model
from visualize import visualize_grad_cam
import os
import glob
import warnings

def main():
    # 配置参数
    config = {
        'data_dir': '/kaggle/input/dogs-vs-cats',  # Kaggle猫狗数据集路径
        'batch_size': 32,
        'num_epochs': 5,
        'learning_rate': 0.001,
        'img_size': 224,
        'device': torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
        'best_model_path': 'best_model.pth'
    }
    
    print(f"使用设备: {config['device']}")
    
    # 加载数据
    print("加载数据...")
    try:
        train_loader, test_loader, class_names = create_data_loaders(
            data_dir=config['data_dir'],
            batch_size=config['batch_size'],
            img_size=config['img_size']
        )
        print(f"发现 {len(class_names)} 个类别: {class_names}")
    except Exception as e:
        print(f"数据加载失败: {e}")
        return

    # 初始化模型
    print("初始化模型...")
    model = CNNModel(num_classes=len(class_names), pretrained=True)
    model = model.to(config['device'])
    
    # 设置优化器和损失函数
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=config['learning_rate'])

    # 训练模型
    print(f"开始训练 ({config['num_epochs']} 个周期)...")
    try:
        trained_model = train_model(
            model=model,
            train_loader=train_loader,
            test_loader=test_loader,
            criterion=criterion,
            optimizer=optimizer,
            num_epochs=config['num_epochs'],
            device=config['device'],
            save_path=config['best_model_path']
        )
    except KeyboardInterrupt:
        print("训练被用户中断")
        return

    # 加载最佳模型
    if os.path.exists(config['best_model_path']):
        print("加载最佳模型...")
        trained_model.load_state_dict(torch.load(config['best_model_path']))
    else:
        warnings.warn("未找到最佳模型,使用最后训练的模型")
    
    # Grad-CAM可视化
    print("生成Grad-CAM可视化...")
    test_dir = os.path.join(config['data_dir'], 'test', '*')
    sample_images = []
    
    # 获取测试图像样本
    for class_name in class_names:
        class_dir = os.path.join(config['data_dir'], 'test', class_name)
        images = glob.glob(os.path.join(class_dir, '*.jp*g')) + \
                 glob.glob(os.path.join(class_dir, '*.png'))
        
        if images:
            sample_images.append(images[0])
            print(f"为类别 '{class_name}' 选择样本: {os.path.basename(images[0])}")
        else:
            print(f"警告: 类别 '{class_name}' 未找到测试图像")

    # 执行可视化
    for img_path in sample_images:
        try:
            visualize_grad_cam(
                img_path=img_path,
                model=trained_model,
                class_names=class_names,
                transform=train_loader.dataset.transform,
                device=config['device']
            )
        except Exception as e:
            print(f"处理图像 {img_path} 时出错: {e}")

if __name__ == "__main__":
    main()

@浙大疏锦行

相关推荐
小鸡吃米…1 小时前
机器学习 - K - 中心聚类
人工智能·机器学习·聚类
沈浩(种子思维作者)2 小时前
真的能精准医疗吗?癌症能提前发现吗?
人工智能·python·网络安全·健康医疗·量子计算
MM_MS2 小时前
Halcon变量控制类型、数据类型转换、字符串格式化、元组操作
开发语言·人工智能·深度学习·算法·目标检测·计算机视觉·视觉检测
njsgcs2 小时前
ue python二次开发启动教程+ 导入fbx到指定文件夹
开发语言·python·unreal engine·ue
io_T_T3 小时前
迭代器 iteration、iter 与 多线程 concurrent 交叉实践(详细)
python
华研前沿标杆游学3 小时前
2026年走进洛阳格力工厂参观游学
python
Carl_奕然3 小时前
【数据挖掘】数据挖掘必会技能之:A/B测试
人工智能·python·数据挖掘·数据分析
AI小怪兽3 小时前
基于YOLOv13的汽车零件分割系统(Python源码+数据集+Pyside6界面)
开发语言·python·yolo·无人机
齐齐大魔王3 小时前
Pascal VOC 数据集
人工智能·深度学习·数据集·voc
wszy18094 小时前
新文章标签:让用户一眼发现最新内容
java·python·harmonyos