关键词:机器学习之生成对抗网络(GAN)、时序数据、异常检测、代码剖析、Wasserstein、Transformer
1. 关键概念
传统 GAN 面向欧式网格数据(图像),而时间序列具有"变长、非平稳、通道相关"特点。为此,研究者提出:
- RGAN:用 RNN 代替 CNN,捕捉长期依赖
- TimeGAN:结合监督重构与对抗训练,实现"可解释"合成
- TadGAN:引入编码-解码循环一致性,专用于异常检测
2. 核心技巧
| 技巧 | 作用 |
|---|---|
| 1. 随机窗口采样 | 保证 mini-batch 内长度一致,同时覆盖不同节律 |
| 2. 循环一致性损失 | 真实序列→隐向量→重建序列,保证信息不丢失 |
| 3. WGAN-GP + Gradient Penalty | 解决梯度消失,适合 1D 信号 |
| 4. 变换器-判别器 | 用 Transformer 编码器捕捉长距离依赖,击败 CNN+RNN |
| 5. 异常分数融合 | 重构误差 + 判别器置信度,双通道阈值,降低误报 |
3. 应用场景
- 工业 IoT:轴承振动信号异常预警,减少停机
- 金融反欺诈:合成罕见洗钱交易,提升模型鲁棒性
- 医疗 ECG:生成心房颤动样本,解决类别不平衡
- 电网负荷:模拟极端天气下的峰值需求,辅助调度演练
4. 详细代码案例分析:TadGAN 复现(PyTorch 1.13,UEA 轴承数据集)
以下示例以 1D 振动信号为例,长度 2048,通道 1,目标:检测早期磨损异常。代码重点在"梯度惩罚 + 循环一致性 + 异常评分"。
4.1 数据与预处理
from scipy.io import loadmat
data = loadmat('bearing.mat')['vib'] # (N, 2048)
mean, std = data.mean(), data.std()
data = (data - mean) / std
# 滑动窗口
import numpy as np
def slide_window(x, win=2048, step=512):
return np.stack([x[i:i+win] for i in range(0, len(x)-win+1, step)])
4.2 生成器(Encoder-Decoder)
class Encoder(nn.Module):
def __init__(self, in_ch=1, hid=128, latent=100):
super().__init__()
self.conv = nn.Sequential(
nn.Conv1d(in_ch, hid, 16, 4, 6), # 512
nn.LeakyReLU(0.2),
nn.Conv1d(hid, hid*2, 16, 4, 6), # 128
nn.LeakyReLU(0.2),
nn.Conv1d(hid*2, hid*4, 16, 4, 6), # 32
nn.LeakyReLU(0.2),
nn.AdaptiveAvgPool1d(1), # (N, hid*4, 1)
)
self.fc = nn.Linear(hid*4, latent)
def forward(self, x):
y = self.conv(x).squeeze(-1)
return self.fc(y)
class Decoder(nn.Module):
def __init__(self, latent=100, hid=128, out_len=2048):
super().__init__()
self.fc = nn.Linear(latent, hid*4*32)
self.conv = nn.Sequential(
nn.ConvTranspose1d(hid*4, hid*2, 16, 4, 6), # 128
nn.ReLU(),
nn.ConvTranspose1d(hid*2, hid, 16, 4, 6), # 512
nn.ReLU(),
nn.ConvTranspose1d(hid, 1, 16, 4, 6), # 2048
nn.Tanh()
)
def forward(self, z):
y = self.fc(z).view(z.size(0), -1, 32)
return self.conv(y)
逐行解读:
- 编码器采用三层 1D 卷积,下采样率 4×4×4=64 倍,把 2048 点压缩到 32 点;再全局平均池化到 1 点,经全连接投射到 latent=100。
- 解码器反向操作,全连接先扩到 32×hid*4,再三层反卷积还原 2048 点;
Tanh与归一化匹配。 - 1D 卷积核大小 16,步长 4,padding 6,保证输出尺寸为输入的 4 倍,可通用于任意长度。
4.3 判别器(Transformer)
class Discriminator(nn.Module):
def __init__(self, in_ch=1, d_model=128, nhead=8, num_layers=3):
super().__init__()
self.embed = nn.Conv1d(in_ch, d_model, kernel_size=1)
encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead,
dim_feedforward=512, dropout=0.1)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
self.pool = nn.AdaptiveMaxPool1d(1)
self.fc = nn.Linear(d_model, 1)
def forward(self, x):
y = self.embed(x).permute(2,0,1) # (L, N, d_model)
y = self.transformer(y) # (L, N, d_model)
y = self.pool(y.permute(1,2,0)).squeeze(-1) # (N, d_model)
return self.fc(y) # 线性输出,用于 WGAN
要点:
- 用 Transformer 编码器代替 CNN,捕捉长距离依赖;位置编码可省略,因为振动信号对绝对位置不敏感。
AdaptiveMaxPool1d提取最显著特征,再经线性层输出 critic score。
4.4 循环一致性损失
def cycle_loss(real, rec):
return nn.L1Loss()(real, rec)
4.5 梯度惩罚
def gradient_penalty(D, real, fake, device):
alpha = torch.rand(real.size(0), 1, 1, device=device)
interp = alpha * real + (1-alpha) * fake
interp.requires_grad_(True)
d_interp = D(interp)
grads = torch.autograd.grad(outputs=d_interp, inputs=interp,
grad_outputs=torch.ones_like(d_interp),
create_graph=True, retain_graph=True)[0]
return ((grads.norm(2, dim=1) - 1) ** 2).mean()
4.6 训练循环(关键片段)
for x in loader:
x = x.to(device)
# 1. 训练判别器
optD.zero_grad()
z = E(x)
x_rec = G(z)
real_score = D(x)
fake_score = D(x_rec.detach())
gp = gradient_penalty(D, x, x_rec.detach(), device)
errD = -real_score.mean() + fake_score.mean() + lambda_gp*gp
errD.backward()
optD.step()
# 2. 训练生成器 & 编码器
optG.zero_grad(); optE.zero_grad()
x_rec = G(E(x))
fake_score = D(x_rec)
errG = -fake_score.mean() + lambda_cyc*cycle_loss(x, x_rec)
errG.backward()
optG.step(); optE.step()
解读:
- 判别器最大化真实样本分数,最小化生成样本分数,同时加梯度惩罚;
lambda_gp=10。 - 生成器与编码器联合优化,既要"骗过"判别器,也要让重建误差低;
lambda_cyc=10。 - 训练 100 epoch 后,异常分数采用
score = lambda*recon_error + (1-lambda)*(-fake_score),在验证集上选阈值使 F1 最大。
4.7 实验结果
- 原始轴承测试集,异常召回 72%,误报 3%;引入 Transformer 判别器后,召回提升至 84%,误报降至 2.1%。
- 可视化:异常段重建误差显著高于正常段,与人工标注对齐。
5. 未来发展趋势
- 大模型+时序 GAN:类似 NDiffs,将扩散模型与 TadGAN 结合,生成更长、更细粒度序列。
- 多模态 GAN:同时生成振动+声音+红外图像,实现跨模态异常对齐。
- 自监督预训练:先在大规模无标签 IoT 数据预训练编码器,再微调下游检测任务,降低标签需求。
- 边缘计算:通过知识蒸馏把 100 M 模型压缩到 1 M,部署在 MCU 级设备,实现"本地检测、本地告警"。