深度学习篇---模型训练(1)


文章目录


前言

本文再网络结构(1)的基础上,完善数据读取、数据增强、数据处理、模型训练、断点训练等功能。


一、库导入与配置部分

python 复制代码
import torch
import torch.nn as nn  # PyTorch核心神经网络模块
import pandas as pd    # 数据处理
import numpy as np     # 数值计算
from torch.utils.data import Dataset, DataLoader  # 数据加载工具
from sklearn.preprocessing import StandardScaler  # 数据标准化
from sklearn.model_selection import train_test_split  # 数据分割
from torch.optim.lr_scheduler import ReduceLROnPlateau  # 动态学习率调整
from collections import Counter  # 统计类别分布
import csv  # 结果记录
import time  # 时间戳生成
import joblib  # 模型/参数持久化

介绍

导入Pytorch核心神经网路模块、数据处理库和数值处理库数据标准化、数据分割、动态学习率调整、统计类别分布、结果记录、时间戳生成、模型/参数持久化。

二、超参数配置

python 复制代码
config = {
    "batch_size": 256,        # 每批数据量
    "num_workers": 128,       # 数据加载并行进程数
    "lr": 1e-3,               # 初始学习率
    "weight_decay": 1e-4,     # L2正则化强度
    "epochs": 200,            # 最大训练轮数
    "patience": 15,           # 早停等待轮数
    "min_delta": 0.001,       # 视为改进的最小精度提升
    "grad_clip": 5.0,         # 梯度裁剪阈值
    "num_classes": None       # 自动计算类别数
}

简介

设置每批数据量、数据加载并行进程数、初始学习率、L2正则化强度、最大训练轮数、早停等待轮数、视为改进的最小精度提升、梯度剪裁阈值、自动计算类别数。

三、模型定义

1. 改进残差块

python 复制代码
class ImprovedResBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()  # 初始化父类
        
        # 第一个卷积层
        self.conv1 = nn.Conv1d(in_channels, out_channels, 5, stride, 2)
        # 参数解释:输入通道,输出通道,卷积核大小5,步长,填充2(保持尺寸)
        self.bn1 = nn.BatchNorm1d(out_channels)  # 批量归一化
        
        # 第二个卷积层
        self.conv2 = nn.Conv1d(out_channels, out_channels, 3, 1, 1)
        # 3x1卷积,步长1,填充1保持尺寸
        self.bn2 = nn.BatchNorm1d(out_channels)
        
        self.relu = nn.ReLU()  # 激活函数
        
        # 下采样路径(当需要调整维度时)
        self.downsample = nn.Sequential(
            nn.Conv1d(in_channels, out_channels, 1, stride),  # 1x1卷积调整维度
            nn.BatchNorm1d(out_channels)
        ) if in_channels != out_channels or stride != 1 else None
        # 当输入输出通道不同或步长>1时启用

    def forward(self, x):
        identity = x  # 保留原始输入作为残差
        
        # 主路径处理
        x = self.relu(self.bn1(self.conv1(x)))  # Conv1 -> BN1 -> ReLU
        x = self.bn2(self.conv2(x))  # Conv2 -> BN2(无激活)
        
        # 调整残差路径维度
        if self.downsample:
            identity = self.downsample(identity)
        
        x += identity  # 残差连接
        return self.relu(x)  # 最终激活

2. 完整CNN模型

python 复制代码
class EnhancedCNN(nn.Module):
    def __init__(self, input_channels, seq_len, num_classes):
        super().__init__()
        
        # 初始特征提取层
        self.initial = nn.Sequential(
            nn.Conv1d(input_channels, 64, 7, stride=2, padding=3),  # 快速下采样
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.MaxPool1d(3, 2, 1)  # 核3,步长2,填充1,输出尺寸约为输入1/4
        )
        
        # 残差块堆叠
        self.blocks = nn.Sequential(
            ImprovedResBlock(64, 128, stride=2),  # 通道翻倍,尺寸减半
            ImprovedResBlock(128, 256, stride=2),
            ImprovedResBlock(256, 512, stride=2),
            nn.AdaptiveAvgPool1d(1)  # 自适应全局平均池化到长度1
        )
        
        # 分类器
        self.classifier = nn.Sequential(
            nn.Linear(512, 256),     # 全连接层
            nn.Dropout(0.5),         # 强正则化防止过拟合
            nn.ReLU(),
            nn.Linear(256, num_classes)  # 最终分类层
        )

    def forward(self, x):
        x = self.initial(x)  # 初始特征提取
        x = self.blocks(x)   # 通过残差块
        x = x.view(x.size(0), -1)  # 展平维度 (batch, 512)
        return self.classifier(x)  # 分类预测

四、数据集类

python 复制代码
class SequenceDataset(Dataset):
    def __init__(self, sequences, labels, scaler=None):
        self.sequences = sequences  # 原始序列数据
        self.labels = labels        # 对应标签
        self.scaler = scaler or StandardScaler()  # 标准化器
        
        # 如果未提供scaler,用当前数据拟合新的
        if scaler is None:
            flattened = np.concatenate(sequences)  # 展平所有数据点
            self.scaler.fit(flattened)  # 计算均值和方差
        
        # 对每个序列进行标准化
        self.normalized = [self.scaler.transform(seq) for seq in sequences]

    def __len__(self):
        return len(self.sequences)  # 返回数据集大小

    def __getitem__(self, idx):
        # 获取单个样本
        seq = torch.tensor(self.normalized[idx], dtype=torch.float32).permute(1, 0)
        # permute将形状从(seq_len, features)转为(features, seq_len)符合Conv1d输入要求
        
        label = torch.tensor(self.labels[idx], dtype=torch.long)
        
        # 数据增强
        if np.random.rand() > 0.5:  # 50%概率时序翻转
            seq = seq.flip(-1)  # 沿时间维度翻转
        if np.random.rand() > 0.3:  # 70%概率添加噪声
            seq += torch.randn_like(seq) * 0.01  # 高斯噪声(均值0,方差0.01)
            
        return seq, label

五、数据加载函数

python 复制代码
def load_data(excel_path):
    df = pd.read_excel(excel_path)  # 读取Excel数据
    
    sequences = []
    labels = []
    
    for _, row in df.iterrows():  # 遍历每一行数据
        try:
            # 处理可能存在的字符串格式异常
            loads = list(map(float, str(row['载荷']).split(',')))
            displacements = list(map(float, str(row['位移']).split(',')))
            powers = list(map(float, str(row['功率']).split(',')))
            
            # 对齐三列数据的长度
            min_len = min(len(loads), len(displacements), len(powers))
            # 组合成(时间步长, 3个特征)的数组
            combined = np.array([
                loads[:min_len], 
                displacements[:min_len], 
                powers[:min_len]
            ).T  # 转置为(min_len, 3)
            
            label = int(float(row['工况结果']))  # 转换标签
            sequences.append(combined)
            labels.append(label)
        except Exception as e:
            print(f"处理第{_}行时出错: {str(e)}")  # 异常处理
    
    # 统计类别分布
    label_counts = Counter(labels)
    print("类别分布:", label_counts)
    
    # 创建标签映射(将任意标签转换为0~N-1的索引)
    unique_labels = sorted(list(set(labels)))
    label_map = {l:i for i,l in enumerate(unique_labels)}
    config["num_classes"] = len(unique_labels)  # 更新配置
    labels = [label_map[l] for l in labels]  # 转换所有标签
    
    # 分层划分训练/验证集(保持类别比例)
    return train_test_split(sequences, labels, test_size=0.2, stratify=labels)

六、训练函数

python 复制代码
def train_epoch(model, loader, optimizer, criterion, device):
    model.train()  # 训练模式
    total_loss = 0
    
    for x, y in loader:  # 遍历数据加载器
        x, y = x.to(device), y.to(device)  # 数据迁移到设备
        
        optimizer.zero_grad()  # 清空梯度
        outputs = model(x)     # 前向传播
        loss = criterion(outputs, y)  # 计算损失
        loss.backward()        # 反向传播
        
        # 梯度裁剪防止爆炸
        nn.utils.clip_grad_norm_(model.parameters(), config["grad_clip"])
        optimizer.step()       # 参数更新
        
        total_loss += loss.item() * x.size(0)  # 累加损失(考虑批次大小)
    
    return total_loss / len(loader.dataset)  # 平均损失

七、验证函数

python 复制代码
def validate(model, loader, criterion, device):
    model.eval()  # 评估模式
    total_loss = 0
    correct = 0
    
    with torch.no_grad():  # 禁用梯度计算
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            outputs = model(x)            loss = criterion(outputs, y)
            total_loss += loss.item() * x.size(0)
            
            # 计算准确率
            preds = outputs.argmax(dim=1)  # 取最大概率类别
            correct += preds.eq(y).sum().item()  # 统计正确数
    
    return (total_loss / len(loader.dataset),  # 平均损失
           (correct / len(loader.dataset))    # 准确率

八、检查点管理

python 复制代码
def save_checkpoint(epoch, model, optimizer, scheduler, best_acc, scaler, filename="checkpoint.pth"):
    torch.save({
        'epoch': epoch,                    # 当前轮数
        'model_state_dict': model.state_dict(),          # 模型参数
        'optimizer_state_dict': optimizer.state_dict(),  # 优化器状态
        'scheduler_state_dict': scheduler.state_dict(),  # 学习率调度器状态
        'best_acc': best_acc,              # 当前最佳准确率
        'scaler': scaler                   # 数据标准化参数
    }, filename)

def load_checkpoint(filename, model, optimizer, scheduler):
    checkpoint = torch.load(filename)
    model.load_state_dict(checkpoint['model_state_dict'])       # 加载模型
    optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 
    scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    return checkpoint['epoch'], checkpoint['best_acc'], checkpoint['scaler']

九、主函数

python 复制代码
def main(resume=False):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # 自动选择设备
    
    # 生成带时间戳的结果文件名
    timestamp = time.strftime("%Y%m%d_%H%M%S")
    results_file = f"training_results_{timestamp}.csv"
    
    # 加载并划分数据
    train_seq, val_seq, train_lb, val_lb = load_data("./dcgt.xls")
    
    # 初始化模型(恢复训练时自动获取序列长度)
    sample_seq = train_seq[0].shape[1] if resume else None
    model = EnhancedCNN(
        input_channels=3, 
        seq_len=sample_seq,  
        num_classes=config["num_classes"]
    ).to(device)
    
    # 定义损失函数和优化器
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(model.parameters(), 
                                lr=config["lr"], 
                                weight_decay=config["weight_decay"])
    # 学习率调度器(根据验证损失调整)
    scheduler = ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=5)
    
    # 恢复训练逻辑
    start_epoch = 0
    best_acc = 0
    if resume:
        checkpoint = torch.load("checkpoint.pth")
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        start_epoch = checkpoint['epoch']
        best_acc = checkpoint['best_acc']
        train_set = SequenceDataset(train_seq, train_lb, scaler=checkpoint['scaler'])
    else:
        train_set = SequenceDataset(train_seq, train_lb)
    
    # 验证集使用训练集的scaler
    val_set = SequenceDataset(val_seq, val_lb, scaler=train_set.scaler)
    
    # 持久化标准化参数
    joblib.dump(train_set.scaler, 'scaler.save')
    
    # 创建数据加载器
    train_loader = DataLoader(
        train_set, 
        batch_size=config["batch_size"], 
        shuffle=True, 
        num_workers=config["num_workers"]  # 多进程加载加速
    )
    val_loader = DataLoader(
        val_set, 
        batch_size=config["batch_size"], 
        num_workers=config["num_workers"]
    )

    # 训练循环
    with open(results_file, 'w', newline='') as f:
        writer = csv.writer(f)
        writer.writerow(['epoch', 'train_loss', 'val_loss', 'val_acc', 'learning_rate'])
        
        for epoch in range(start_epoch, config["epochs"]):
            # 训练一个epoch
            train_loss = train_epoch(model, train_loader, optimizer, criterion, device)
            # 验证
            val_loss, val_acc = validate(model, val_loader, criterion, device)
            current_lr = optimizer.param_groups[0]['lr']  # 获取当前学习率
            
            # 更新学习率
            scheduler.step(val_loss)
            
            # 保存检查点
            save_checkpoint(epoch+1, model, optimizer, scheduler, best_acc, train_set.scaler)
            
            # 记录结果
            writer.writerow([
                epoch + 1, 
                f"{train_loss:.4f}", 
                f"{val_loss:.4f}", 
                f"{val_acc:.4f}", 
                f"{current_lr:.6f}"
            ])
            print(f"\nEpoch {epoch+1}/{config['epochs']}")
            print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
            print(f"Val Acc: {val_acc*100:.2f}% | Learning Rate: {current_lr:.6f}")
            
            # 早停逻辑(伪代码示意)
            if val_acc > best_acc + config["min_delta"]:
                best_acc = val_acc
                patience_counter = 0
            else:
                patience_counter += 1
            
            if patience_counter >= config["patience"]:
                print(f"早停触发于第{epoch+1}轮")
                break
    
    # 保存最终模型
    torch.save(model.state_dict(), "best_model.pth")

十、执行入口

python 复制代码
if __name__ == "__main__":
    main(resume=False)  # 首次训练
    # main(resume=True)  # 恢复训练

十一、关键设计亮点总结

1.维度管理

维度管理:通过permute确保数据形状符合Conv1d要求

2.数据标准化

数据标准化:使用全体训练数据计算均值和方差,避免数据泄露

3.动态学习率

动态学习率:ReduceLROnPlateau根据验证损失自动调整

4.梯度剪裁

梯度裁剪:防止梯度爆炸,稳定训练过程

5.检查点系统

检查点系统:完整保存训练状态,支持训练中断恢复

6.结果可追溯

结果可追溯:带时间戳的CSV记录和模型保存

7.工业级健壮性

工业级健壮性:异常捕获、参数持久化、自动类别映射

8.高效数据加载

高效数据加载:多进程并行加速数据预处理

这个实现涵盖了从数据预处理到模型训练的完整流程,适合工业级时间序列分类任务 ,具有良好的可扩展性和可维护性。


相关推荐
liuyunshengsir6 分钟前
chromadb 安装和使用
人工智能·大模型
FIT2CLOUD飞致云12 分钟前
全面支持MCP协议,开启便捷连接之旅,MaxKB知识库问答系统v1.10.3 LTS版本发布
人工智能·开源
Johnny_Cheung17 分钟前
字符串、列表、元组、字典
开发语言·python
云水木石17 分钟前
ChatGPT-4o 在汉字显示上进步巨大
人工智能·chatgpt
独行soc21 分钟前
2025年渗透测试面试题总结- 某四字大厂面试复盘扩展 一面(题目+回答)
java·数据库·python·安全·面试·职场和发展·汽车
Mr_LeeCZ33 分钟前
PyTorch 深度学习 || 7. Unet | Ch7.1 Unet 框架
人工智能·深度学习·机器学习
James. 常德 student36 分钟前
多GPU训练
人工智能·pytorch·深度学习
梦回阑珊39 分钟前
《QT从基础到进阶·七十四》Qt+C++开发一个python编译器,能够编写,运行python程序改进版
c++·python·qt
前端开发张小七43 分钟前
13.Python Socket服务端开发指南
前端·python
前端开发张小七44 分钟前
14.Python Socket客户端开发指南
前端·python