mnist cnn

import torch

import torch.nn as nn

import torch.optim as optim

import torchvision

from torchvision import transforms, datasets

from torch.utils.data import DataLoader

import matplotlib.pyplot as plt

import numpy as np

import os

from sklearn.metrics import confusion_matrix

import seaborn as sns

设置随机种子

torch.manual_seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print(f"使用设备: {device}")

============ 1. 数据准备 ============

def get_custom_data(data_root,test_root, batch_size=32, img_size=(64, 64)):

"""

加载自定义图像数据

Args:

data_root: 数据根目录,包含n个子目录

img_size: 图像调整大小

"""

数据预处理

transform = transforms.Compose([

transforms.Resize(img_size), # 调整图像大小

transforms.ToTensor(), # 转为Tensor

transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # 标准化

])

加载数据集

dataset = datasets.ImageFolder(root=data_root, transform=transform)

testset = datasets.ImageFolder(root=test_root, transform=transform)

计算训练集和测试集大小

#train_size = len(dataset)

#test_size = len(testset)

分割数据集

#train_dataset, test_dataset = torch.utils.data.random_split(

dataset, [train_size, test_size]

#)

创建数据加载器

train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False)

获取类别信息

class_names = dataset.classes

num_classes = len(class_names)

print(f"数据集信息:")

print(f" 训练集: {len(dataset)}")

print(f" 测试集: {len(testset)}")

print(f" 类别数: {num_classes}")

print(f" 类别名称: {class_names}")

return train_loader, test_loader, num_classes, class_names

============ 2. 修改的SimpleCNN模型 ============

class CustomSimpleCNN(nn.Module):

"""

修改的SimpleCNN,适应自定义图像分类

输入: [batch, 3, height, width] (RGB图像)

输出: [batch, num_classes]

"""

def init(self, num_classes, input_channels=3):

super(CustomSimpleCNN, self).init()

卷积层1

self.conv1 = nn.Sequential(

nn.Conv2d(input_channels, 32, kernel_size=3, padding=1),

nn.ReLU(),

nn.MaxPool2d(2) # 尺寸减半

)

卷积层2

self.conv2 = nn.Sequential(

nn.Conv2d(32, 64, kernel_size=3, padding=1),

nn.ReLU(),

nn.MaxPool2d(2) # 尺寸再减半

)

卷积层3

self.conv3 = nn.Sequential(

nn.Conv2d(64, 128, kernel_size=3, padding=1),

nn.ReLU(),

nn.MaxPool2d(2) # 尺寸再减半

)

自适应全局平均池化,适应不同输入尺寸

self.adaptive_pool = nn.AdaptiveAvgPool2d((4, 4))

全连接层

self.fc = nn.Sequential(

nn.Flatten(),

nn.Linear(128 * 4 * 4, 256), # 128 * 4 * 4=2048

nn.ReLU(),

nn.Dropout(0.5),

nn.Linear(256, num_classes)

)

def forward(self, x):

x = self.conv1(x)

x = self.conv2(x)

x = self.conv3(x)

x = self.adaptive_pool(x) # 统一输出尺寸

x = self.fc(x)

return x

============ 3. 训练函数 ============

def train_model(model, train_loader, test_loader, epochs=10, lr=0.001):

"""训练CNN模型"""

criterion = nn.CrossEntropyLoss()

optimizer = optim.Adam(model.parameters(), lr=lr)

train_losses = []

test_accuracies = []

print(f"\n开始训练 {epochs} 个epoch...")

for epoch in range(epochs):

训练模式

model.train()

running_loss = 0.0

correct = 0

total = 0

for batch_idx, (data, target) in enumerate(train_loader):

data, target = data.to(device), target.to(device)

清零梯度

optimizer.zero_grad()

前向传播

output = model(data)

loss = criterion(output, target)

反向传播

loss.backward()

optimizer.step()

running_loss += loss.item()

计算准确率

_, predicted = torch.max(output.data, 1)

total += target.size(0)

correct += (predicted == target).sum().item()

if batch_idx % 10 == 0:

print(f'Epoch [{epoch+1}/{epochs}] '

f'Batch [{batch_idx}/{len(train_loader)}] '

f'Loss: {loss.item():.4f}')

计算平均训练损失和准确率

avg_loss = running_loss / len(train_loader)

train_accuracy = 100 * correct / total

train_losses.append(avg_loss)

测试模式

test_accuracy = evaluate_model(model, test_loader)

test_accuracies.append(test_accuracy)

print(f'Epoch [{epoch+1}/{epochs}] 完成')

print(f' 训练损失: {avg_loss:.4f}, 训练准确率: {train_accuracy:.2f}%')

print(f' 测试准确率: {test_accuracy:.2f}%')

return train_losses, test_accuracies

def evaluate_model(model, test_loader):

"""评估模型准确率"""

model.eval()

correct = 0

total = 0

with torch.no_grad():

for data, target in test_loader:

data, target = data.to(device), target.to(device)

output = model(data)

_, predicted = torch.max(output.data, 1)

total += target.size(0)

correct += (predicted == target).sum().item()

return 100 * correct / total

============ 4. 可视化结果 ============

def visualize_results(model, test_loader, class_names, train_losses, test_accuracies):

"""可视化训练和预测结果"""

fig, axes = plt.subplots(2, 2, figsize=(15, 12))

1. 训练损失曲线

ax1 = axes[0, 0]

ax1.plot(train_losses, 'b-o', linewidth=2, markersize=6)

ax1.set_title('训练损失曲线', fontsize=12, fontweight='bold')

ax1.set_xlabel('Epoch')

ax1.set_ylabel('Loss')

ax1.grid(True, alpha=0.3)

2. 测试准确率曲线

ax2 = axes[0, 1]

ax2.plot(test_accuracies, 'r-s', linewidth=2, markersize=6)

ax2.set_title('测试准确率曲线', fontsize=12, fontweight='bold')

ax2.set_xlabel('Epoch')

ax2.set_ylabel('Accuracy (%)')

ax2.grid(True, alpha=0.3)

3. 显示部分测试样本及预测

ax3 = axes[1, 0]

data_iter = iter(test_loader)

images, labels = next(data_iter)

预测

model.eval()

with torch.no_grad():

outputs = model(images.to(device))

predictions = torch.max(outputs, 1)[1].cpu().numpy()

显示12个样本

num_samples = min(12, len(images))

for i in range(num_samples):

ax = plt.subplot(3, 4, i+1)

反标准化显示图像

img = images[i].numpy().transpose((1, 2, 0))

img = img * 0.5 + 0.5 # 反标准化

plt.imshow(img)

true_label = class_names[labels[i]]

pred_label = class_names[predictions[i]]

color = 'green' if predictions[i] == labels[i] else 'red'

plt.title(f'True: {true_label}\nPred: {pred_label}', color=color, fontsize=8)

plt.axis('off')

plt.suptitle('预测结果 (绿色=正确, 红色=错误)', fontsize=12, fontweight='bold')

4. 混淆矩阵

ax4 = axes[1, 1]

model.eval()

all_preds = []

all_labels = []

with torch.no_grad():

for data, target in test_loader:

data = data.to(device)

output = model(data)

preds = torch.max(output, 1)[1].cpu().numpy()

all_preds.extend(preds)

all_labels.extend(target.numpy())

计算混淆矩阵

cm = confusion_matrix(all_labels, all_preds)

使用seaborn绘制热力图

sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',

xticklabels=class_names, yticklabels=class_names, ax=ax4)

ax4.set_title('混淆矩阵', fontsize=12, fontweight='bold')

ax4.set_xlabel('预测标签')

ax4.set_ylabel('真实标签')

plt.setp(ax4.get_xticklabels(), rotation=45)

plt.setp(ax4.get_yticklabels(), rotation=0)

plt.tight_layout()

plt.savefig('./custom_cnn_results.png', dpi=150, bbox_inches='tight')

plt.show()

============ 5. 打印模型参数 ============

def print_model_parameters(model):

"""打印模型所有参数详情"""

print("=" * 60)

print("模型参数详情:")

print("=" * 60)

total_params = 0

for name, param in model.named_parameters():

if param.requires_grad:

print(f"参数名: {name}")

print(f"形状: {param.shape}")

print(f"参数量: {param.numel():,}")

print(f"数据类型: {param.dtype}")

print(f"设备: {param.device}")

print("-" * 40)

total_params += param.numel()

print(f"总可训练参数量: {total_params:,}")

print("=" * 60)

============ 6. 主程序 ============

def main(data_path="./mnist_train_sorted",test_path="./mnist_test_sorted"):

"""主函数:运行完整训练流程"""

检查数据路径

if not os.path.exists(data_path):

print(f"❌ 错误: 数据路径 '{data_path}' 不存在")

print("请创建目录结构:")

print(f"{data_path}/")

print("├── 类别1/")

print("│ ├── image1.jpg")

print("│ └── image2.jpg")

print("├── 类别2/")

print("└── 类别3/")

return None

加载数据

print("加载自定义图像数据...")

train_loader, test_loader, num_classes, class_names = get_custom_data(

data_root=data_path,

test_root=test_path,

batch_size=32,

img_size=(64, 64) # 可以调整图像大小

)

创建模型

print(f"\n创建CustomSimpleCNN模型,类别数: {num_classes}")

model = CustomSimpleCNN(num_classes=num_classes, input_channels=3).to(device)

print(model)

统计参数量

total_params = sum(p.numel() for p in model.parameters())

print(f"\n总参数量: {total_params:,}")

训练

train_losses, test_accuracies = train_model(

model, train_loader, test_loader,

epochs=15, lr=0.001 # 可以调整epochs

)

可视化

print("\n生成可视化结果...")

visualize_results(model, test_loader, class_names, train_losses, test_accuracies)

保存模型

torch.save({

'model_state_dict': model.state_dict(),

'class_names': class_names,

'num_classes': num_classes

}, './custom_cnn_model.pth')

print("\n模型已保存!")

return model, class_names

运行

if name == "main":

替换为你的数据路径

data_path = "./mnist_train_sorted" # 修改为你的数据目录

test_path = "./mnist_test_sorted" # 修改为你的数据目录

model, class_names = main(data_path,test_path)

if model is not None:

print_model_parameters(model)

相关推荐
张同学031 天前
220V 转 12V/5V 电源输入电路设计笔记
笔记·嵌入式硬件·硬件工程
深蓝海拓1 天前
S7-1500PLC学习笔记:MOVE_BLK、MOVE_BLK_VARIANT、BLKMOV的区别
笔记·学习·plc
雨浓YN1 天前
OPC UA 通讯开发笔记 - 基于本地dll文件
windows·笔记
深蓝海拓1 天前
S7-1500学习笔记:用户自定义数据类型(UDT)
笔记·学习·plc
罗罗攀1 天前
PyTorch学习笔记|神经网络的损失函数
人工智能·pytorch·笔记·神经网络·学习
tq10861 天前
价值:社会对劳动所产生的效用增量形成的局部共识
笔记
A923A1 天前
【小兔鲜电商前台 | 项目笔记】第八天
前端·vue.js·笔记·项目·小兔鲜
猹叉叉(学习版)1 天前
【系统分析师_知识点整理】 15.数学计算与知识产权
笔记·软考·知识产权·系统分析师
風清掦1 天前
【江科大STM32学习笔记-10】I2C通信协议 - 10.1 软件I2C读写MPU6050
笔记·stm32·单片机·嵌入式硬件·物联网·学习
MwEUwQ3Gx1 天前
常见Linux权限提升笔记
linux·运维·笔记