连续变量的 交叉熵 如何计算 python tensorflow

连续变量的交叉熵通常在机器学习中的回归问题中使用,但它也可以用于分类问题,当概率分布是连续的时。连续变量的交叉熵计算公式如下:

设 \( p(x) \) 是真实概率密度函数,\( q(x) \) 是预测概率密度函数,交叉熵 \( H(p, q) \) 定义为:

\[

H(p, q) = -\int_{-\infty}^{\infty} p(x) \log q(x) \, dx

\]

在分类问题中,如果我们有 \( K \) 个类别,并且 \( p_k \) 是第 \( k \) 个类别的真实概率,\( q_k \) 是第 \( k \) 个类别的预测概率,交叉熵可以表示为:

\[

H(p, q) = -\sum_{k=1}^{K} p_k \log q_k

\]

在实际应用中,如果 \( p_k \) 是一个独热编码的向量(即只有一个类别是 1,其余都是 0),上述公式简化为:

\[

H(p, q) = -\log q_y

\]

其中 \( y \) 是真实类别的索引。

在回归问题中,如果我们有一个连续的目标变量,我们可以使用均方误差(MSE)或均方对数误差(MSLE)等其他损失函数,而不是交叉熵。然而,如果我们想要使用交叉熵,我们通常需要将问题转换为类似于分类问题的形式,例如通过将连续变量离散化或使用概率分布来建模连续变量。

下面是用梯度惩罚来实现 K L 散度最小化的实现,和交叉熵原理差不多

复制代码
# 定义计算损失的函数
def compute_loss(real_data):
    # 梯度惩罚权重
    gradient_penalty_weight = gradient_penalty_weight_lamda
    # x = tf.random.normal((batch_size, n), dtype=tf.dtypes.float32)
    # x_samp = x / tf.sqrt(2 * tf.reduce_mean(tf.square(x)))
    # x_gen = tf.concat(values=[w_generator(x_samp), x_samp], axis=1)
    # x = X_train
    # x_samp = X_train_samp
    x_samp = X_train
    # todo
    # 计算损失函数的时候 ,用z_score归一化计算?
    # todo
    # x_gen = w_generator(x_samp) + x_samp
    x_gen = w_generator(x_samp)
    logits_x = w_discriminator(tf.concat([y_train, X_train], axis=-1))
    logits_x_gen = w_discriminator(tf.concat([x_gen, X_train], axis=-1))
    d_regularizer = gradient_penalty(real_data, x_gen)
    disc_loss = (tf.reduce_mean(logits_x) - tf.reduce_mean(logits_x_gen) + d_regularizer * gradient_penalty_weight)
    gen_loss = tf.reduce_mean(logits_x_gen)
    return disc_loss, gen_loss


# 定义应用生成器梯度的函数
def apply_gen_gradients(gen_gradients):
    w_gen_optimizer.apply_gradients(zip(gen_gradients, w_generator.trainable_variables))


# 定义应用判别器梯度的函数
def apply_disc_gradients(disc_gradients):
    w_disc_optimizer.apply_gradients(zip(disc_gradients, w_discriminator.trainable_variables))


# 定义梯度惩罚函数
# def gradient_penalty(x, x_gen):
#     epsilon = tf.random.uniform([x.shape[0], 1, 1, 1], 0.0, 1.0)
#     x_hat = epsilon * x + (1 - epsilon) * x_gen
#     with tf.GradientTape() as t:
#         t.watch(x_hat)
#         d_hat = w_discriminator(x_hat)
#     gradients = t.gradient(d_hat, x_hat)
#     ddx = tf.sqrt(tf.reduce_sum(gradients ** 2, axis=[1, 2]))
#     d_regularizer = tf.reduce_mean((ddx - 1.0) ** 2)
#     return d_regularizer


# 定义梯度惩罚函数
def gradient_penalty(x, x_gen):
    # 创建一个与真实样本 x 的批量大小相同的随机变量 epsilon,其值在0和1之间,用于在后续步骤中进行插值。
    # epsilon = tf.random.uniform([x.shape[0], 1, 1, 1], 0.0, 1.0)
    epsilon = tf.random.uniform([x.shape[0], 1], 0.0, 1.0)
    # 计算插值样本 x_hat,它是真实样本 x 和生成样本 x_gen 的线性组合。
    # 这一步是为了在真实样本和生成样本之间创建一个连续的路径
    x_hat = epsilon * x + (1 - epsilon) * x_gen
    # print("Shape before discriminator:", x_hat.shape)

    # 创建一个 tf.GradientTape 上下文,用于记录对 x_hat 的操作,以便后续计算梯度。
    with tf.GradientTape() as t:
        # 告诉 tf.GradientTape 监控 x_hat,以便可以计算关于它的梯度
        t.watch(x_hat)
        # 使用判别器 w_discriminator 对插值样本 x_hat 进行评分,得到 d_hat。
        # print(x_hat.shape)
        d_hat = w_discriminator(tf.concat([x_hat, X_train], axis=-1))
    # 计算判别器输出 d_hat 关于插值样本 x_hat 的梯度
    gradients = t.gradient(d_hat, x_hat)
    # 计算梯度的L2范数,即对每个样本的梯度向量进行平方和,然后开方,得到每个样本的梯度范数。
    # 在你的代码中,gradients 张量的形状是 [100, 4],但你尝试在 axis=[1, 2]
    # 上进行 tf.reduce_sum 操作。由于张量只有两个维度,所以没有第三个维度可以进行求和。
    # ddx = tf.sqrt(tf.reduce_sum(gradients ** 2, axis=[1, 2]))
    ddx = tf.sqrt(tf.reduce_sum(gradients ** 2, axis=[1]))
    # 算梯度惩罚项,它是梯度范数与1的差的平方的平均值。在WGAN中,我们希望梯度范数接近1,
    # 因此这个惩罚项会惩罚那些使梯度范数远离1的判别器。
    d_regularizer = tf.reduce_mean((ddx - 1.0) ** 2)
    # print("gradient_penalty")
    return d_regularizer
相关推荐
看到我,请让我去学习5 分钟前
Qt— 布局综合项目(Splitter,Stacked,Dock)
开发语言·qt
GUET_一路向前18 分钟前
【C语言防御性编程】if条件常量在前,变量在后
c语言·开发语言·if-else·防御性编程
曳渔19 分钟前
UDP/TCP套接字编程简单实战指南
java·开发语言·网络·网络协议·tcp/ip·udp
三千道应用题36 分钟前
WPF&C#超市管理系统(6)订单详情、顾客注册、商品销售排行查询和库存提示、LiveChat报表
开发语言·c#·wpf
hqxstudying1 小时前
JAVA项目中邮件发送功能
java·开发语言·python·邮件
咪咪渝粮1 小时前
JavaScript 中constructor 属性的指向异常问题
开发语言·javascript
最初的↘那颗心1 小时前
Java HashMap深度解析:原理、实现与最佳实践
java·开发语言·面试·hashmap·八股文
后台开发者Ethan2 小时前
Python需要了解的一些知识
开发语言·人工智能·python
盼小辉丶2 小时前
PyTorch生成式人工智能——使用MusicGen生成音乐
pytorch·python·深度学习·生成模型
常利兵3 小时前
Kotlin作用域函数全解:run/with/apply/let/also与this/it的魔法对决
android·开发语言·kotlin