连续变量的 交叉熵 如何计算 python tensorflow

连续变量的交叉熵通常在机器学习中的回归问题中使用,但它也可以用于分类问题,当概率分布是连续的时。连续变量的交叉熵计算公式如下:

设 \( p(x) \) 是真实概率密度函数,\( q(x) \) 是预测概率密度函数,交叉熵 \( H(p, q) \) 定义为:

\[

H(p, q) = -\int_{-\infty}^{\infty} p(x) \log q(x) \, dx

\]

在分类问题中,如果我们有 \( K \) 个类别,并且 \( p_k \) 是第 \( k \) 个类别的真实概率,\( q_k \) 是第 \( k \) 个类别的预测概率,交叉熵可以表示为:

\[

H(p, q) = -\sum_{k=1}^{K} p_k \log q_k

\]

在实际应用中,如果 \( p_k \) 是一个独热编码的向量(即只有一个类别是 1,其余都是 0),上述公式简化为:

\[

H(p, q) = -\log q_y

\]

其中 \( y \) 是真实类别的索引。

在回归问题中,如果我们有一个连续的目标变量,我们可以使用均方误差(MSE)或均方对数误差(MSLE)等其他损失函数,而不是交叉熵。然而,如果我们想要使用交叉熵,我们通常需要将问题转换为类似于分类问题的形式,例如通过将连续变量离散化或使用概率分布来建模连续变量。

下面是用梯度惩罚来实现 K L 散度最小化的实现,和交叉熵原理差不多

# 定义计算损失的函数
def compute_loss(real_data):
    # 梯度惩罚权重
    gradient_penalty_weight = gradient_penalty_weight_lamda
    # x = tf.random.normal((batch_size, n), dtype=tf.dtypes.float32)
    # x_samp = x / tf.sqrt(2 * tf.reduce_mean(tf.square(x)))
    # x_gen = tf.concat(values=[w_generator(x_samp), x_samp], axis=1)
    # x = X_train
    # x_samp = X_train_samp
    x_samp = X_train
    # todo
    # 计算损失函数的时候 ,用z_score归一化计算?
    # todo
    # x_gen = w_generator(x_samp) + x_samp
    x_gen = w_generator(x_samp)
    logits_x = w_discriminator(tf.concat([y_train, X_train], axis=-1))
    logits_x_gen = w_discriminator(tf.concat([x_gen, X_train], axis=-1))
    d_regularizer = gradient_penalty(real_data, x_gen)
    disc_loss = (tf.reduce_mean(logits_x) - tf.reduce_mean(logits_x_gen) + d_regularizer * gradient_penalty_weight)
    gen_loss = tf.reduce_mean(logits_x_gen)
    return disc_loss, gen_loss


# 定义应用生成器梯度的函数
def apply_gen_gradients(gen_gradients):
    w_gen_optimizer.apply_gradients(zip(gen_gradients, w_generator.trainable_variables))


# 定义应用判别器梯度的函数
def apply_disc_gradients(disc_gradients):
    w_disc_optimizer.apply_gradients(zip(disc_gradients, w_discriminator.trainable_variables))


# 定义梯度惩罚函数
# def gradient_penalty(x, x_gen):
#     epsilon = tf.random.uniform([x.shape[0], 1, 1, 1], 0.0, 1.0)
#     x_hat = epsilon * x + (1 - epsilon) * x_gen
#     with tf.GradientTape() as t:
#         t.watch(x_hat)
#         d_hat = w_discriminator(x_hat)
#     gradients = t.gradient(d_hat, x_hat)
#     ddx = tf.sqrt(tf.reduce_sum(gradients ** 2, axis=[1, 2]))
#     d_regularizer = tf.reduce_mean((ddx - 1.0) ** 2)
#     return d_regularizer


# 定义梯度惩罚函数
def gradient_penalty(x, x_gen):
    # 创建一个与真实样本 x 的批量大小相同的随机变量 epsilon,其值在0和1之间,用于在后续步骤中进行插值。
    # epsilon = tf.random.uniform([x.shape[0], 1, 1, 1], 0.0, 1.0)
    epsilon = tf.random.uniform([x.shape[0], 1], 0.0, 1.0)
    # 计算插值样本 x_hat,它是真实样本 x 和生成样本 x_gen 的线性组合。
    # 这一步是为了在真实样本和生成样本之间创建一个连续的路径
    x_hat = epsilon * x + (1 - epsilon) * x_gen
    # print("Shape before discriminator:", x_hat.shape)

    # 创建一个 tf.GradientTape 上下文,用于记录对 x_hat 的操作,以便后续计算梯度。
    with tf.GradientTape() as t:
        # 告诉 tf.GradientTape 监控 x_hat,以便可以计算关于它的梯度
        t.watch(x_hat)
        # 使用判别器 w_discriminator 对插值样本 x_hat 进行评分,得到 d_hat。
        # print(x_hat.shape)
        d_hat = w_discriminator(tf.concat([x_hat, X_train], axis=-1))
    # 计算判别器输出 d_hat 关于插值样本 x_hat 的梯度
    gradients = t.gradient(d_hat, x_hat)
    # 计算梯度的L2范数,即对每个样本的梯度向量进行平方和,然后开方,得到每个样本的梯度范数。
    # 在你的代码中,gradients 张量的形状是 [100, 4],但你尝试在 axis=[1, 2]
    # 上进行 tf.reduce_sum 操作。由于张量只有两个维度,所以没有第三个维度可以进行求和。
    # ddx = tf.sqrt(tf.reduce_sum(gradients ** 2, axis=[1, 2]))
    ddx = tf.sqrt(tf.reduce_sum(gradients ** 2, axis=[1]))
    # 算梯度惩罚项,它是梯度范数与1的差的平方的平均值。在WGAN中,我们希望梯度范数接近1,
    # 因此这个惩罚项会惩罚那些使梯度范数远离1的判别器。
    d_regularizer = tf.reduce_mean((ddx - 1.0) ** 2)
    # print("gradient_penalty")
    return d_regularizer
相关推荐
sukalot1 分钟前
windows C#-泛型接口
开发语言·c#
weixin_749949903 分钟前
双向列表的实现(C++)
开发语言·c++·链表
猿饵块16 分钟前
python--main--入口函数
开发语言·python
xianwu54318 分钟前
反向代理模块开发,
linux·开发语言·网络·c++·git
C++小厨神25 分钟前
SQL语言的数据库交互
开发语言·后端·golang
吴冰_hogan37 分钟前
Java 线程池 ThreadPoolExecutor 底层原理与源码分析
java·开发语言
摇光931 小时前
js状态模式
开发语言·javascript·状态模式
lzb_kkk1 小时前
【C++】JsonCpp库
开发语言·c++·json·1024程序员节
码力全開1 小时前
C 语言奇幻之旅 - 第16篇:C 语言项目实战
c语言·开发语言·数据库·windows·vscode·vim·visual studio
salsm2 小时前
使用 C++ 实现神经网络:从基础到高级优化
开发语言·c++·神经网络