Tensorflow 中的损失函数 —— loss 专题汇总

回归和分类是监督学习中的两个大类。自学过程中,阅读别人代码时经常看到不同种类的损失函数,到底 Tensorflow 中有多少自带的损失函数呢,什么情况下使用什么样的损失函数?这次就来汇总介绍一下。

一、处理回归问题

1. tf.losses.mean_squared_error:均方根误差(MSE) ------ 回归问题中最常用的损失函数

优点是便于梯度下降,误差大时下降快,误差小时下降慢,有利于函数收敛。

缺点是受明显偏离正常范围的离群样本的影响较大

复制代码
# Tensorflow中集成的函数
mse = tf.losses.mean_squared_error(y_true, y_pred)
# 利用Tensorflow基础函数手工实现
mse = tf.reduce_mean(tf.square(y_true -  y_pred))

2. tf.losses.absolute_difference:平均绝对误差(MAE) ------ 想格外增强对离群样本的健壮性时使用

优点是其克服了 MSE 的缺点,受偏离正常范围的离群样本影响较小。

缺点是收敛速度比 MSE 慢,因为当误差大或小时其都保持同等速度下降,而且在某一点处还不可导,计算机求导比较困难。

复制代码
maes = tf.losses.absolute_difference(y_true, y_pred)
maes_loss = tf.reduce_sum(maes)

3. tf.losses.huber_loss:Huber loss ------ 集合 MSE 和 MAE 的优点,但是需要手动调超参数

核心思想是,检测真实值(y_true)和预测值(y_pred)之差的绝对值在超参数 δ 内时,使用 MSE 来计算 loss, 在 δ 外时使用类 MAE 计算 loss。sklearn 关于 huber 回归的文档中建议将 δ=1.35 以达到 95% 的有效性。

复制代码
hubers = tf.losses.huber_loss(y_true, y_pred)
hubers_loss = tf.reduce_sum(hubers)

二、处理分类问题

1. tf.nn.sigmoid_cross_entropy_with_logits:先 sigmoid 再求交叉熵 ------ 二分类问题首选

使用时,一定不要将预测值(y_pred)进行 sigmoid 处理,否则会影响训练的准确性,因为函数内部已经包含了 sigmoid 激活(若已先行 sigmoid 处理过了,则 tensorflow 提供了另外的函数) 。真实值(y_true)则要求是 One-hot 编码形式。

函数求得的结果是一组向量,是每个维度单独的交叉熵,如果想求总的交叉熵,使用 tf.reduce_sum() 相加即可;如果想求 loss ,则使用 tf.reduce_mean() 进行平均。

复制代码
# Tensorflow中集成的函数
sigmoids = tf.nn.sigmoid_cross_entropy_with_logits(labels=y, logits=y_pred)
sigmoids_loss = tf.reduce_mean(sigmoids)

# 利用Tensorflow基础函数手工实现
y_pred_si = 1.0/(1+tf.exp(-y_pred))
sigmoids = -y_true*tf.log(y_pred_si) - (1-y_true)*tf.log(1-y_pred_si)
sigmoids_loss = tf.reduce_mean(sigmoids)

2. tf.losses.log_loss:交叉熵 ------ 效果同上,预测值格式略有不同

预测值(y_pred)计算完成后,若已先行进行了 sigmoid 处理,则使用此函数求 loss ,若还没经过 sigmoid 处理,可直接使用 sigmoid_cross_entropy_with_logits。

复制代码
# Tensorflow中集成的函数
logs = tf.losses.log_loss(labels=y, logits=y_pred)
logs_loss = tf.reduce_mean(logs)

# 利用Tensorflow基础函数手工实现
logs = -y_true*tf.log(y_pred) - (1-y_true)*tf.log(1-y_pred)
logs_loss = tf.reduce_mean(logs)

3. tf.nn.softmax_cross_entropy_with_logits_v2:先 softmax 再求交叉熵 ------ 多分类问题首选

使用时,预测值(y_pred)同样是没有经过 softmax 处理过的值,真实值(y_true)要求是 One-hot 编码形式。

复制代码
softmaxs = tf.nn.softmax_cross_entropy_with_logits_v2(labels=y, logits=y_pred)
softmaxs_loss = tf.reduce_mean(softmaxs)
v1.8之前为 tf.nn.softmax_cross_entropy_with_logits(),新函数修补了旧函数的不足,两者在使用方法上是一样的。

4. tf.nn.sparse_softmax_cross_entropy_with_logits:效果同上,真实值格式略有不同

若真实值(y_true)不是 One-hot 格式的,可以使用此函数,可省略一步转换

复制代码
softmaxs_sparse = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y, logits=y_pred)
softmaxs_sparse_loss = tf.reduce_mean(softmaxs_sparse)

5. tf.nn.weighted_cross_entropy_with_logits:带权重的 sigmoid 交叉熵 ------ 适用于正、负样本数量差距过大时

增加了一个权重的系数,用来平衡正、负样本差距,可在一定程度上解决差距过大时训练结果严重偏向大样本的情况。

复制代码
# Tensorflow中集成的函数
sigmoids_weighted = tf.nn.weighted_cross_entropy_with_logits(targets=y, logits=y_pred, pos_weight)
sigmoids_weighted_loss = tf.reduce_mean(sigmoids_weighted)

# 利用Tensorflow基础函数手工实现
sigmoids_weighted = -y_true*tf.log(y_pred) * weight - (1-y_true)*tf.log(1-y_pred)
sigmoids_loss = tf.reduce_mean(sigmoids)

6. tf.losses.hinge_loss:铰链损失函数 ------ SVM 中使用

hing_loss 是为了求出不同类别间的"最大间隔",此特性尤其适用于 SVM(支持向量机)。使用 SVM 做分类,与 LR(Logistic Regression 对数几率回归)相比,其优点是小样本量便有不错效果、对噪点包容性强,缺点是样本量大时效率低、有时很难找到合适的区分方法。

复制代码
hings = tf.losses.hinge_loss(labels=y, logits=y_pred, weights)
hings_loss = tf.reduce_mean(hings)

三、自定义损失函数

标准的损失函数并不合适所有场景,有些实际的背景需要采用自己构造的损失函数,

复制代码
Tensorflow 也提供了丰富的基础函数供自行构建。
例如下面的例子:当预测值(y_pred)比真实值(y_true)大时,使用 (y_pred-y_true)*loss_more 作为 loss,反之,使用 (y_true-y_pred)*loss_less

loss = tf.reduce_sum(tf.where(tf.greater(y_pred, y_true), (y_pred-y_true)*loss_more,(y_true-y_pred)*loss_less))
tf.greater(x, y):判断 x 是否大于 y,当维度不一致时广播后比较
tf.where(condition, x, y):当 condition 为 true 时返回 x,否则返回 y
tf.reduce_mean():沿维度求平均
tf.reduce_sum():沿维度相加
tf.reduce_prod():沿维度相乘
tf.reduce_min():沿维度找最小
tf.reduce_max():沿维度找最大
使用 Tensorflow 提供的方法可自行构造想要的损失函数。
相关推荐
程序员水自流1 分钟前
【AI大模型第13集】Transformer底层架构原理详细介绍(核心组件拆解分析)
java·人工智能·架构·llm·transformer
code_pgf2 分钟前
openclaw配置高德导航、京东商品搜索、QQ 音乐播放控制
人工智能·gateway·边缘计算
IT观测3 分钟前
品牌在AI中的影响力如何评估?2026年AI营销工具实战选型指南
大数据·人工智能
ai_xiaogui4 分钟前
PanelAI前端全面升级!私有化部署AI面板控制台+生态市场一键管理详解
前端·人工智能·comfyui一键部署·生态市场算力共享·ai面板控制台·panelai私有化部署·大模型前端管理
海水冷却7 分钟前
RTC成语音AI基础设施:AWS和ElevenLabs相继跟进,ZEGO已跑三年
人工智能·实时音视频·aws
QC·Rex8 分钟前
国产大模型应用实践:从 0 到 1 搭建企业级 AI 助手
人工智能·langchain·大语言模型·rag·企业应用·ai 助手
墨染天姬9 分钟前
【AI】ollama和vLLM怎么选
人工智能
源码学社11 分钟前
DeerFlow 2.0:字节跳动开源的超级智能体框架,让AI真正“干活”
人工智能·开源
xingyuzhisuan15 分钟前
租用GPU服务器后,快速搭建Stable Diffusion WebUI并实现公网访问全指南
服务器·人工智能·云计算·gpu算力
fengfuyao98519 分钟前
STM32智能桌面宠物-AI机器狗设计与实现
人工智能·stm32·宠物