DAY 43 复习日

作业:

kaggle找到一个图像数据集,用cnn网络进行训练并且用grad-cam做可视化

进阶:并拆分成多个文件

复制代码
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
from torchvision.transforms import ToPILImage
from tqdm import tqdm
import matplotlib.pyplot as plt
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image

# 数据处理
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# 使用Kaggle的花朵数据集
dataset = datasets.ImageFolder(root='./flower_data', transform=transform)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

# CNN模型
class FlowerCNN(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(64, 128, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(128, 256, 3, padding=1),  # 用于GradCAM的目标层
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(256 * 28 * 28, 512),
            nn.ReLU(),
            nn.Linear(512, num_classes)
        )
    
    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

# 训练函数
def train_model(model, epochs=5):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-4)
    
    for epoch in range(epochs):
        model.train()
        train_loss = 0.0
        with tqdm(train_loader, desc=f'Epoch {epoch+1}') as pbar:
            for imgs, labels in pbar:
                imgs, labels = imgs.to(device), labels.to(device)
                optimizer.zero_grad()
                outputs = model(imgs)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()
                
                train_loss += loss.item() * imgs.size(0)
                pbar.set_postfix(loss=loss.item())
        
        # 验证
        model.eval()
        val_correct = 0
        with torch.no_grad():
            for imgs, labels in val_loader:
                imgs, labels = imgs.to(device), labels.to(device)
                outputs = model(imgs)
                _, preds = torch.max(outputs, 1)
                val_correct += (preds == labels).sum().item()
        
        print(f'Val Acc: {val_correct/len(val_dataset):.4f}')
    
    return model

# 训练模型
num_classes = len(dataset.classes)
model = FlowerCNN(num_classes)
model = train_model(model)

# Grad-CAM可视化
def visualize_grad_cam(model, img_tensor, target_layer):
    cam = GradCAM(model=model, target_layer=target_layer)
    grayscale_cam = cam(input_tensor=img_tensor.unsqueeze(0))
    grayscale_cam = grayscale_cam[0, :]
    
    # 图像转换
    img = img_tensor.permute(1, 2, 0).numpy()
    img = (img - img.min()) / (img.max() - img.min())  # 归一化
    
    visualization = show_cam_on_image(img, grayscale_cam, use_rgb=True)
    plt.imshow(visualization)
    plt.axis('off')
    plt.show()

# 随机选择一张验证集图片进行可视化
img, label = val_dataset[0]
target_layer = model.features[6]  # 选择最后一个卷积层
visualize_grad_cam(model, img, target_layer)

@浙大疏锦行

相关推荐
亓才孓5 分钟前
[Class的应用]获取类的信息
java·开发语言
开开心心就好13 分钟前
AI人声伴奏分离工具,离线提取伴奏K歌用
java·linux·开发语言·网络·人工智能·电脑·blender
Never_Satisfied16 分钟前
在JavaScript / HTML中,关于querySelectorAll方法
开发语言·javascript·html
B站_计算机毕业设计之家31 分钟前
豆瓣电影数据采集分析推荐系统 | Python Vue Flask框架 LSTM Echarts多技术融合开发 毕业设计源码 计算机
vue.js·python·机器学习·flask·echarts·lstm·推荐算法
渣渣苏38 分钟前
Langchain实战快速入门
人工智能·python·langchain
3GPP仿真实验室40 分钟前
【Matlab源码】6G候选波形:OFDM-IM 增强仿真平台 DM、CI
开发语言·matlab·ci/cd
devmoon44 分钟前
在 Polkadot 上部署独立区块链Paseo 测试网实战部署指南
开发语言·安全·区块链·polkadot·erc-20·测试网·独立链
lili-felicity44 分钟前
CANN流水线并行推理与资源调度优化
开发语言·人工智能
沐知全栈开发1 小时前
CSS3 边框:全面解析与实战技巧
开发语言
lili-felicity1 小时前
CANN模型量化详解:从FP32到INT8的精度与性能平衡
人工智能·python