mnist cnn

import torch

import torch.nn as nn

import torch.optim as optim

import torchvision

from torchvision import transforms, datasets

from torch.utils.data import DataLoader

import matplotlib.pyplot as plt

import numpy as np

import os

from sklearn.metrics import confusion_matrix

import seaborn as sns

设置随机种子

torch.manual_seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print(f"使用设备: {device}")

============ 1. 数据准备 ============

def get_custom_data(data_root,test_root, batch_size=32, img_size=(64, 64)):

"""

加载自定义图像数据

Args:

data_root: 数据根目录,包含n个子目录

img_size: 图像调整大小

"""

数据预处理

transform = transforms.Compose([

transforms.Resize(img_size), # 调整图像大小

transforms.ToTensor(), # 转为Tensor

transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # 标准化

])

加载数据集

dataset = datasets.ImageFolder(root=data_root, transform=transform)

testset = datasets.ImageFolder(root=test_root, transform=transform)

计算训练集和测试集大小

#train_size = len(dataset)

#test_size = len(testset)

分割数据集

#train_dataset, test_dataset = torch.utils.data.random_split(

dataset, [train_size, test_size]

#)

创建数据加载器

train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False)

获取类别信息

class_names = dataset.classes

num_classes = len(class_names)

print(f"数据集信息:")

print(f" 训练集: {len(dataset)}")

print(f" 测试集: {len(testset)}")

print(f" 类别数: {num_classes}")

print(f" 类别名称: {class_names}")

return train_loader, test_loader, num_classes, class_names

============ 2. 修改的SimpleCNN模型 ============

class CustomSimpleCNN(nn.Module):

"""

修改的SimpleCNN,适应自定义图像分类

输入: [batch, 3, height, width] (RGB图像)

输出: [batch, num_classes]

"""

def init(self, num_classes, input_channels=3):

super(CustomSimpleCNN, self).init()

卷积层1

self.conv1 = nn.Sequential(

nn.Conv2d(input_channels, 32, kernel_size=3, padding=1),

nn.ReLU(),

nn.MaxPool2d(2) # 尺寸减半

)

卷积层2

self.conv2 = nn.Sequential(

nn.Conv2d(32, 64, kernel_size=3, padding=1),

nn.ReLU(),

nn.MaxPool2d(2) # 尺寸再减半

)

卷积层3

self.conv3 = nn.Sequential(

nn.Conv2d(64, 128, kernel_size=3, padding=1),

nn.ReLU(),

nn.MaxPool2d(2) # 尺寸再减半

)

自适应全局平均池化,适应不同输入尺寸

self.adaptive_pool = nn.AdaptiveAvgPool2d((4, 4))

全连接层

self.fc = nn.Sequential(

nn.Flatten(),

nn.Linear(128 * 4 * 4, 256), # 128 * 4 * 4=2048

nn.ReLU(),

nn.Dropout(0.5),

nn.Linear(256, num_classes)

)

def forward(self, x):

x = self.conv1(x)

x = self.conv2(x)

x = self.conv3(x)

x = self.adaptive_pool(x) # 统一输出尺寸

x = self.fc(x)

return x

============ 3. 训练函数 ============

def train_model(model, train_loader, test_loader, epochs=10, lr=0.001):

"""训练CNN模型"""

criterion = nn.CrossEntropyLoss()

optimizer = optim.Adam(model.parameters(), lr=lr)

train_losses = []

test_accuracies = []

print(f"\n开始训练 {epochs} 个epoch...")

for epoch in range(epochs):

训练模式

model.train()

running_loss = 0.0

correct = 0

total = 0

for batch_idx, (data, target) in enumerate(train_loader):

data, target = data.to(device), target.to(device)

清零梯度

optimizer.zero_grad()

前向传播

output = model(data)

loss = criterion(output, target)

反向传播

loss.backward()

optimizer.step()

running_loss += loss.item()

计算准确率

_, predicted = torch.max(output.data, 1)

total += target.size(0)

correct += (predicted == target).sum().item()

if batch_idx % 10 == 0:

print(f'Epoch [{epoch+1}/{epochs}] '

f'Batch [{batch_idx}/{len(train_loader)}] '

f'Loss: {loss.item():.4f}')

计算平均训练损失和准确率

avg_loss = running_loss / len(train_loader)

train_accuracy = 100 * correct / total

train_losses.append(avg_loss)

测试模式

test_accuracy = evaluate_model(model, test_loader)

test_accuracies.append(test_accuracy)

print(f'Epoch [{epoch+1}/{epochs}] 完成')

print(f' 训练损失: {avg_loss:.4f}, 训练准确率: {train_accuracy:.2f}%')

print(f' 测试准确率: {test_accuracy:.2f}%')

return train_losses, test_accuracies

def evaluate_model(model, test_loader):

"""评估模型准确率"""

model.eval()

correct = 0

total = 0

with torch.no_grad():

for data, target in test_loader:

data, target = data.to(device), target.to(device)

output = model(data)

_, predicted = torch.max(output.data, 1)

total += target.size(0)

correct += (predicted == target).sum().item()

return 100 * correct / total

============ 4. 可视化结果 ============

def visualize_results(model, test_loader, class_names, train_losses, test_accuracies):

"""可视化训练和预测结果"""

fig, axes = plt.subplots(2, 2, figsize=(15, 12))

1. 训练损失曲线

ax1 = axes[0, 0]

ax1.plot(train_losses, 'b-o', linewidth=2, markersize=6)

ax1.set_title('训练损失曲线', fontsize=12, fontweight='bold')

ax1.set_xlabel('Epoch')

ax1.set_ylabel('Loss')

ax1.grid(True, alpha=0.3)

2. 测试准确率曲线

ax2 = axes[0, 1]

ax2.plot(test_accuracies, 'r-s', linewidth=2, markersize=6)

ax2.set_title('测试准确率曲线', fontsize=12, fontweight='bold')

ax2.set_xlabel('Epoch')

ax2.set_ylabel('Accuracy (%)')

ax2.grid(True, alpha=0.3)

3. 显示部分测试样本及预测

ax3 = axes[1, 0]

data_iter = iter(test_loader)

images, labels = next(data_iter)

预测

model.eval()

with torch.no_grad():

outputs = model(images.to(device))

predictions = torch.max(outputs, 1)[1].cpu().numpy()

显示12个样本

num_samples = min(12, len(images))

for i in range(num_samples):

ax = plt.subplot(3, 4, i+1)

反标准化显示图像

img = images[i].numpy().transpose((1, 2, 0))

img = img * 0.5 + 0.5 # 反标准化

plt.imshow(img)

true_label = class_names[labels[i]]

pred_label = class_names[predictions[i]]

color = 'green' if predictions[i] == labels[i] else 'red'

plt.title(f'True: {true_label}\nPred: {pred_label}', color=color, fontsize=8)

plt.axis('off')

plt.suptitle('预测结果 (绿色=正确, 红色=错误)', fontsize=12, fontweight='bold')

4. 混淆矩阵

ax4 = axes[1, 1]

model.eval()

all_preds = []

all_labels = []

with torch.no_grad():

for data, target in test_loader:

data = data.to(device)

output = model(data)

preds = torch.max(output, 1)[1].cpu().numpy()

all_preds.extend(preds)

all_labels.extend(target.numpy())

计算混淆矩阵

cm = confusion_matrix(all_labels, all_preds)

使用seaborn绘制热力图

sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',

xticklabels=class_names, yticklabels=class_names, ax=ax4)

ax4.set_title('混淆矩阵', fontsize=12, fontweight='bold')

ax4.set_xlabel('预测标签')

ax4.set_ylabel('真实标签')

plt.setp(ax4.get_xticklabels(), rotation=45)

plt.setp(ax4.get_yticklabels(), rotation=0)

plt.tight_layout()

plt.savefig('./custom_cnn_results.png', dpi=150, bbox_inches='tight')

plt.show()

============ 5. 打印模型参数 ============

def print_model_parameters(model):

"""打印模型所有参数详情"""

print("=" * 60)

print("模型参数详情:")

print("=" * 60)

total_params = 0

for name, param in model.named_parameters():

if param.requires_grad:

print(f"参数名: {name}")

print(f"形状: {param.shape}")

print(f"参数量: {param.numel():,}")

print(f"数据类型: {param.dtype}")

print(f"设备: {param.device}")

print("-" * 40)

total_params += param.numel()

print(f"总可训练参数量: {total_params:,}")

print("=" * 60)

============ 6. 主程序 ============

def main(data_path="./mnist_train_sorted",test_path="./mnist_test_sorted"):

"""主函数:运行完整训练流程"""

检查数据路径

if not os.path.exists(data_path):

print(f"❌ 错误: 数据路径 '{data_path}' 不存在")

print("请创建目录结构:")

print(f"{data_path}/")

print("├── 类别1/")

print("│ ├── image1.jpg")

print("│ └── image2.jpg")

print("├── 类别2/")

print("└── 类别3/")

return None

加载数据

print("加载自定义图像数据...")

train_loader, test_loader, num_classes, class_names = get_custom_data(

data_root=data_path,

test_root=test_path,

batch_size=32,

img_size=(64, 64) # 可以调整图像大小

)

创建模型

print(f"\n创建CustomSimpleCNN模型,类别数: {num_classes}")

model = CustomSimpleCNN(num_classes=num_classes, input_channels=3).to(device)

print(model)

统计参数量

total_params = sum(p.numel() for p in model.parameters())

print(f"\n总参数量: {total_params:,}")

训练

train_losses, test_accuracies = train_model(

model, train_loader, test_loader,

epochs=15, lr=0.001 # 可以调整epochs

)

可视化

print("\n生成可视化结果...")

visualize_results(model, test_loader, class_names, train_losses, test_accuracies)

保存模型

torch.save({

'model_state_dict': model.state_dict(),

'class_names': class_names,

'num_classes': num_classes

}, './custom_cnn_model.pth')

print("\n模型已保存!")

return model, class_names

运行

if name == "main":

替换为你的数据路径

data_path = "./mnist_train_sorted" # 修改为你的数据目录

test_path = "./mnist_test_sorted" # 修改为你的数据目录

model, class_names = main(data_path,test_path)

if model is not None:

print_model_parameters(model)

相关推荐
Yeh20205814 小时前
cookie与Session笔记
笔记
d111111111d15 小时前
STM32-UART封装问题解析
笔记·stm32·单片机·嵌入式硬件·学习·算法
寒秋花开曾相惜15 小时前
(学习笔记)4.2 逻辑设计和硬件控制语言HCL(4.2.1 逻辑门&4.2.2 组合电路和HCL布尔表达式)
linux·网络·数据结构·笔记·学习·fpga开发
Yeh20205816 小时前
request与response笔记
java·前端·笔记
Fuyo_111916 小时前
C++ 内存管理
c++·笔记
柳鲲鹏17 小时前
李善兰和牛顿,谁剽窃谁的运动三定律
笔记
handler0118 小时前
Linux 进程探索:从 PCB 管理到 fork() 的写时拷贝
linux·c语言·c++·笔记·学习
xuhaoyu_cpp_java18 小时前
MyBatis学习(五)
经验分享·笔记·学习·mybatis
AI_6614659719 小时前
副业平台收益效率评估:实验设计、指标体系与数据分析框架
经验分享·笔记
阿星_19 小时前
Windows Subsystem for Linux (WSL) 运行 Firefox 浏览器时遇到中文乱码的解决方法
笔记