【机器学习实战】kaggle 欺诈检测---使用生成对抗网络(GAN)解决欺诈数据中正负样本极度不平衡问题

【机器学习实战】kaggle 欺诈检测---如何解决欺诈数据中正负样本极度不平衡问题https://blog.csdn.net/2302_79308082/article/details/145177242

本篇文章是基于上次文章中提到的对抗生成网络,通过对抗生成网络生成少数类样本,平衡欺诈数据中正类样本极少的问题。

本人主页:机器学习小小白

机器学习专栏:机器学习实战

PyTorch入门专栏:PyTorch入门

深度学习实战:深度学习

ok,话不多说,我们进入正题吧

1. 引言

生成对抗网络(Generative Adversarial Networks,简称GAN)是由Ian Goodfellow等人于2014年提出的一种深度学习模型。它在计算机视觉、自然语言处理、音频生成等领域得到了广泛应用。GAN的核心思想是通过两个神经网络之间的博弈关系来生成新的、仿真的数据。自从GAN提出以来,它已经成为生成模型领域的突破性进展,深刻改变了生成式模型的研究和应用。

2. GAN的基本原理

生成对抗网络的结构包括两个主要部分:生成器(Generator)和判别器(Discriminator)。这两个网络分别充当"对手",并在训练过程中互相博弈:

  • 生成器(Generator):该网络的目的是通过学习数据分布来生成尽可能接近真实数据的虚假样本。生成器从一个随机的噪声(通常是高维的向量)出发,逐步生成样本。

  • 判别器(Discriminator):该网络的任务是判断一个样本是真实的(来自训练数据)还是虚假的(来自生成器)。判别器输出一个概率值,表示输入样本为真实数据的概率。

3. GAN的训练过程

GAN的训练过程是一个"博弈"过程,生成器和判别器不断互相对抗,从而提升各自的性能。这个过程可以通过以下的数学公式来表示:

  • 判别器的目标:判别器的目标是最大化其对于真实数据的判断概率(即预测为1的概率),同时最小化对生成数据的错误分类(即预测为0的概率)。可以通过以下的交叉熵损失函数表示:

其中:

  • ​ 是从真实数据分布中采样的数据。

  • 是生成器生成的样本,是从潜在空间中采样的噪声。

  • 是判别器对样本 的判别输出,表示其为真实数据的概率。

  • 生成器的目标:生成器的目标是使判别器无法区分生成数据与真实数据,因此它通过最大化判别器对生成数据为真实的概率来进行训练:

  • 其中: 是生成器生成的虚假样本,是判别器对生成样本的输出,表示其为真实数据的概率。

在训练过程中,生成器和判别器会交替优化这两个损失函数。理想的结果是生成器能够生成与真实数据分布相似的样本,而判别器则无法有效地区分生成数据与真实数据。

4. GAN的应用

GAN具有强大的生成能力,广泛应用于多个领域,以下是一些典型的应用场景:

  • 图像生成:GAN可以用于生成高度逼真的图像,如人脸、风景或艺术作品。典型的例子包括DeepArt和StyleGAN,后者能够生成几乎无法与真实人脸区分的图像。

  • 图像到图像的转换:例如,利用GAN进行图像风格转换(如将照片转化为油画风格)、超分辨率重建(如提高图像的分辨率)、图像修复(如填补丢失部分)等任务。

  • 文本生成:结合自然语言处理技术,GAN也可用于生成文本数据,如诗歌、故事生成等,尤其是文本生成和对话系统中的对抗训练。

  • 音频生成:GAN被广泛应用于音频生成,如音乐生成、语音合成等。

  • 数据增强:GAN可以用于数据增强,特别是在医疗图像领域,生成具有一定变异的图像样本,以增强训练数据集。

  • 模型训练中的对抗样本生成:GAN可以生成对抗样本,即通过对训练数据进行微小扰动,生成能够误导模型的样本,这对提升模型的鲁棒性非常重要。

5. GAN的变种

GAN作为一种框架,已经发展出了多种变种,以满足不同应用的需求。以下是几种常见的GAN变种:

  • CGAN(Conditional GAN):在生成器和判别器中都加入了条件变量,使得生成的样本可以根据某些条件(如标签信息)进行控制。

  • WGAN(Wasserstein GAN):解决了传统GAN在训练过程中可能出现的梯度消失和模式崩溃问题。WGAN使用了Wasserstein距离作为生成器和判别器的损失函数。

  • DCGAN(Deep Convolutional GAN):使用卷积神经网络(CNN)来构建生成器和判别器,增强了GAN在图像生成任务中的表现。

  • CycleGAN:用于无监督学习场景,特别是在图像到图像的转换中,例如将一张照片转换成另一种风格(如马到斑马转换)。

6. 使用生成对抗网络(GAN)生成欺诈数据中少数类数据

1. 数据预处理与特征提取

python 复制代码
import pandas as pd
import numpy as np

train_df = pd.read_csv('/kaggle/input/credit-card-fraud-prediction/train.csv')
test_df = pd.read_csv('/kaggle/input/credit-card-fraud-prediction/test.csv')

def time_feature(df):
    df['Time'] = pd.to_datetime(df['Time'], unit='s')  # 将时间戳转为 datetime 格式
    # 提取时间特征
    df['hour'] = df['Time'].dt.hour
    df['minute'] = df['Time'].dt.minute 
    return df 

train_df = time_feature(train_df)
test_df = time_feature(test_df)

在欺诈检测任务中,时间特征(如交易发生的小时和分钟)通常是重要的,因为欺诈交易往往具有不同的时间模式。例如,欺诈交易可能集中在某些特定的时间段。

  • 这里我们通过pd.to_datetime()Time列从Unix时间戳格式转换为日期时间格式。然后,我们提取了小时和分钟作为新的特征,用于训练模型。
python 复制代码
train_feature = train_df.drop(columns=['id','IsFraud','Time'])
test_feature = test_df.drop(columns=['id','Time'])

label = train_df['IsFraud']

train_feature 是用于训练的特征数据,删除了 id, IsFraudTime 列。IsFraud 是标签列,表示交易是否为欺诈交易;而 idTime 列不包含有用的特征信息,因此可以去掉。

2. 标准化数据

python 复制代码
from sklearn.preprocessing import StandardScaler

# 标准化特征数据
scaler = StandardScaler()
train_feature_scaled = scaler.fit_transform(train_feature)
  • 标准化(Standardization)是机器学习中常用的预处理步骤。它通过减去均值并除以标准差,使特征数据具有零均值和单位方差。标准化能够加速模型的收敛过程,尤其是在使用像神经网络这样的梯度优化模型时。

  • 这里使用 StandardScaler 来对训练数据进行标准化,以确保所有特征在同一个量级。

3. 生成器与判别器的构建

生成器(Generator)

python 复制代码
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, LeakyReLU, BatchNormalization, Input

def build_generator(latent_dim, input_dim):
    model = Sequential()
    model.add(Input(shape=(latent_dim,)))  # 使用 Input 层来指定输入维度
    model.add(Dense(256))
    model.add(LeakyReLU(0.2))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Dense(512))
    model.add(LeakyReLU(0.2))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Dense(1024))
    model.add(LeakyReLU(0.2))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Dense(input_dim, activation='tanh'))  # 输出层与原数据同维度
    return model

生成器(Generator)是GAN的核心部分,它通过接收随机噪声向量(潜在空间中的点),然后经过一系列的全连接层和激活函数,生成与原始数据分布相似的虚假数据。

  • 在此,我们使用了 LeakyReLU 激活函数,它允许梯度通过负半轴流动,解决了传统ReLU可能出现的"死神经元"问题。BatchNormalization 用于加速网络的训练,并帮助改善模型的稳定性。

判别器(Discriminator)

python 复制代码
def build_discriminator(input_dim):
    model = Sequential()
    model.add(Input(shape=(input_dim,)))  # 使用 Input 层来指定输入维度
    model.add(Dense(1024))
    model.add(LeakyReLU(0.2))
    model.add(Dense(512))
    model.add(LeakyReLU(0.2))
    model.add(Dense(256))
    model.add(LeakyReLU(0.2))
    model.add(Dense(1, activation='sigmoid'))  # 输出真假判定
    return model

判别器(Discriminator)的任务是判断输入数据是真实的还是由生成器生成的。它是一个二分类模型,输出是一个概率值,表示输入数据为真实的概率。

  • 这里使用 sigmoid 激活函数,输出一个概率值。判别器学习将真实数据和生成数据区分开来。

4. GAN模型的组合与训练

python 复制代码
def build_gan(generator, discriminator):
    discriminator.trainable = False  # 在训练GAN时冻结判别器
    model = Sequential()
    model.add(generator)
    model.add(discriminator)
    return model

# 定义优化器
optimizer = Adam()

# 定义输入维度和潜在维度
latent_dim = 100  # 随机噪声的维度
input_dim = 31  # 输入数据的维度,例如欺诈检测数据的特征数

# 创建并编译模型
generator = build_generator(latent_dim, input_dim)
discriminator = build_discriminator(input_dim)
gan = build_gan(generator, discriminator)

# 编译判别器和GAN模型
discriminator.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy'])
gan.compile(loss='binary_crossentropy', optimizer=optimizer)
  • 生成对抗训练(Adversarial Training)是GAN的关键。生成器和判别器在一个博弈过程中互相优化。在训练过程中,生成器通过"欺骗"判别器来优化其生成数据的能力,而判别器则不断学习区分真实和生成数据。

  • 在训练过程中,我们冻结判别器的参数,只训练生成器,这样可以避免在训练生成器时更新判别器的权重。

5. GAN训练函数

python 复制代码
def train_gan(generator, discriminator, gan, fraud_data_scaled, epochs=10000, batch_size=64):
    valid = np.ones((batch_size, 1))  # 真数据标签
    fake = np.zeros((batch_size, 1))  # 假数据标签

    for epoch in range(epochs):
        # 随机选择真实欺诈数据
        idx = np.random.randint(0, fraud_data_scaled.shape[0], batch_size)
        real_data = fraud_data_scaled[idx]

        # 生成虚拟数据
        noise = np.random.normal(0, 1, (batch_size, latent_dim))
        generated_data = generator.predict(noise)

        # 训练判别器
        d_loss_real = discriminator.train_on_batch(real_data, valid)
        d_loss_fake = discriminator.train_on_batch(generated_data, fake)
        d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

        # 训练生成器
        noise = np.random.normal(0, 1, (batch_size, latent_dim))
        g_loss = gan.train_on_batch(noise, valid)

        # 输出训练过程的损失
        if epoch % 1000 == 0:
            print(f'{epoch}/{epochs} [D loss: {d_loss[0]}] [G loss: {g_loss}]')
  • 训练过程:在每个训练周期中,首先更新判别器的权重(通过训练它区分真实数据和生成数据),然后训练生成器(通过训练它欺骗判别器)。

  • 损失函数 :我们使用了 binary_crossentropy 损失函数,它用于二分类任务。在判别器的训练中,我们分别计算真实数据和生成数据的损失,然后平均得到判别器的总损失。生成器的损失则是通过GAN模型进行计算的。

6. 生成虚拟数据

python 复制代码
def generate_fake_data(generator, num_samples):
    noise = np.random.normal(0, 1, (num_samples, latent_dim))  # 随机噪声
    generated_data = generator.predict(noise)  # 生成虚拟数据
    # 将生成的数据转换回原始空间
    generated_data_original = scaler.inverse_transform(generated_data)

    # 获取原始负样本数据的列名(去除 'id', 'IsFraud', 'Time' 列)
    feature_columns = [col for col in train_df.columns if col not in ['id', 'IsFraud', 'Time']]

    # 将生成的数据与原始负样本数据(即非欺诈数据)结合,作为新的训练数据
    augmented_data = np.concatenate([train_df[train_df['IsFraud'] == 0].drop(columns=['id', 'IsFraud', 'Time']),
                                     generated_data_original], axis=0)

    augmented_label = np.concatenate([np.zeros(train_df[train_df['IsFraud'] == 0].shape[0]), 
                                     np.ones(generated_data_original.shape[0])], axis=0)

    # 创建包含生成数据和标签的 DataFrame
    augmented_df = pd.DataFrame(augmented_data, columns=feature_columns)
    augmented_df['IsFraud'] = augmented_label

    return augmented_df

在这个函数中,我们使用训练好的生成器来生成新的虚拟欺诈数据,并将它们与真实的非欺诈数据结合,以增强数据集。然后,我们通过逆标准化将生成的数据转换回原始数据空间。

本次例子为了缩短训练时间,只生成了100条虚拟的正样本数据。

相关推荐
深图智能2 分钟前
PyTorch使用教程(4)-torch.nn
人工智能·pytorch·深度学习
smartcat20108 分钟前
Lora理解&QLoRA
深度学习
Channing Lewis13 分钟前
Python 3.9及以上版本支持的新的字符串函数 str.removeprefix()
服务器·python
yuanbenshidiaos15 分钟前
【大数据】机器学习----------集成学习
大数据·机器学习·集成学习
伊一大数据&人工智能学习日志33 分钟前
机器学习经典无监督算法——聚类K-Means算法
人工智能·算法·机器学习
唐BiuBiu1 小时前
python如何解析word文件格式(.docx)
python·word
测试秃头怪1 小时前
银行测试:第三方支付平台业务流,功能/性能/安全测试方法
自动化测试·软件测试·python·功能测试·测试工具·测试用例·安全性测试
游王子1 小时前
机器学习(3):逻辑回归
人工智能·机器学习·逻辑回归
laopeng3012 小时前
4.Spring AI Prompt:与大模型进行有效沟通
人工智能·spring·prompt
dwjf3212 小时前
神经网络基础-正则化方法
人工智能·深度学习·神经网络