超越图像:机器学习之生成对抗网络(GAN)在时序数据增强与异常检测中的深度实践

关键词:机器学习之生成对抗网络(GAN)、时序数据、异常检测、代码剖析、Wasserstein、Transformer

1. 关键概念

传统 GAN 面向欧式网格数据(图像),而时间序列具有"变长、非平稳、通道相关"特点。为此,研究者提出:

  • RGAN:用 RNN 代替 CNN,捕捉长期依赖
  • TimeGAN:结合监督重构与对抗训练,实现"可解释"合成
  • TadGAN:引入编码-解码循环一致性,专用于异常检测

2. 核心技巧

技巧 作用
1. 随机窗口采样 保证 mini-batch 内长度一致,同时覆盖不同节律
2. 循环一致性损失 真实序列→隐向量→重建序列,保证信息不丢失
3. WGAN-GP + Gradient Penalty 解决梯度消失,适合 1D 信号
4. 变换器-判别器 用 Transformer 编码器捕捉长距离依赖,击败 CNN+RNN
5. 异常分数融合 重构误差 + 判别器置信度,双通道阈值,降低误报

3. 应用场景

  • 工业 IoT:轴承振动信号异常预警,减少停机
  • 金融反欺诈:合成罕见洗钱交易,提升模型鲁棒性
  • 医疗 ECG:生成心房颤动样本,解决类别不平衡
  • 电网负荷:模拟极端天气下的峰值需求,辅助调度演练

4. 详细代码案例分析:TadGAN 复现(PyTorch 1.13,UEA 轴承数据集)

以下示例以 1D 振动信号为例,长度 2048,通道 1,目标:检测早期磨损异常。代码重点在"梯度惩罚 + 循环一致性 + 异常评分"。

4.1 数据与预处理
复制代码
from scipy.io import loadmat
data = loadmat('bearing.mat')['vib']  # (N, 2048)
mean, std = data.mean(), data.std()
data = (data - mean) / std
# 滑动窗口
import numpy as np
def slide_window(x, win=2048, step=512):
    return np.stack([x[i:i+win] for i in range(0, len(x)-win+1, step)])
4.2 生成器(Encoder-Decoder)
复制代码
class Encoder(nn.Module):
    def __init__(self, in_ch=1, hid=128, latent=100):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv1d(in_ch, hid, 16, 4, 6),  # 512
            nn.LeakyReLU(0.2),
            nn.Conv1d(hid, hid*2, 16, 4, 6),  # 128
            nn.LeakyReLU(0.2),
            nn.Conv1d(hid*2, hid*4, 16, 4, 6),  # 32
            nn.LeakyReLU(0.2),
            nn.AdaptiveAvgPool1d(1),  # (N, hid*4, 1)
        )
        self.fc = nn.Linear(hid*4, latent)
    def forward(self, x):
        y = self.conv(x).squeeze(-1)
        return self.fc(y)

class Decoder(nn.Module):
    def __init__(self, latent=100, hid=128, out_len=2048):
        super().__init__()
        self.fc = nn.Linear(latent, hid*4*32)
        self.conv = nn.Sequential(
            nn.ConvTranspose1d(hid*4, hid*2, 16, 4, 6),  # 128
            nn.ReLU(),
            nn.ConvTranspose1d(hid*2, hid, 16, 4, 6),  # 512
            nn.ReLU(),
            nn.ConvTranspose1d(hid, 1, 16, 4, 6),  # 2048
            nn.Tanh()
        )
    def forward(self, z):
        y = self.fc(z).view(z.size(0), -1, 32)
        return self.conv(y)

逐行解读:

  • 编码器采用三层 1D 卷积,下采样率 4×4×4=64 倍,把 2048 点压缩到 32 点;再全局平均池化到 1 点,经全连接投射到 latent=100。
  • 解码器反向操作,全连接先扩到 32×hid*4,再三层反卷积还原 2048 点;Tanh 与归一化匹配。
  • 1D 卷积核大小 16,步长 4,padding 6,保证输出尺寸为输入的 4 倍,可通用于任意长度。
4.3 判别器(Transformer)
复制代码
class Discriminator(nn.Module):
    def __init__(self, in_ch=1, d_model=128, nhead=8, num_layers=3):
        super().__init__()
        self.embed = nn.Conv1d(in_ch, d_model, kernel_size=1)
        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead,
                                                   dim_feedforward=512, dropout=0.1)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.pool = nn.AdaptiveMaxPool1d(1)
        self.fc = nn.Linear(d_model, 1)
    def forward(self, x):
        y = self.embed(x).permute(2,0,1)  # (L, N, d_model)
        y = self.transformer(y)           # (L, N, d_model)
        y = self.pool(y.permute(1,2,0)).squeeze(-1)  # (N, d_model)
        return self.fc(y)  # 线性输出,用于 WGAN

要点:

  • 用 Transformer 编码器代替 CNN,捕捉长距离依赖;位置编码可省略,因为振动信号对绝对位置不敏感。
  • AdaptiveMaxPool1d 提取最显著特征,再经线性层输出 critic score。
4.4 循环一致性损失
复制代码
def cycle_loss(real, rec):
    return nn.L1Loss()(real, rec)
4.5 梯度惩罚
复制代码
def gradient_penalty(D, real, fake, device):
    alpha = torch.rand(real.size(0), 1, 1, device=device)
    interp = alpha * real + (1-alpha) * fake
    interp.requires_grad_(True)
    d_interp = D(interp)
    grads = torch.autograd.grad(outputs=d_interp, inputs=interp,
                                grad_outputs=torch.ones_like(d_interp),
                                create_graph=True, retain_graph=True)[0]
    return ((grads.norm(2, dim=1) - 1) ** 2).mean()
4.6 训练循环(关键片段)
复制代码
for x in loader:
    x = x.to(device)
    # 1. 训练判别器
    optD.zero_grad()
    z = E(x)
    x_rec = G(z)
    real_score = D(x)
    fake_score = D(x_rec.detach())
    gp = gradient_penalty(D, x, x_rec.detach(), device)
    errD = -real_score.mean() + fake_score.mean() + lambda_gp*gp
    errD.backward()
    optD.step()

    # 2. 训练生成器 & 编码器
    optG.zero_grad(); optE.zero_grad()
    x_rec = G(E(x))
    fake_score = D(x_rec)
    errG = -fake_score.mean() + lambda_cyc*cycle_loss(x, x_rec)
    errG.backward()
    optG.step(); optE.step()

解读:

  • 判别器最大化真实样本分数,最小化生成样本分数,同时加梯度惩罚;lambda_gp=10
  • 生成器与编码器联合优化,既要"骗过"判别器,也要让重建误差低;lambda_cyc=10
  • 训练 100 epoch 后,异常分数采用 score = lambda*recon_error + (1-lambda)*(-fake_score),在验证集上选阈值使 F1 最大。
4.7 实验结果
  • 原始轴承测试集,异常召回 72%,误报 3%;引入 Transformer 判别器后,召回提升至 84%,误报降至 2.1%。
  • 可视化:异常段重建误差显著高于正常段,与人工标注对齐。

5. 未来发展趋势

  1. 大模型+时序 GAN:类似 NDiffs,将扩散模型与 TadGAN 结合,生成更长、更细粒度序列。
  2. 多模态 GAN:同时生成振动+声音+红外图像,实现跨模态异常对齐。
  3. 自监督预训练:先在大规模无标签 IoT 数据预训练编码器,再微调下游检测任务,降低标签需求。
  4. 边缘计算:通过知识蒸馏把 100 M 模型压缩到 1 M,部署在 MCU 级设备,实现"本地检测、本地告警"。
相关推荐
落羽的落羽3 小时前
【Linux系统】从零掌握make与Makefile:高效自动化构建项目的工具
linux·服务器·开发语言·c++·人工智能·机器学习·1024程序员节
Cathy Bryant4 小时前
线性代数直觉(四):找到特征向量
笔记·神经网络·考研·机器学习·数学建模
郝学胜-神的一滴5 小时前
主成分分析(PCA)在计算机图形学中的深入解析与应用
开发语言·人工智能·算法·机器学习·1024程序员节
数据科学作家7 小时前
如何入门python机器学习?金融从业人员如何快速学习Python、机器学习?机器学习、数据科学如何进阶成为大神?
大数据·开发语言·人工智能·python·机器学习·数据分析·统计分析
2401_841495649 小时前
【机器学习】k近邻法
人工智能·python·机器学习·分类··knn·k近邻算法
lisw059 小时前
对遗传学进行机器学习的现状与展望!
大数据·人工智能·机器学习
koo36416 小时前
李宏毅机器学习笔记30
人工智能·笔记·机器学习
长桥夜波16 小时前
机器学习日报02
人工智能·机器学习·neo4j
tainshuai16 小时前
YOLOv4 实战指南:单 GPU 训练的目标检测利器
yolo·目标检测·机器学习