DAY 43 复习日

浙大疏锦行-CSDN博客
作业:

kaggle找到一个图像数据集,用cnn网络进行训练并且用grad-cam做可视化

进阶:并拆分成多个文件

复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import os
import warnings
warnings.filterwarnings("ignore")
 
# 设置随机种子确保结果可复现
torch.manual_seed(42)
np.random.seed(42)
 
# 数据预处理
transform = transforms.Compose([
    transforms.Resize((64, 64)),  # 调整图像大小
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
 
# 加载自定义数据集
dataset_path = r"F:\Program Files\MyPythonProjects\day43\music_instruments"
dataset = ImageFolder(root=dataset_path, transform=transform)
 
# 划分训练集和测试集
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])
 
# 定义类别名称
classes = ('accordion', 'banjo', 'drum', 'flute', 'guitar', 
           'harmonica', 'saxophone', 'sitar', 'tabla', 'violin')
# 初始化模型
model = SimpleCNN()
print("模型已创建")
 
# 使用GPU或CPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)
 
# 训练模型
def train_model(model, epochs=10):
    trainloader = torch.utils.data.DataLoader(
        train_dataset, 
        batch_size=32,
        shuffle=True, 
        num_workers=2
    )
    
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    
    for epoch in range(epochs):
        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            if i % 10 == 9:
                print(f'[{epoch + 1}, {i + 1}] 损失: {running_loss / 10:.3f}')
                running_loss = 0.0
    
    print("训练完成")
 
# 训练或加载模型
try:
    model.load_state_dict(torch.load('music_instruments_cnn.pth'))
    print("已加载预训练模型")
except:
    print("无法加载预训练模型,使用未训练模型或训练新模型")
    train_model(model, epochs=10)
    torch.save(model.state_dict(), 'music_instruments_cnn.pth')