mnist cnn

import torch

import torch.nn as nn

import torch.optim as optim

import torchvision

from torchvision import transforms, datasets

from torch.utils.data import DataLoader

import matplotlib.pyplot as plt

import numpy as np

import os

from sklearn.metrics import confusion_matrix

import seaborn as sns

设置随机种子

torch.manual_seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print(f"使用设备: {device}")

============ 1. 数据准备 ============

def get_custom_data(data_root,test_root, batch_size=32, img_size=(64, 64)):

"""

加载自定义图像数据

Args:

data_root: 数据根目录,包含n个子目录

img_size: 图像调整大小

"""

数据预处理

transform = transforms.Compose([

transforms.Resize(img_size), # 调整图像大小

transforms.ToTensor(), # 转为Tensor

transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # 标准化

])

加载数据集

dataset = datasets.ImageFolder(root=data_root, transform=transform)

testset = datasets.ImageFolder(root=test_root, transform=transform)

计算训练集和测试集大小

#train_size = len(dataset)

#test_size = len(testset)

分割数据集

#train_dataset, test_dataset = torch.utils.data.random_split(

dataset, [train_size, test_size]

#)

创建数据加载器

train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False)

获取类别信息

class_names = dataset.classes

num_classes = len(class_names)

print(f"数据集信息:")

print(f" 训练集: {len(dataset)}")

print(f" 测试集: {len(testset)}")

print(f" 类别数: {num_classes}")

print(f" 类别名称: {class_names}")

return train_loader, test_loader, num_classes, class_names

============ 2. 修改的SimpleCNN模型 ============

class CustomSimpleCNN(nn.Module):

"""

修改的SimpleCNN,适应自定义图像分类

输入: [batch, 3, height, width] (RGB图像)

输出: [batch, num_classes]

"""

def init(self, num_classes, input_channels=3):

super(CustomSimpleCNN, self).init()

卷积层1

self.conv1 = nn.Sequential(

nn.Conv2d(input_channels, 32, kernel_size=3, padding=1),

nn.ReLU(),

nn.MaxPool2d(2) # 尺寸减半

)

卷积层2

self.conv2 = nn.Sequential(

nn.Conv2d(32, 64, kernel_size=3, padding=1),

nn.ReLU(),

nn.MaxPool2d(2) # 尺寸再减半

)

卷积层3

self.conv3 = nn.Sequential(

nn.Conv2d(64, 128, kernel_size=3, padding=1),

nn.ReLU(),

nn.MaxPool2d(2) # 尺寸再减半

)

自适应全局平均池化,适应不同输入尺寸

self.adaptive_pool = nn.AdaptiveAvgPool2d((4, 4))

全连接层

self.fc = nn.Sequential(

nn.Flatten(),

nn.Linear(128 * 4 * 4, 256), # 128 * 4 * 4=2048

nn.ReLU(),

nn.Dropout(0.5),

nn.Linear(256, num_classes)

)

def forward(self, x):

x = self.conv1(x)

x = self.conv2(x)

x = self.conv3(x)

x = self.adaptive_pool(x) # 统一输出尺寸

x = self.fc(x)

return x

============ 3. 训练函数 ============

def train_model(model, train_loader, test_loader, epochs=10, lr=0.001):

"""训练CNN模型"""

criterion = nn.CrossEntropyLoss()

optimizer = optim.Adam(model.parameters(), lr=lr)

train_losses = []

test_accuracies = []

print(f"\n开始训练 {epochs} 个epoch...")

for epoch in range(epochs):

训练模式

model.train()

running_loss = 0.0

correct = 0

total = 0

for batch_idx, (data, target) in enumerate(train_loader):

data, target = data.to(device), target.to(device)

清零梯度

optimizer.zero_grad()

前向传播

output = model(data)

loss = criterion(output, target)

反向传播

loss.backward()

optimizer.step()

running_loss += loss.item()

计算准确率

_, predicted = torch.max(output.data, 1)

total += target.size(0)

correct += (predicted == target).sum().item()

if batch_idx % 10 == 0:

print(f'Epoch [{epoch+1}/{epochs}] '

f'Batch [{batch_idx}/{len(train_loader)}] '

f'Loss: {loss.item():.4f}')

计算平均训练损失和准确率

avg_loss = running_loss / len(train_loader)

train_accuracy = 100 * correct / total

train_losses.append(avg_loss)

测试模式

test_accuracy = evaluate_model(model, test_loader)

test_accuracies.append(test_accuracy)

print(f'Epoch [{epoch+1}/{epochs}] 完成')

print(f' 训练损失: {avg_loss:.4f}, 训练准确率: {train_accuracy:.2f}%')

print(f' 测试准确率: {test_accuracy:.2f}%')

return train_losses, test_accuracies

def evaluate_model(model, test_loader):

"""评估模型准确率"""

model.eval()

correct = 0

total = 0

with torch.no_grad():

for data, target in test_loader:

data, target = data.to(device), target.to(device)

output = model(data)

_, predicted = torch.max(output.data, 1)

total += target.size(0)

correct += (predicted == target).sum().item()

return 100 * correct / total

============ 4. 可视化结果 ============

def visualize_results(model, test_loader, class_names, train_losses, test_accuracies):

"""可视化训练和预测结果"""

fig, axes = plt.subplots(2, 2, figsize=(15, 12))

1. 训练损失曲线

ax1 = axes[0, 0]

ax1.plot(train_losses, 'b-o', linewidth=2, markersize=6)

ax1.set_title('训练损失曲线', fontsize=12, fontweight='bold')

ax1.set_xlabel('Epoch')

ax1.set_ylabel('Loss')

ax1.grid(True, alpha=0.3)

2. 测试准确率曲线

ax2 = axes[0, 1]

ax2.plot(test_accuracies, 'r-s', linewidth=2, markersize=6)

ax2.set_title('测试准确率曲线', fontsize=12, fontweight='bold')

ax2.set_xlabel('Epoch')

ax2.set_ylabel('Accuracy (%)')

ax2.grid(True, alpha=0.3)

3. 显示部分测试样本及预测

ax3 = axes[1, 0]

data_iter = iter(test_loader)

images, labels = next(data_iter)

预测

model.eval()

with torch.no_grad():

outputs = model(images.to(device))

predictions = torch.max(outputs, 1)[1].cpu().numpy()

显示12个样本

num_samples = min(12, len(images))

for i in range(num_samples):

ax = plt.subplot(3, 4, i+1)

反标准化显示图像

img = images[i].numpy().transpose((1, 2, 0))

img = img * 0.5 + 0.5 # 反标准化

plt.imshow(img)

true_label = class_names[labels[i]]

pred_label = class_names[predictions[i]]

color = 'green' if predictions[i] == labels[i] else 'red'

plt.title(f'True: {true_label}\nPred: {pred_label}', color=color, fontsize=8)

plt.axis('off')

plt.suptitle('预测结果 (绿色=正确, 红色=错误)', fontsize=12, fontweight='bold')

4. 混淆矩阵

ax4 = axes[1, 1]

model.eval()

all_preds = []

all_labels = []

with torch.no_grad():

for data, target in test_loader:

data = data.to(device)

output = model(data)

preds = torch.max(output, 1)[1].cpu().numpy()

all_preds.extend(preds)

all_labels.extend(target.numpy())

计算混淆矩阵

cm = confusion_matrix(all_labels, all_preds)

使用seaborn绘制热力图

sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',

xticklabels=class_names, yticklabels=class_names, ax=ax4)

ax4.set_title('混淆矩阵', fontsize=12, fontweight='bold')

ax4.set_xlabel('预测标签')

ax4.set_ylabel('真实标签')

plt.setp(ax4.get_xticklabels(), rotation=45)

plt.setp(ax4.get_yticklabels(), rotation=0)

plt.tight_layout()

plt.savefig('./custom_cnn_results.png', dpi=150, bbox_inches='tight')

plt.show()

============ 5. 打印模型参数 ============

def print_model_parameters(model):

"""打印模型所有参数详情"""

print("=" * 60)

print("模型参数详情:")

print("=" * 60)

total_params = 0

for name, param in model.named_parameters():

if param.requires_grad:

print(f"参数名: {name}")

print(f"形状: {param.shape}")

print(f"参数量: {param.numel():,}")

print(f"数据类型: {param.dtype}")

print(f"设备: {param.device}")

print("-" * 40)

total_params += param.numel()

print(f"总可训练参数量: {total_params:,}")

print("=" * 60)

============ 6. 主程序 ============

def main(data_path="./mnist_train_sorted",test_path="./mnist_test_sorted"):

"""主函数:运行完整训练流程"""

检查数据路径

if not os.path.exists(data_path):

print(f"❌ 错误: 数据路径 '{data_path}' 不存在")

print("请创建目录结构:")

print(f"{data_path}/")

print("├── 类别1/")

print("│ ├── image1.jpg")

print("│ └── image2.jpg")

print("├── 类别2/")

print("└── 类别3/")

return None

加载数据

print("加载自定义图像数据...")

train_loader, test_loader, num_classes, class_names = get_custom_data(

data_root=data_path,

test_root=test_path,

batch_size=32,

img_size=(64, 64) # 可以调整图像大小

)

创建模型

print(f"\n创建CustomSimpleCNN模型,类别数: {num_classes}")

model = CustomSimpleCNN(num_classes=num_classes, input_channels=3).to(device)

print(model)

统计参数量

total_params = sum(p.numel() for p in model.parameters())

print(f"\n总参数量: {total_params:,}")

训练

train_losses, test_accuracies = train_model(

model, train_loader, test_loader,

epochs=15, lr=0.001 # 可以调整epochs

)

可视化

print("\n生成可视化结果...")

visualize_results(model, test_loader, class_names, train_losses, test_accuracies)

保存模型

torch.save({

'model_state_dict': model.state_dict(),

'class_names': class_names,

'num_classes': num_classes

}, './custom_cnn_model.pth')

print("\n模型已保存!")

return model, class_names

运行

if name == "main":

替换为你的数据路径

data_path = "./mnist_train_sorted" # 修改为你的数据目录

test_path = "./mnist_test_sorted" # 修改为你的数据目录

model, class_names = main(data_path,test_path)

if model is not None:

print_model_parameters(model)

相关推荐
二哈赛车手8 小时前
新人笔记---ApiFox的一些常见使用出错
java·笔记·spring
xian_wwq10 小时前
【学习笔记】AGC协调控制系统概述
笔记·学习
x_yeyue11 小时前
三角形数
笔记·算法·数论·组合数学
憧憬成为java架构高手的小白11 小时前
docker学习笔记(基于b站多个视频学习)【未完结】
笔记·学习
RainCity13 小时前
Java Swing 自定义组件库分享(七)
java·笔记·后端
東隅已逝,桑榆非晚13 小时前
字符函数和字符串函数
c语言·笔记
Upsy-Daisy14 小时前
AI Agent 项目学习笔记(七):RAG 高级扩展——过滤检索、PgVector 与云知识库
人工智能·笔记·学习
智者知已应修善业15 小时前
【51单片机LED闪烁10次数码管显示0-9】2023-12-14
c++·经验分享·笔记·算法·51单片机
智者知已应修善业15 小时前
【51单片机2按键控制1个敞亮LED灯闪烁和熄灭】2023-11-3
c++·经验分享·笔记·算法·51单片机
w20180017 小时前
二年级下册语文看图写话作文:蛋壳的奇妙之旅
笔记