import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
import os
from sklearn.metrics import confusion_matrix
import seaborn as sns
设置随机种子
torch.manual_seed(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"使用设备: {device}")
============ 1. 数据准备 ============
def get_custom_data(data_root,test_root, batch_size=32, img_size=(64, 64)):
"""
加载自定义图像数据
Args:
data_root: 数据根目录,包含n个子目录
img_size: 图像调整大小
"""
数据预处理
transform = transforms.Compose([
transforms.Resize(img_size), # 调整图像大小
transforms.ToTensor(), # 转为Tensor
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # 标准化
])
加载数据集
dataset = datasets.ImageFolder(root=data_root, transform=transform)
testset = datasets.ImageFolder(root=test_root, transform=transform)
计算训练集和测试集大小
#train_size = len(dataset)
#test_size = len(testset)
分割数据集
#train_dataset, test_dataset = torch.utils.data.random_split(
dataset, [train_size, test_size]
#)
创建数据加载器
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False)
获取类别信息
class_names = dataset.classes
num_classes = len(class_names)
print(f"数据集信息:")
print(f" 训练集: {len(dataset)}")
print(f" 测试集: {len(testset)}")
print(f" 类别数: {num_classes}")
print(f" 类别名称: {class_names}")
return train_loader, test_loader, num_classes, class_names
============ 2. 修改的SimpleCNN模型 ============
class CustomSimpleCNN(nn.Module):
"""
修改的SimpleCNN,适应自定义图像分类
输入: [batch, 3, height, width] (RGB图像)
输出: [batch, num_classes]
"""
def init(self, num_classes, input_channels=3):
super(CustomSimpleCNN, self).init()
卷积层1
self.conv1 = nn.Sequential(
nn.Conv2d(input_channels, 32, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2) # 尺寸减半
)
卷积层2
self.conv2 = nn.Sequential(
nn.Conv2d(32, 64, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2) # 尺寸再减半
)
卷积层3
self.conv3 = nn.Sequential(
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2) # 尺寸再减半
)
自适应全局平均池化,适应不同输入尺寸
self.adaptive_pool = nn.AdaptiveAvgPool2d((4, 4))
全连接层
self.fc = nn.Sequential(
nn.Flatten(),
nn.Linear(128 * 4 * 4, 256), # 128 * 4 * 4=2048
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(256, num_classes)
)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
x = self.adaptive_pool(x) # 统一输出尺寸
x = self.fc(x)
return x
============ 3. 训练函数 ============
def train_model(model, train_loader, test_loader, epochs=10, lr=0.001):
"""训练CNN模型"""
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)
train_losses = []
test_accuracies = []
print(f"\n开始训练 {epochs} 个epoch...")
for epoch in range(epochs):
训练模式
model.train()
running_loss = 0.0
correct = 0
total = 0
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
清零梯度
optimizer.zero_grad()
前向传播
output = model(data)
loss = criterion(output, target)
反向传播
loss.backward()
optimizer.step()
running_loss += loss.item()
计算准确率
_, predicted = torch.max(output.data, 1)
total += target.size(0)
correct += (predicted == target).sum().item()
if batch_idx % 10 == 0:
print(f'Epoch [{epoch+1}/{epochs}] '
f'Batch [{batch_idx}/{len(train_loader)}] '
f'Loss: {loss.item():.4f}')
计算平均训练损失和准确率
avg_loss = running_loss / len(train_loader)
train_accuracy = 100 * correct / total
train_losses.append(avg_loss)
测试模式
test_accuracy = evaluate_model(model, test_loader)
test_accuracies.append(test_accuracy)
print(f'Epoch [{epoch+1}/{epochs}] 完成')
print(f' 训练损失: {avg_loss:.4f}, 训练准确率: {train_accuracy:.2f}%')
print(f' 测试准确率: {test_accuracy:.2f}%')
return train_losses, test_accuracies
def evaluate_model(model, test_loader):
"""评估模型准确率"""
model.eval()
correct = 0
total = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
_, predicted = torch.max(output.data, 1)
total += target.size(0)
correct += (predicted == target).sum().item()
return 100 * correct / total
============ 4. 可视化结果 ============
def visualize_results(model, test_loader, class_names, train_losses, test_accuracies):
"""可视化训练和预测结果"""
fig, axes = plt.subplots(2, 2, figsize=(15, 12))
1. 训练损失曲线
ax1 = axes[0, 0]
ax1.plot(train_losses, 'b-o', linewidth=2, markersize=6)
ax1.set_title('训练损失曲线', fontsize=12, fontweight='bold')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.grid(True, alpha=0.3)
2. 测试准确率曲线
ax2 = axes[0, 1]
ax2.plot(test_accuracies, 'r-s', linewidth=2, markersize=6)
ax2.set_title('测试准确率曲线', fontsize=12, fontweight='bold')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Accuracy (%)')
ax2.grid(True, alpha=0.3)
3. 显示部分测试样本及预测
ax3 = axes[1, 0]
data_iter = iter(test_loader)
images, labels = next(data_iter)
预测
model.eval()
with torch.no_grad():
outputs = model(images.to(device))
predictions = torch.max(outputs, 1)[1].cpu().numpy()
显示12个样本
num_samples = min(12, len(images))
for i in range(num_samples):
ax = plt.subplot(3, 4, i+1)
反标准化显示图像
img = images[i].numpy().transpose((1, 2, 0))
img = img * 0.5 + 0.5 # 反标准化
plt.imshow(img)
true_label = class_names[labels[i]]
pred_label = class_names[predictions[i]]
color = 'green' if predictions[i] == labels[i] else 'red'
plt.title(f'True: {true_label}\nPred: {pred_label}', color=color, fontsize=8)
plt.axis('off')
plt.suptitle('预测结果 (绿色=正确, 红色=错误)', fontsize=12, fontweight='bold')
4. 混淆矩阵
ax4 = axes[1, 1]
model.eval()
all_preds = []
all_labels = []
with torch.no_grad():
for data, target in test_loader:
data = data.to(device)
output = model(data)
preds = torch.max(output, 1)[1].cpu().numpy()
all_preds.extend(preds)
all_labels.extend(target.numpy())
计算混淆矩阵
cm = confusion_matrix(all_labels, all_preds)
使用seaborn绘制热力图
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
xticklabels=class_names, yticklabels=class_names, ax=ax4)
ax4.set_title('混淆矩阵', fontsize=12, fontweight='bold')
ax4.set_xlabel('预测标签')
ax4.set_ylabel('真实标签')
plt.setp(ax4.get_xticklabels(), rotation=45)
plt.setp(ax4.get_yticklabels(), rotation=0)
plt.tight_layout()
plt.savefig('./custom_cnn_results.png', dpi=150, bbox_inches='tight')
plt.show()
============ 5. 打印模型参数 ============
def print_model_parameters(model):
"""打印模型所有参数详情"""
print("=" * 60)
print("模型参数详情:")
print("=" * 60)
total_params = 0
for name, param in model.named_parameters():
if param.requires_grad:
print(f"参数名: {name}")
print(f"形状: {param.shape}")
print(f"参数量: {param.numel():,}")
print(f"数据类型: {param.dtype}")
print(f"设备: {param.device}")
print("-" * 40)
total_params += param.numel()
print(f"总可训练参数量: {total_params:,}")
print("=" * 60)
============ 6. 主程序 ============
def main(data_path="./mnist_train_sorted",test_path="./mnist_test_sorted"):
"""主函数:运行完整训练流程"""
检查数据路径
if not os.path.exists(data_path):
print(f"❌ 错误: 数据路径 '{data_path}' 不存在")
print("请创建目录结构:")
print(f"{data_path}/")
print("├── 类别1/")
print("│ ├── image1.jpg")
print("│ └── image2.jpg")
print("├── 类别2/")
print("└── 类别3/")
return None
加载数据
print("加载自定义图像数据...")
train_loader, test_loader, num_classes, class_names = get_custom_data(
data_root=data_path,
test_root=test_path,
batch_size=32,
img_size=(64, 64) # 可以调整图像大小
)
创建模型
print(f"\n创建CustomSimpleCNN模型,类别数: {num_classes}")
model = CustomSimpleCNN(num_classes=num_classes, input_channels=3).to(device)
print(model)
统计参数量
total_params = sum(p.numel() for p in model.parameters())
print(f"\n总参数量: {total_params:,}")
训练
train_losses, test_accuracies = train_model(
model, train_loader, test_loader,
epochs=15, lr=0.001 # 可以调整epochs
)
可视化
print("\n生成可视化结果...")
visualize_results(model, test_loader, class_names, train_losses, test_accuracies)
保存模型
torch.save({
'model_state_dict': model.state_dict(),
'class_names': class_names,
'num_classes': num_classes
}, './custom_cnn_model.pth')
print("\n模型已保存!")
return model, class_names
运行
if name == "main":
替换为你的数据路径
data_path = "./mnist_train_sorted" # 修改为你的数据目录
test_path = "./mnist_test_sorted" # 修改为你的数据目录
model, class_names = main(data_path,test_path)
if model is not None:
print_model_parameters(model)