年龄预测识别模型训练python代码

涵盖断点续训、早停机制、定期保存检查点等

检查点保存逻辑

检查点保存分为两种:

最新检查点:每次训练都会保存为 latest.pth,用于恢复训练。

最佳模型:仅在验证损失达到新低时保存为 model_best.pth。

定期保存(每 checkpoint_interval 轮)也确保了即使训练中断,也能恢复到最近的状态。

  1. 早停机制

当验证损失连续 early_stop_patience 轮未改善时,触发早停,避免过拟合或浪费计算资源。

这是一个非常实用的功能,特别是在超参数调试阶段。

  1. 命令行参数支持

使用 argparse 支持通过命令行指定恢复训练的检查点路径,提升了脚本的灵活性。

  1. CUDA基准模式

启用 torch.backends.cudnn.benchmark = True 可以加速卷积操作,尤其是在输入尺寸固定的情况下。

复制代码
"""
一个黑客创业者:年龄预测模型完整训练(支持CPU/GPU、断点续训、早停机制)
执行方式:
1. 训练:python train_age.py
2. 恢复:python train_age.py --resume ./checkpoints/latest.pth
"""

import os
import time
import argparse
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from PIL import Image
from tqdm import tqdm

# 配置参数
CONFIG = {
    # 数据路径
    "train_age_list": r"D:\daku\性别\megaage_asian\list\train_age.txt",
    "val_age_list": r"D:\daku\性别\megaage_asian\list\test_age.txt",
    "train_image_dir": r"D:\daku\性别\megaage_asian\train",
    "val_image_dir": r"D:\daku\性别\megaage_asian\val",

    # 训练参数
    "batch_size": 64,
    "num_workers": 4 if torch.cuda.is_available() else 2,
    "learning_rate": 3e-4,
    "num_epochs": 100,
    "input_size": 224,

    # 系统参数
    "checkpoint_dir": "./checkpoints",
    "checkpoint_interval": 1,
    "early_stop_patience": 7,
    "use_amp": torch.cuda.is_available(),  # 自动判断是否启用混合精度
    "resume": None
}


class AgeDataset(Dataset):
    """处理序号命名图片和年龄列表的数据集"""

    def __init__(self, age_list_path, image_dir, transform=None):
        self.image_dir = image_dir
        self.transform = transform

        # 加载年龄数据
        with open(age_list_path, 'r') as f:
            self.ages = []
            line_count = 0
            for line in f:
                line_count += 1
                line = line.strip()
                try:
                    age = float(line)
                    if 0 <= age <= 120:
                        self.ages.append(age)
                    else:
                        print(f"行 {line_count}: 异常年龄值 {age},已过滤")
                except ValueError:
                    print(f"行 {line_count}: 无效年龄值 '{line}',已跳过")

        # 加载并排序图片文件
        self.image_files = sorted(
            [f for f in os.listdir(image_dir)
             if f.lower().endswith(('.jpg', '.jpeg', '.png'))],
            key=lambda x: int(os.path.splitext(x)[0])
        )

        # 对齐数据长度
        self.num_samples = min(len(self.ages), len(self.image_files))
        if len(self.ages) != len(self.image_files):
            print(
                f"警告: 年龄数({len(self.ages)})与图片数({len(self.image_files)})不一致,使用前{self.num_samples}个样本")

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        # 生成图片路径
        img_name = self.image_files[idx]
        img_path = os.path.join(self.image_dir, img_name)

        # 加载图片
        try:
            with Image.open(img_path) as img:
                image = img.convert('RGB')
        except Exception as e:
            print(f"图片加载失败: {img_path},错误: {str(e)}")
            return self[(idx + 1) % len(self)]  # 跳过错误样本

        # 获取年龄
        age = torch.tensor(self.ages[idx], dtype=torch.float32)

        if self.transform:
            image = self.transform(image)

        return image, age


def create_data_loaders():
    """创建数据加载器"""
    train_transform = transforms.Compose([
        transforms.Resize(256),
        transforms.RandomCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    val_transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    # 创建数据集
    train_set = AgeDataset(CONFIG["train_age_list"], CONFIG["train_image_dir"], train_transform)
    val_set = AgeDataset(CONFIG["val_age_list"], CONFIG["val_image_dir"], val_transform)

    print(f"\n数据集统计:")
    print(f"训练样本: {len(train_set)} | 验证样本: {len(val_set)}")

    # 创建数据加载器
    train_loader = DataLoader(
        train_set,
        batch_size=CONFIG["batch_size"],
        shuffle=True,
        num_workers=CONFIG["num_workers"],
        pin_memory=torch.cuda.is_available(),
        persistent_workers=torch.cuda.is_available()
    )

    val_loader = DataLoader(
        val_set,
        batch_size=CONFIG["batch_size"],
        shuffle=False,
        num_workers=CONFIG["num_workers"],
        pin_memory=torch.cuda.is_available()
    )

    return train_loader, val_loader


class AgeRegressor(nn.Module):
    """年龄回归模型"""

    def __init__(self, pretrained=True):
        super().__init__()
        base_model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT if pretrained else None)
        self.feature_extractor = nn.Sequential(*list(base_model.children())[:-1])
        self.regressor = nn.Sequential(
            nn.Linear(base_model.fc.in_features, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, 1)
        )

    def forward(self, x):
        features = self.feature_extractor(x).flatten(1)
        return self.regressor(features).squeeze(1)


def initialize_training(resume_path=None):
    """初始化训练环境"""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"\n训练设备: {device}")

    # 初始化模型
    model = AgeRegressor().to(device)
    optimizer = optim.AdamW(model.parameters(), lr=CONFIG["learning_rate"], weight_decay=1e-4)
    criterion = nn.HuberLoss()

    # 自动处理AMP
    scaler = torch.cuda.amp.GradScaler(enabled=CONFIG["use_amp"]) if torch.cuda.is_available() else None
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=3)

    # 训练状态
    start_epoch = 0
    best_loss = float('inf')
    no_improve = 0

    # 断点续训
    if resume_path and os.path.exists(resume_path):
        checkpoint = torch.load(resume_path, map_location=device)
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        scheduler.load_state_dict(checkpoint['scheduler'])
        start_epoch = checkpoint['epoch'] + 1
        best_loss = checkpoint['best_loss']
        no_improve = checkpoint['no_improve']
        if scaler and 'scaler' in checkpoint:
            scaler.load_state_dict(checkpoint['scaler'])
        print(f"成功恢复训练状态,从第 {start_epoch} 轮开始")

    return {
        "device": device,
        "model": model,
        "optimizer": optimizer,
        "criterion": criterion,
        "scaler": scaler,
        "scheduler": scheduler,
        "start_epoch": start_epoch,
        "best_loss": best_loss,
        "no_improve": no_improve
    }


def train_epoch(model, device, train_loader, optimizer, criterion, scaler):
    """训练单个epoch"""
    model.train()
    total_loss = 0.0

    with tqdm(train_loader, desc="训练", unit="batch") as pbar:
        for images, labels in pbar:
            images = images.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)

            optimizer.zero_grad(set_to_none=True)

            # 混合精度训练
            with torch.autocast(device_type='cuda' if torch.cuda.is_available() else 'cpu',
                                enabled=CONFIG["use_amp"]):
                outputs = model(images)
                loss = criterion(outputs, labels)

            # 反向传播
            if scaler:
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
            else:
                loss.backward()
                optimizer.step()

            total_loss += loss.item() * images.size(0)
            pbar.set_postfix(loss=loss.item())

    return total_loss / len(train_loader.dataset)


def validate(model, device, val_loader, criterion):
    """验证循环"""
    model.eval()
    total_loss = 0.0

    with torch.no_grad(), tqdm(val_loader, desc="验证", unit="batch") as pbar:
        for images, labels in pbar:
            images = images.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)

            outputs = model(images)
            loss = criterion(outputs, labels)

            total_loss += loss.item() * images.size(0)
            pbar.set_postfix(loss=loss.item())

    return total_loss / len(val_loader.dataset)


def save_checkpoint(state, filename, is_best=False):
    """保存检查点"""
    os.makedirs(CONFIG["checkpoint_dir"], exist_ok=True)
    filepath = os.path.join(CONFIG["checkpoint_dir"], filename)

    # 保存完整状态
    torch.save(state, filepath)

    # 保存最佳模型
    if is_best:
        best_path = os.path.join(CONFIG["checkpoint_dir"], "model_best.pth")
        torch.save(state["model"], best_path)


def main():
    # 初始化
    train_loader, val_loader = create_data_loaders()
    training_env = initialize_training(CONFIG["resume"])

    # 解包训练环境
    device = training_env["device"]
    model = training_env["model"]
    optimizer = training_env["optimizer"]
    criterion = training_env["criterion"]
    scaler = training_env["scaler"]
    scheduler = training_env["scheduler"]
    start_epoch = training_env["start_epoch"]
    best_loss = training_env["best_loss"]
    no_improve = training_env["no_improve"]

    # 训练循环
    for epoch in range(start_epoch, CONFIG["num_epochs"]):
        print(f"\nEpoch {epoch + 1}/{CONFIG['num_epochs']}")
        start_time = time.time()

        # 训练与验证
        train_loss = train_epoch(model, device, train_loader, optimizer, criterion, scaler)
        val_loss = validate(model, device, val_loader, criterion)
        scheduler.step(val_loss)

        # 统计信息
        epoch_time = time.time() - start_time
        lr = optimizer.param_groups[0]['lr']
        print(f"耗时: {epoch_time // 60:.0f}m{epoch_time % 60:.0f}s | LR: {lr:.1e} | "
              f"训练损失: {train_loss:.4f} | 验证损失: {val_loss:.4f}")

        # 保存检查点
        is_best = val_loss < best_loss
        if is_best:
            best_loss = val_loss
            no_improve = 0
        else:
            no_improve += 1

        checkpoint = {
            'epoch': epoch,
            'model': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'scheduler': scheduler.state_dict(),
            'scaler': scaler.state_dict() if scaler else None,
            'best_loss': best_loss,
            'no_improve': no_improve,
            'config': CONFIG
        }

        # 定期保存
        if is_best or (epoch + 1) % CONFIG["checkpoint_interval"] == 0:
            save_checkpoint(checkpoint, f"epoch_{epoch + 1}.pth", is_best)

        # 保存最新检查点
        save_checkpoint(checkpoint, "latest.pth")

        # 早停机制
        if no_improve >= CONFIG["early_stop_patience"]:
            print(f"\n早停触发: 验证损失连续 {CONFIG['early_stop_patience']} 轮未提升")
            break


if __name__ == "__main__":
    # 命令行参数
    parser = argparse.ArgumentParser()
    parser.add_argument('--resume', help='恢复训练的检查点路径')
    args = parser.parse_args()

    if args.resume:
        CONFIG["resume"] = args.resume

    # 设置CUDA基准模式
    if torch.cuda.is_available():
        torch.backends.cudnn.benchmark = True

    # 启动训练
    main()
相关推荐
Xxtaoaooo1 分钟前
OCR文字识别前沿:PaddleOCR/DBNet++的端到端文本检测与识别
人工智能·ai·ocr·文本检测·dbnet++
taxunjishu2 分钟前
DeviceNet 转 MODBUS TCP:倍福 CX 系列 PLC 与 MES 系统在 SMT 回流焊温度曲线监控的通讯配置案例
运维·人工智能·物联网·自动化·区块链
小烤箱9 分钟前
自动驾驶工程师面试(定位、感知向)
人工智能·面试·自动驾驶
IT_陈寒11 分钟前
《Redis性能翻倍的7个冷门技巧,90%开发者都不知道!》
前端·人工智能·后端
龙俊杰的读书笔记24 分钟前
《小白学随机过程》第一章:随机过程——定义和形式 (附录1 探究随机变量)
人工智能·机器学习·概率论·随机过程和rl
长空任鸟飞_阿康35 分钟前
在 Vue 3.5 中优雅地集成 wangEditor,并定制“AI 工具”下拉菜单(总结/润色/翻译)
前端·vue.js·人工智能
滑水滑成滑头41 分钟前
**发散创新:多智能体系统的探索与实践**随着人工智能技术的飞速发展,多智能体系统作为当今研究的热点领域,正受到越来越多关注
java·网络·人工智能·python
盼小辉丶41 分钟前
PyTorch实战(9)——从零开始实现Transformer
pytorch·深度学习·transformer
云布道师1 小时前
阿里云 OSS MetaQuery 全面升级——新增内容和语义的检索能力,助力 AI 应用快速落地
人工智能·阿里云·云计算
m0_650108241 小时前
【论文精读】FlowVid:驯服不完美的光流,实现一致的视频到视频合成
人工智能·计算机视觉·扩散模型·视频编辑·视频生成·论文精读·不完美光流