mnist cnn

import torch

import torch.nn as nn

import torch.optim as optim

import torchvision

from torchvision import transforms, datasets

from torch.utils.data import DataLoader

import matplotlib.pyplot as plt

import numpy as np

import os

from sklearn.metrics import confusion_matrix

import seaborn as sns

设置随机种子

torch.manual_seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print(f"使用设备: {device}")

============ 1. 数据准备 ============

def get_custom_data(data_root,test_root, batch_size=32, img_size=(64, 64)):

"""

加载自定义图像数据

Args:

data_root: 数据根目录,包含n个子目录

img_size: 图像调整大小

"""

数据预处理

transform = transforms.Compose([

transforms.Resize(img_size), # 调整图像大小

transforms.ToTensor(), # 转为Tensor

transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # 标准化

])

加载数据集

dataset = datasets.ImageFolder(root=data_root, transform=transform)

testset = datasets.ImageFolder(root=test_root, transform=transform)

计算训练集和测试集大小

#train_size = len(dataset)

#test_size = len(testset)

分割数据集

#train_dataset, test_dataset = torch.utils.data.random_split(

dataset, [train_size, test_size]

#)

创建数据加载器

train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False)

获取类别信息

class_names = dataset.classes

num_classes = len(class_names)

print(f"数据集信息:")

print(f" 训练集: {len(dataset)}")

print(f" 测试集: {len(testset)}")

print(f" 类别数: {num_classes}")

print(f" 类别名称: {class_names}")

return train_loader, test_loader, num_classes, class_names

============ 2. 修改的SimpleCNN模型 ============

class CustomSimpleCNN(nn.Module):

"""

修改的SimpleCNN,适应自定义图像分类

输入: [batch, 3, height, width] (RGB图像)

输出: [batch, num_classes]

"""

def init(self, num_classes, input_channels=3):

super(CustomSimpleCNN, self).init()

卷积层1

self.conv1 = nn.Sequential(

nn.Conv2d(input_channels, 32, kernel_size=3, padding=1),

nn.ReLU(),

nn.MaxPool2d(2) # 尺寸减半

)

卷积层2

self.conv2 = nn.Sequential(

nn.Conv2d(32, 64, kernel_size=3, padding=1),

nn.ReLU(),

nn.MaxPool2d(2) # 尺寸再减半

)

卷积层3

self.conv3 = nn.Sequential(

nn.Conv2d(64, 128, kernel_size=3, padding=1),

nn.ReLU(),

nn.MaxPool2d(2) # 尺寸再减半

)

自适应全局平均池化,适应不同输入尺寸

self.adaptive_pool = nn.AdaptiveAvgPool2d((4, 4))

全连接层

self.fc = nn.Sequential(

nn.Flatten(),

nn.Linear(128 * 4 * 4, 256), # 128 * 4 * 4=2048

nn.ReLU(),

nn.Dropout(0.5),

nn.Linear(256, num_classes)

)

def forward(self, x):

x = self.conv1(x)

x = self.conv2(x)

x = self.conv3(x)

x = self.adaptive_pool(x) # 统一输出尺寸

x = self.fc(x)

return x

============ 3. 训练函数 ============

def train_model(model, train_loader, test_loader, epochs=10, lr=0.001):

"""训练CNN模型"""

criterion = nn.CrossEntropyLoss()

optimizer = optim.Adam(model.parameters(), lr=lr)

train_losses = []

test_accuracies = []

print(f"\n开始训练 {epochs} 个epoch...")

for epoch in range(epochs):

训练模式

model.train()

running_loss = 0.0

correct = 0

total = 0

for batch_idx, (data, target) in enumerate(train_loader):

data, target = data.to(device), target.to(device)

清零梯度

optimizer.zero_grad()

前向传播

output = model(data)

loss = criterion(output, target)

反向传播

loss.backward()

optimizer.step()

running_loss += loss.item()

计算准确率

_, predicted = torch.max(output.data, 1)

total += target.size(0)

correct += (predicted == target).sum().item()

if batch_idx % 10 == 0:

print(f'Epoch [{epoch+1}/{epochs}] '

f'Batch [{batch_idx}/{len(train_loader)}] '

f'Loss: {loss.item():.4f}')

计算平均训练损失和准确率

avg_loss = running_loss / len(train_loader)

train_accuracy = 100 * correct / total

train_losses.append(avg_loss)

测试模式

test_accuracy = evaluate_model(model, test_loader)

test_accuracies.append(test_accuracy)

print(f'Epoch [{epoch+1}/{epochs}] 完成')

print(f' 训练损失: {avg_loss:.4f}, 训练准确率: {train_accuracy:.2f}%')

print(f' 测试准确率: {test_accuracy:.2f}%')

return train_losses, test_accuracies

def evaluate_model(model, test_loader):

"""评估模型准确率"""

model.eval()

correct = 0

total = 0

with torch.no_grad():

for data, target in test_loader:

data, target = data.to(device), target.to(device)

output = model(data)

_, predicted = torch.max(output.data, 1)

total += target.size(0)

correct += (predicted == target).sum().item()

return 100 * correct / total

============ 4. 可视化结果 ============

def visualize_results(model, test_loader, class_names, train_losses, test_accuracies):

"""可视化训练和预测结果"""

fig, axes = plt.subplots(2, 2, figsize=(15, 12))

1. 训练损失曲线

ax1 = axes[0, 0]

ax1.plot(train_losses, 'b-o', linewidth=2, markersize=6)

ax1.set_title('训练损失曲线', fontsize=12, fontweight='bold')

ax1.set_xlabel('Epoch')

ax1.set_ylabel('Loss')

ax1.grid(True, alpha=0.3)

2. 测试准确率曲线

ax2 = axes[0, 1]

ax2.plot(test_accuracies, 'r-s', linewidth=2, markersize=6)

ax2.set_title('测试准确率曲线', fontsize=12, fontweight='bold')

ax2.set_xlabel('Epoch')

ax2.set_ylabel('Accuracy (%)')

ax2.grid(True, alpha=0.3)

3. 显示部分测试样本及预测

ax3 = axes[1, 0]

data_iter = iter(test_loader)

images, labels = next(data_iter)

预测

model.eval()

with torch.no_grad():

outputs = model(images.to(device))

predictions = torch.max(outputs, 1)[1].cpu().numpy()

显示12个样本

num_samples = min(12, len(images))

for i in range(num_samples):

ax = plt.subplot(3, 4, i+1)

反标准化显示图像

img = images[i].numpy().transpose((1, 2, 0))

img = img * 0.5 + 0.5 # 反标准化

plt.imshow(img)

true_label = class_names[labels[i]]

pred_label = class_names[predictions[i]]

color = 'green' if predictions[i] == labels[i] else 'red'

plt.title(f'True: {true_label}\nPred: {pred_label}', color=color, fontsize=8)

plt.axis('off')

plt.suptitle('预测结果 (绿色=正确, 红色=错误)', fontsize=12, fontweight='bold')

4. 混淆矩阵

ax4 = axes[1, 1]

model.eval()

all_preds = []

all_labels = []

with torch.no_grad():

for data, target in test_loader:

data = data.to(device)

output = model(data)

preds = torch.max(output, 1)[1].cpu().numpy()

all_preds.extend(preds)

all_labels.extend(target.numpy())

计算混淆矩阵

cm = confusion_matrix(all_labels, all_preds)

使用seaborn绘制热力图

sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',

xticklabels=class_names, yticklabels=class_names, ax=ax4)

ax4.set_title('混淆矩阵', fontsize=12, fontweight='bold')

ax4.set_xlabel('预测标签')

ax4.set_ylabel('真实标签')

plt.setp(ax4.get_xticklabels(), rotation=45)

plt.setp(ax4.get_yticklabels(), rotation=0)

plt.tight_layout()

plt.savefig('./custom_cnn_results.png', dpi=150, bbox_inches='tight')

plt.show()

============ 5. 打印模型参数 ============

def print_model_parameters(model):

"""打印模型所有参数详情"""

print("=" * 60)

print("模型参数详情:")

print("=" * 60)

total_params = 0

for name, param in model.named_parameters():

if param.requires_grad:

print(f"参数名: {name}")

print(f"形状: {param.shape}")

print(f"参数量: {param.numel():,}")

print(f"数据类型: {param.dtype}")

print(f"设备: {param.device}")

print("-" * 40)

total_params += param.numel()

print(f"总可训练参数量: {total_params:,}")

print("=" * 60)

============ 6. 主程序 ============

def main(data_path="./mnist_train_sorted",test_path="./mnist_test_sorted"):

"""主函数:运行完整训练流程"""

检查数据路径

if not os.path.exists(data_path):

print(f"❌ 错误: 数据路径 '{data_path}' 不存在")

print("请创建目录结构:")

print(f"{data_path}/")

print("├── 类别1/")

print("│ ├── image1.jpg")

print("│ └── image2.jpg")

print("├── 类别2/")

print("└── 类别3/")

return None

加载数据

print("加载自定义图像数据...")

train_loader, test_loader, num_classes, class_names = get_custom_data(

data_root=data_path,

test_root=test_path,

batch_size=32,

img_size=(64, 64) # 可以调整图像大小

)

创建模型

print(f"\n创建CustomSimpleCNN模型,类别数: {num_classes}")

model = CustomSimpleCNN(num_classes=num_classes, input_channels=3).to(device)

print(model)

统计参数量

total_params = sum(p.numel() for p in model.parameters())

print(f"\n总参数量: {total_params:,}")

训练

train_losses, test_accuracies = train_model(

model, train_loader, test_loader,

epochs=15, lr=0.001 # 可以调整epochs

)

可视化

print("\n生成可视化结果...")

visualize_results(model, test_loader, class_names, train_losses, test_accuracies)

保存模型

torch.save({

'model_state_dict': model.state_dict(),

'class_names': class_names,

'num_classes': num_classes

}, './custom_cnn_model.pth')

print("\n模型已保存!")

return model, class_names

运行

if name == "main":

替换为你的数据路径

data_path = "./mnist_train_sorted" # 修改为你的数据目录

test_path = "./mnist_test_sorted" # 修改为你的数据目录

model, class_names = main(data_path,test_path)

if model is not None:

print_model_parameters(model)

相关推荐
杨小扩10 小时前
OpenAI Codex CLI 命令行参考笔记
人工智能·笔记
做cv的小昊11 小时前
大语言模型系统:【CMU 11-868】课程学习笔记06——Transformer学习(Transformer)
笔记·学习·语言模型
Vae_Mars13 小时前
华睿MVP:C#脚本的应用一
笔记·c#
_muffinman13 小时前
Java学习笔记-第2章 运算和语句
java·笔记·学习
六元七角八分13 小时前
学习笔记一《JavaScript基础语法》
javascript·笔记·学习
风酥糖14 小时前
在Termux中运行Siyuan笔记服务
android·linux·服务器·笔记
跃龙客14 小时前
C++写文件笔记
c++·笔记
宵时待雨14 小时前
C++笔记归纳11:多态
开发语言·c++·笔记
李昊哲小课15 小时前
NumPy 完整学习笔记
笔记·python·学习·数据分析·numpy
Jasminee15 小时前
SSH 服务攻防实战
笔记·安全