DAY43打卡

@浙大疏锦行
kaggle找到一个图像数据集,用cnn网络进行训练并且用grad-cam做可视化

进阶:并拆分成多个文件

python 复制代码
fruit_cnn_project/
├─ data/                # 存放数据集(需手动创建,后续放入图片)
│  ├─ train/            # 训练集图像
│  └─ val/              # 验证集图像
├─ models/              # 模型定义
│  └─ cnn_model.py      # CNN网络结构
├─ utils/               # 工具函数
│  ├─ dataset_utils.py  # 数据加载与预处理
│  ├─ grad_cam.py       # Grad-CAM可视化
│  └─ train_utils.py    # 训练与评估
├─ main.py              # 主程序
└─ requirements.txt     # 依赖列表(可选)
python 复制代码
# 第一部分:导入库
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

# 第二部分:数据加载与预处理
def load_data():
    data_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    train_dataset = datasets.ImageFolder(root='data/train', transform=data_transform)
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
    test_dataset = datasets.ImageFolder(root='data/test', transform=data_transform)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False)
    return train_loader, test_loader

# 第三部分:模型定义
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)
        self.relu = nn.ReLU()
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(32 * 56 * 56, 128)
        self.fc2 = nn.Linear(128, 2)

    def forward(self, x):
        x = self.pool(self.relu(self.conv1(x)))
        x = self.pool(self.relu(self.conv2(x)))
        x = x.view(-1, 32 * 56 * 56)
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# 第四部分:模型训练
train_loader, _ = load_data()
model = SimpleCNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
num_epochs = 10

for epoch in range(num_epochs):
    running_loss = 0.0
    for i, data in enumerate(train_loader, 0):
        inputs, labels = data
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print(f'Epoch {epoch + 1}, Loss: {running_loss / len(train_loader)}')

torch.save(model.state_dict(), 'trained_model.pth')

# 第五部分:模型测试
_, test_loader = load_data()
model = SimpleCNN()
model.load_state_dict(torch.load('trained_model.pth'))
model.eval()
correct = 0
total = 0

with torch.no_grad():
    for data in test_loader:
        images, labels = data
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy of the network on the test images: {100 * correct / total}%')

# 第六部分:Grad-CAM可视化(修复版)
def get_activation():
    activation = {}
    def hook(model, input, output):
        activation['target_layer'] = output.detach()
    return hook, activation

def grad_cam(model, image, target_class_index):
    hook, activation = get_activation()
    target_layer = model.conv2
    target_layer.register_forward_hook(hook)
    
    model.eval()
    image = image.unsqueeze(0)
    image.requires_grad_(True)
    
    output = model(image)
    one_hot = torch.zeros(1, output.size()[-1]).to(image.device)
    one_hot[0][target_class_index] = 1
    
    output.backward(gradient=one_hot, retain_graph=True)
    gradients = image.grad[0].cpu().numpy()
    
    # 从activation字典中获取激活图
    activation_map = activation['target_layer'].cpu().numpy()[0]
    
    weights = np.mean(gradients, axis=(1, 2))
    cam = np.zeros(activation_map.shape[1:], dtype=np.float32)
    
    for i, w in enumerate(weights):
        cam += w * activation_map[i]
    
    cam = np.maximum(cam, 0)
    cam = F.interpolate(
        torch.from_numpy(cam).unsqueeze(0).unsqueeze(0), 
        size=(224, 224), 
        mode='bilinear', 
        align_corners=False
    )[0][0].numpy()
    
    cam = (cam - cam.min()) / (cam.max() - cam.min())
    return cam

# 可视化前几张测试图片
dataiter = iter(test_loader)
images, labels = dataiter.next()

for i in range(5):  # 可视化前5张图片
    image = images[i]
    label = labels[i].item()
    cam = grad_cam(model, image, label)
    
    plt.figure(figsize=(10, 5))
    plt.subplot(1, 2, 1)
    plt.imshow(image.permute(1, 2, 0).numpy())
    plt.title(f'Original Image (Class: {label})')
    plt.axis('off')
    
    plt.subplot(1, 2, 2)
    plt.imshow(image.permute(1, 2, 0).numpy())
    plt.imshow(cam, cmap='jet', alpha=0.5)
    plt.title('Grad-CAM Visualization')
    plt.axis('off')
    
    plt.tight_layout()
    plt.show()
相关推荐
hvinsion几秒前
【开源工具】基于PyQt5工作时长计算器工具开发全解析
开发语言·python·qt·开源·时间·time·工作时长计算
pianmian116 分钟前
使用ArcPy进行栅格数据分析
python
404.Not Found36 分钟前
Day43 Python打卡训练营
开发语言·python
油头少年_w42 分钟前
Python爬虫之数据提取
python
程序员的世界你不懂1 小时前
Appium+python自动化(九)- 定位元素工具
python·appium·自动化
才华是浅浅的耐心1 小时前
Facebook用户信息爬虫技术分析与实现详解
数据库·爬虫·python·facebook
胖墩会武术1 小时前
win32com.client模块 —— Python实现COM自动化控制与数据交互
python·自动化·交互·win32com
蹦蹦跳跳真可爱5892 小时前
计算机视觉处理----OpenCV(从摄像头采集视频、视频处理与视频录制)
人工智能·python·opencv·计算机视觉·音视频
一个天蝎座 白勺 程序猿3 小时前
Python爬虫(48)基于Scrapy-Redis与深度强化学习的智能分布式爬虫架构设计与实践
爬虫·python·scrapy
开开心心就好8 小时前
高效视频倍速播放插件推荐
python·学习·游戏·pdf·计算机外设·电脑·音视频