卷积神经网络训练模型的保存与引用

在实际生产生活中,为了保证模型的训练正确率足够大,其中一个重要条件就是拥有足够的数据集,但当数据集过大,或者模型在实际生活中部署时,我们不可能花费大量的时间去现场训练模型,此时我们便需要保存模型并且在之后调用它

模型的保存:

python 复制代码
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
from torchvision.models import resnet18, ResNet18_Weights
import matplotlib.pyplot as plt
from PIL import Image
import warnings

warnings.filterwarnings('ignore')

# ====================== 1. 配置参数(可根据需求修改) ======================
DATA_ROOT = r"E:\pycharm\Py_Projects\food_dataset2\food_dataset2"
BATCH_SIZE = 32
EPOCHS = 20
LEARNING_RATE = 0.001
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
IMG_SIZE = (224, 224)
MODEL_SAVE_PATH = "food_classification_model.pth"

# ====================== 2. 数据预处理与加载 ======================
train_transforms = transforms.Compose([
    transforms.Resize(IMG_SIZE),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

test_transforms = transforms.Compose([
    transforms.Resize(IMG_SIZE),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])


# 修复后的自定义数据集类(遵循ImageFolder接口规范)
class FoodDataset(datasets.ImageFolder):
    # 仅返回classes和class_to_idx(符合父类要求)
    def find_classes(self, directory):
        classes = []
        # 遍历文件夹提取纯食物名称
        for root, dirs, files in os.walk(directory, topdown=False):
            for dir_name in dirs:
                food_name = dir_name.split("_")[-1]
                if food_name not in classes:
                    classes.append(food_name)
        # 排序保证标签稳定
        classes.sort()
        class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
        return classes, class_to_idx

    # 重写make_dataset方法,自定义样本路径和标签映射
    def make_dataset(self, directory, class_to_idx, extensions=None, is_valid_file=None):
        samples = []
        directory = os.path.expanduser(directory)
        # 支持的图片格式
        valid_extensions = ('.png', '.jpg', '.jpeg') if extensions is None else extensions

        for root, dirs, files in os.walk(directory):
            for file in files:
                # 过滤有效图片文件
                if file.lower().endswith(valid_extensions):
                    # 提取食物名称作为标签
                    dir_name = os.path.basename(root)
                    food_name = dir_name.split("_")[-1]
                    if food_name in class_to_idx:
                        path = os.path.join(root, file)
                        item = (path, class_to_idx[food_name])
                        samples.append(item)
        return samples


# 加载训练集和测试集
print("【阶段1:加载数据集】开始解析数据集...")
train_dataset = FoodDataset(root=os.path.join(DATA_ROOT, "train"), transform=train_transforms)
test_dataset = FoodDataset(root=os.path.join(DATA_ROOT, "test"), transform=test_transforms)

# 创建数据加载器
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

# 输出数据集信息
num_classes = len(train_dataset.classes)
print(f"✅ 数据集加载完成!")
print(f"   - 训练集样本数:{len(train_dataset)}")
print(f"   - 测试集样本数:{len(test_dataset)}")
print(f"   - 食物类别数:{num_classes}")
print(f"   - 类别列表:{train_dataset.classes}")
print(f"   - 使用设备:{DEVICE}")

# ====================== 后续代码(3-8阶段)完全不变 ======================
# 3. 构建CNN模型
print("\n【阶段2:构建模型】初始化ResNet18模型...")
model = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
model.fc = nn.Linear(model.fc.in_features, num_classes)
model = model.to(DEVICE)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.8)


# 4. 训练与验证函数
def train_one_epoch(model, loader, criterion, optimizer, epoch):
    model.train()
    total_loss = 0.0
    correct = 0
    total = 0
    for batch_idx, (images, labels) in enumerate(loader):
        images, labels = images.to(DEVICE), labels.to(DEVICE)

        outputs = model(images)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

        if batch_idx % 10 == 0:
            batch_acc = 100 * correct / total
            batch_loss = total_loss / (batch_idx + 1)
            print(
                f"   训练轮次[{epoch + 1}/{EPOCHS}] | 批次[{batch_idx}/{len(loader)}] | 批次损失:{batch_loss:.4f} | 批次准确率:{batch_acc:.2f}%")

    epoch_loss = total_loss / len(loader)
    epoch_acc = 100 * correct / total
    return epoch_loss, epoch_acc


def validate(model, loader, criterion):
    model.eval()
    total_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            outputs = model(images)
            loss = criterion(outputs, labels)

            total_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    val_loss = total_loss / len(loader)
    val_acc = 100 * correct / total
    return val_loss, val_acc


# 5. 开始训练
print("\n【阶段3:开始训练】")
train_loss_history = []
train_acc_history = []
val_loss_history = []
val_acc_history = []

best_val_acc = 0.0

for epoch in range(EPOCHS):
    train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer, epoch)
    val_loss, val_acc = validate(model, test_loader, criterion)
    scheduler.step()

    train_loss_history.append(train_loss)
    train_acc_history.append(train_acc)
    val_loss_history.append(val_loss)
    val_acc_history.append(val_acc)

    print(f"\n✅ 训练轮次[{epoch + 1}/{EPOCHS}] 完成!")
    print(f"   - 训练损失:{train_loss:.4f} | 训练准确率:{train_acc:.2f}%")
    print(f"   - 验证损失:{val_loss:.4f} | 验证准确率:{val_acc:.2f}%")
    print(f"   - 当前学习率:{scheduler.get_last_lr()[0]:.6f}")

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'best_val_acc': best_val_acc,
            'class_to_idx': train_dataset.class_to_idx,
            'classes': train_dataset.classes
        }, MODEL_SAVE_PATH)
        print(f"   🎉 最优模型已保存!当前最高验证准确率:{best_val_acc:.2f}%")
    print("-" * 80)

# 6. 可视化结果
print("\n【阶段4:训练完成】开始可视化训练结果...")
plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.plot(train_loss_history, label='训练损失', color='blue')
plt.plot(val_loss_history, label='验证损失', color='red')
plt.title('训练/验证损失变化')
plt.xlabel('轮次')
plt.ylabel('损失值')
plt.legend()
plt.grid(True)

plt.subplot(1, 2, 2)
plt.plot(train_acc_history, label='训练准确率', color='blue')
plt.plot(val_acc_history, label='验证准确率', color='red')
plt.title('训练/验证准确率变化')
plt.xlabel('轮次')
plt.ylabel('准确率(%)')
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.savefig("training_result.png")
plt.show()

print(f"\n📊 训练总结:")
print(f"   - 最终训练准确率:{train_acc:.2f}%")
print(f"   - 最终验证准确率:{val_acc:.2f}%")
print(f"   - 最高验证准确率:{best_val_acc:.2f}%")
print(f"   - 模型已保存至:{MODEL_SAVE_PATH}")


# 7. 预测函数
def predict_food(image_path, model_path=MODEL_SAVE_PATH):
    print(f"\n【阶段5:预测】开始识别图片:{image_path}")

    checkpoint = torch.load(model_path, map_location=DEVICE)
    model.load_state_dict(checkpoint['model_state_dict'])
    classes = checkpoint['classes']
    model.eval()

    try:
        image = Image.open(image_path).convert('RGB')
        transform = test_transforms
        image = transform(image).unsqueeze(0)
        image = image.to(DEVICE)

        with torch.no_grad():
            outputs = model(image)
            _, predicted_idx = torch.max(outputs, 1)
            predicted_class = classes[predicted_idx.item()]

        print(f"✅ 识别结果:{predicted_class}")
        return predicted_class
    except Exception as e:
        print(f"❌ 预测失败!错误信息:{e}")
        return None


# 8. 测试预测功能
if __name__ == "__main__":
    test_image_path = r"E:\pycharm\Py_Projects\food_dataset2\food_dataset2\train\厨余垃圾_巴旦木\img_巴旦木_1.jpeg"
    predict_food(test_image_path)

运行结果(20轮):

如图所示,准确率很可观:

首先,在配置参数时定义保存模型的名称:

python 复制代码
MODEL_SAVE_PATH = "food_classification_model.pth"

在训练与验证函数时:

python 复制代码
for epoch in range(EPOCHS):
    train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer, epoch)
    val_loss, val_acc = validate(model, test_loader, criterion)
    scheduler.step()

    train_loss_history.append(train_loss)
    train_acc_history.append(train_acc)
    val_loss_history.append(val_loss)
    val_acc_history.append(val_acc)

    print(f"\n✅ 训练轮次[{epoch + 1}/{EPOCHS}] 完成!")
    print(f"   - 训练损失:{train_loss:.4f} | 训练准确率:{train_acc:.2f}%")
    print(f"   - 验证损失:{val_loss:.4f} | 验证准确率:{val_acc:.2f}%")
    print(f"   - 当前学习率:{scheduler.get_last_lr()[0]:.6f}")

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'best_val_acc': best_val_acc,
            'class_to_idx': train_dataset.class_to_idx,
            'classes': train_dataset.classes
        }, MODEL_SAVE_PATH)
        print(f"   🎉 最优模型已保存!当前最高验证准确率:{best_val_acc:.2f}%")
    print("-" * 80)

python 复制代码
if val_acc > best_val_acc:
    best_val_acc = val_acc

中,如果当前训练验证准确率由于历史最好验证准确率,则用新的模型将其替代。

接下来运行

torch.save({内容},路径/名称)

来保存模型文件。

模型文件字典的内容:

字典中每个键值对都有明确作用,缺一不可:

键名 含义 & 作用
epoch 保存最优模型时的训练轮次(便于后续复现 / 继续训练,知道模型在第几轮达到最优)
model_state_dict 模型的权重参数(模型的核心,决定模型预测能力,model.state_dict()返回所有可学习参数)
optimizer_state_dict 优化器的状态参数(如果后续要继续训练,需要恢复优化器的学习率、动量等状态)
best_val_acc 该模型对应的最高验证准确率(便于后续查看模型效果,无需重新计算)
class_to_idx 类别名称到标签索引的映射(如{"巴旦木":0, "苹果":1},保证预测时标签对应正确)
classes 完整的类别列表(如["巴旦木", "苹果", "香蕉"],预测时通过索引反查类别名称)

随着for循环的反复遍历,比如我训练时使其训练了二十轮,if就会进行反复对比,由于文件名相同,每次运行完后新的pth文件就会替代老的文件,知道best_val_acc得出20次训练后的最优结果。

然后再预测函数中会调用model_state_dict和classes

  • torch.load():读取.pth文件,恢复保存的字典;
  • model.load_state_dict():将保存的权重加载到模型中,让模型具备预测能力;
  • classes:从保存的字典中恢复类别列表,保证预测时能输出正确的类别名称(而非数字索引)
python 复制代码
# 7. 预测函数
def predict_food(image_path, model_path=MODEL_SAVE_PATH):
    print(f"\n【阶段5:预测】开始识别图片:{image_path}")

    checkpoint = torch.load(model_path, map_location=DEVICE)
    model.load_state_dict(checkpoint['model_state_dict'])
    classes = checkpoint['classes']
    model.eval()

    try:
        image = Image.open(image_path).convert('RGB')
        transform = test_transforms
        image = transform(image).unsqueeze(0)
        image = image.to(DEVICE)

        with torch.no_grad():
            outputs = model(image)
            _, predicted_idx = torch.max(outputs, 1)
            predicted_class = classes[predicted_idx.item()]

        print(f"✅ 识别结果:{predicted_class}")
        return predicted_class
    except Exception as e:
        print(f"❌ 预测失败!错误信息:{e}")
        return None

模型的引用

在相同分类(20)的另一个小模型内

python 复制代码
# ====================== 2. 加载新数据集并验证分类数 ======================
# 修改数据预处理,添加数据增强
transform = transforms.Compose([
    transforms.Resize((36, 36)),  # 先放大,再随机裁剪
    transforms.RandomCrop((32, 32)),  # 随机裁剪到训练尺寸
    transforms.RandomHorizontalFlip(p=0.5),  # 随机水平翻转
    transforms.ColorJitter(brightness=0.2, contrast=0.2),  # 随机调整亮度/对比度
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# 新数据集路径
new_data_root = r"E:\pycharm\Py_Projects\pythonProjectlearning3.7.9\.venv\food_dataset"
train_dataset = datasets.ImageFolder(root=new_data_root + "/train", transform=transform)

# 打印数据集核心信息
print("=" * 50)
print("新数据集验证信息:")
print(f"实际分类数:{len(train_dataset.classes)}")
print(f"类别名称 & 标签索引:{train_dataset.class_to_idx}")
class_names = train_dataset.classes  # 保存类别名称,用于后续预测标注
print("=" * 50)

# 自动匹配分类数
new_num_classes = len(train_dataset.classes)
train_loader = DataLoader(train_dataset, batch_size=20, shuffle=True)

# ====================== 3. 加载原模型并适配新分类数 ======================
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SimpleCNN(num_classes=new_num_classes).to(DEVICE)

# 加载原模型权重
checkpoint = torch.load("food_classification_model.pth", map_location=DEVICE)
model_dict = model.state_dict()
pretrained_dict = checkpoint['model_state_dict']

# 优化权重过滤逻辑
pretrained_dict_filtered = {}
for k, v in pretrained_dict.items():
    if k in model_dict and model_dict[k].shape == v.shape:
        pretrained_dict_filtered[k] = v
    else:
        print(
            f"跳过不匹配的层:{k}(原模型形状:{v.shape},新模型形状:{model_dict[k].shape if k in model_dict else '不存在'})")

model_dict.update(pretrained_dict_filtered)
model.load_state_dict(model_dict, strict=False)

先加载数据集,然后调用之前训练好的模型的权重。

相关推荐
Piar1231sdafa5 小时前
蓝莓目标检测——改进YOLO11-C2TSSA-DYT-Mona模型实现
人工智能·目标检测·计算机视觉
愚公搬代码5 小时前
【愚公系列】《AI短视频创作一本通》002-AI引爆短视频创作革命(短视频创作者必备的能力)
人工智能
数据猿视觉5 小时前
新品上市|奢音S5耳夹耳机:3.5g无感佩戴,178.8元全场景适配
人工智能
我有酒两杯5 小时前
引导模型生成具有反思和验证机制的response的指令
深度学习
蚁巡信息巡查系统5 小时前
网站信息发布再巡查机制怎么建立?
大数据·人工智能·数据挖掘·内容运营
AI浩5 小时前
C-RADIOv4(技术报告)
人工智能·目标检测
Purple Coder5 小时前
AI赋予超导材料预测论文初稿
人工智能
Data_Journal5 小时前
Scrapy vs. Crawlee —— 哪个更好?!
运维·人工智能·爬虫·媒体·社媒营销
云边云科技_云网融合5 小时前
AIoT智能物联网平台:架构解析与边缘应用新图景
大数据·网络·人工智能·安全
康康的AI博客5 小时前
什么是API中转服务商?如何低成本高稳定调用海量AI大模型?
人工智能·ai