cnn训练并用grad-cam可视化

使用大米图片训练集,包含五个文件,分别是5种品牌的大米,使用cnn进行分类训练。

  • -Arborio/ :代表 Arborio 品种的大米图像数据,根据 Rice_Citation_Request.txt 文件可知,该数据集中包含 Arborio 品种的大米图像。
  • Basmati/ :代表 Basmati 品种的大米图像数据,同样是数据集中 Basmati 品种大米的图像集合。
  • Ipsala/ :代表 Ipsala 品种的大米图像数据,该文件夹下存储了大量 Ipsala 品种大米的图像文件。
  • Jasmine/ :代表 Jasmine 品种的大米图像数据,是 Jasmine 品种大米的图像数据集。
  • Karacadag/ :代表 Karacadag 品种的大米图像数据,包含该品种大米的相关图像。
python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F
from torchvision.models import resnet18
import cv2
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

# 数据预处理
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# 加载数据集
data_dir = 'e:/2025_python/Rice_Image_Dataset'
train_dataset = datasets.ImageFolder(root=data_dir, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

# 定义 CNN 模型
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(32 * 56 * 56, 128)
        self.fc2 = nn.Linear(128, len(train_dataset.classes))

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 32 * 56 * 56)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# 初始化模型、损失函数和优化器
model = SimpleCNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 训练模型
num_epochs = 10
for epoch in range(num_epochs):
    running_loss = 0.0
    for i, (images, labels) in enumerate(train_loader):
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print(f'Epoch {epoch + 1}, Loss: {running_loss / len(train_loader)}')

# Grad - CAM 实现
def grad_cam(model, img, target_layer):
    model.eval()
    img = img.unsqueeze(0)
    img.requires_grad_()

    feature_maps = []
    gradients = []

    def forward_hook(module, input, output):
        feature_maps.append(output)

    def backward_hook(module, grad_input, grad_output):
        gradients.append(grad_output[0])

    hook = target_layer.register_forward_hook(forward_hook)
    hook_backward = target_layer.register_backward_hook(backward_hook)

    output = model(img)
    pred = torch.argmax(output, dim=1)
    output[0, pred].backward()

    hook.remove()
    hook_backward.remove()

    feature_map = feature_maps[0][0]
    gradient = gradients[0][0]

    weights = torch.mean(gradient, dim=(1, 2))
    cam = torch.zeros(feature_map.shape[1:], dtype=torch.float32)
    for i, w in enumerate(weights):
        cam += w * feature_map[i, :, :]

    cam = torch.relu(cam)
    cam = cam.detach().numpy()
    cam = cv2.resize(cam, (img.shape[3], img.shape[2]))
    cam = (cam - np.min(cam)) / (np.max(cam) - np.min(cam))
    return cam

# 选择一张图片进行 Grad - CAM 可视化
sample_img, _ = train_dataset[0]
cam = grad_cam(model, sample_img, model.conv2)

# 可视化结果
img_np = sample_img.permute(1, 2, 0).numpy()
img_np = (img_np - np.min(img_np)) / (np.max(img_np) - np.min(img_np))
cam = np.uint8(255 * cam)
heatmap = cv2.applyColorMap(cam, cv2.COLORMAP_JET)
superimposed_img = cv2.addWeighted(np.uint8(255 * img_np), 0.6, heatmap, 0.4, 0)

plt.imshow(cv2.cvtColor(superimposed_img, cv2.COLOR_BGR2RGB))
plt.axis('off')
plt.show()

@浙大疏锦行

相关推荐
jason成都1 分钟前
c#开发AI模型对话
人工智能·c#
m0_748250746 分钟前
STM32上部署AI的两个实用软件——Nanoedge AI Studio和STM32Cube AI
人工智能·stm32·嵌入式硬件
Oliverro11 分钟前
EasyRTC嵌入式音视频通信SDK音视频功能驱动视频业务多场景应用
人工智能·音视频
画江湖Test15 分钟前
乘用车自动驾驶和非乘用车(矿车,卡车)自动驾驶区别
人工智能·机器学习·自动驾驶·车载测试·汽车测试·座舱测试
清醒的兰15 分钟前
OpenCV 自带颜色表实现各种滤镜
人工智能·opencv·计算机视觉
亚马逊云开发者15 分钟前
使用 Amazon Q Developer CLI 调用 MCP Server 实现 Amazon Support 案例自动创建
人工智能
奔跑吧邓邓子23 分钟前
DeepSeek 赋能金融衍生品:定价与风险管理的智能革命
人工智能·金融衍生品·deepseek·金融市场·定价与风险管理
阿里云云原生25 分钟前
【发布实录】云原生+AI,助力企业全球化业务创新
人工智能·云原生·可观测·通义灵码
BFT白芙堂33 分钟前
涂胶协作机器人解决方案 | Kinova Link 6 Cobot在涂胶工业的方案应用与价值
人工智能·协作机器人·机器人解决方案·kinova·kinovalink6·bft机器人·工业涂胶
jndingxin39 分钟前
OpenCV CUDA模块图像处理------图像连通域标记接口函数connectedComponents()
图像处理·人工智能·opencv