TensorFlow 2.0 手写数字分类教程之SparseCategoricalCrossentropy 核心原理(二)

tf.keras.losses.SparseCategoricalCrossentropy,核心是记住它的「作用」和「使用场景」,不用纠结复杂推导~

一、先明确:这个损失函数是用来干嘛的?

它的核心使命是------给模型的"分类答案"打分,告诉模型"猜得对不对、准不准",分数(损失值)越低,说明模型猜得越准。

适用场景:「单标签多分类任务」(每个样本只有一个正确答案,比如:

  • 识别图片是猫/狗/鸟(3分类);
  • 识别数字是0-9(10分类)。

二、关键特点:"稀疏标签"是什么意思?(为什么叫"稀疏")

"稀疏"是相对于"密集"(one-hot编码)来说的,核心是「标签的写法不同」:

  • 比如做"猫(0)、狗(1)、鸟(2)"3分类:
    • 「密集标签(one-hot)」:正确答案是狗,标签要写成 [0,1,0](像选择题的"答题卡",只有正确选项打勾);
    • 「稀疏标签」:正确答案是狗,标签直接写成 1(像填空题的"答案编号",直接写正确选项的序号)。

这个损失函数的第一个核心优势:不用手动把标签改成one-hot格式,直接用整数序号(0、1、2...)就行,省事儿还省内存(比如1000分类时,稀疏标签只存1个整数,one-hot要存1000个0和1)。

三、核心逻辑:它是怎么"打分"的?(不用公式!)

模型分类时,最终会输出「每个类别的"置信度"」(比如猜猫的置信度0.1、狗0.8、鸟0.1),损失函数的打分规则很简单:

规则:「正确类别的置信度越高,损失越低;置信度越低,损失越高」

举3个直观例子(3分类,正确答案是狗,标签=1):

模型输出(每个类别的置信度) 正确类别的置信度 损失值(打分结果) 模型表现
[0.1, 0.8, 0.1] 0.8(很高) 0.22(很低) 猜得准,加分!
[0.3, 0.5, 0.2] 0.5(中等) 0.69(中等) 猜得一般
[0.9, 0.05, 0.05] 0.05(很低) 2.99(很高) 猜反了,扣分!

简单说:损失函数就像一个"评委",只盯着「正确答案对应的置信度」------你越确定正确答案,得分(损失)越好;越不确定甚至猜反,得分越差。

四、关键参数:2个必须搞懂的设置(实际用的时候用得到)

1. from_logits=True/False(最关键,默认False)
  • 先搞懂「logits」:模型最后一层没经过任何处理的"原始得分"(比如 [1.0, 3.0, 0.5]),不是0-1之间的置信度;
  • 「置信度」:把logits通过「Softmax函数」转换后得到的结果(比如上面的 [0.1, 0.8, 0.1]),总和是1,符合"概率"的逻辑。

参数选择:

  • 推荐用 from_logits=True:直接把模型的原始得分(logits)传给损失函数,它内部会自己转换置信度,还能避免计算出错(比如原始得分太大时,直接算置信度会溢出);
  • from_logits=False(默认):必须确保模型输出是0-1之间的置信度(比如最后一层加了Softmax),否则会报错或计算不准。
2. reduction(损失的"汇总方式",默认不用改)

实际训练时,一次会喂给模型一批数据(比如32个样本),这个参数控制"怎么把32个样本的损失汇总成一个数":

  • 默认是 SUM_OVER_BATCH_SIZE:求所有样本损失的「平均值」(比如32个样本的损失加起来除以32),方便模型调整参数;
  • 简单理解:不用管它,默认设置就够用。

五、和常见的 CategoricalCrossentropy 怎么选?(避免用错)

两个都是多分类损失函数,核心区别就是「标签格式」,用表格一眼看明白:

损失函数 标签格式要求 适用场景 举个例子(3分类,正确答案是狗)
SparseCategoricalCrossentropy 整数稀疏标签(0/1/2...) 标签是类别序号,不想手动转one-hot 标签直接写 1
CategoricalCrossentropy one-hot密集标签([0,1,0]) 标签已经是one-hot格式 标签必须写 [0,1,0]

总结:如果你的标签是"0、1、2"这种整数,直接用 SparseCategoricalCrossentropy;如果是"[0,1,0]"这种向量,用 CategoricalCrossentropy

六、实际用的时候要注意的2个坑(避坑指南)

  1. 标签必须是「0到类别数-1」的整数:比如3分类,标签只能是0、1、2,不能是3或-1,否则会报错;
  2. 只适用于「单标签多分类」:如果每个样本有多个正确答案(比如一张图里既有猫又有狗),不能用这个,要换 BinaryCrossentropy

最后:简单代码示例(直观感受)

用最朴素的代码,看它怎么工作:

python 复制代码
import tensorflow as tf

# 1. 定义损失函数(推荐from_logits=True)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

# 2. 模拟数据:模型预测的原始得分(logits)、真实标签(稀疏标签)
y_true = tf.constant([1, 0])  # 2个样本的真实标签:第1个是1(狗),第2个是0(猫)
y_pred_logits = tf.constant([[1.0, 3.0, 0.5], [5.0, 1.0, 0.1]])  # 模型输出的原始得分

# 3. 计算损失
loss = loss_fn(y_true, y_pred_logits)
print("批量损失值:", loss.numpy())  # 输出约0.15(两个样本损失的平均值,数值越小越好)

运行结果说明:模型对这两个样本的预测整体不错,损失值很低~

核心总结(记3句话就行)

  1. 用途:给「单标签多分类」模型打分,判断预测准不准;
  2. 特点:直接用整数标签(0/1/2...),不用转one-hot,省事儿;
  3. 逻辑:正确类别的置信度越高,损失越低,模型越棒。
相关推荐
十铭忘2 小时前
SAM2跟踪的理解12——mask decoder
人工智能·计算机视觉
PS1232322 小时前
隔爆型防爆压力变送器的多信号输出优势
大数据·人工智能
人工智能培训2 小时前
国内外知名大模型及应用
人工智能·深度学习·神经网络·大模型·dnn·ai大模型·具身智能
bryant_meng2 小时前
【GA-Net】《GA-Net: Guided Aggregation Net for End-to-end Stereo Matching》
人工智能·深度学习·计算机视觉·立体匹配·ganet
爱学习的张大2 小时前
如何选择正确版本的CUDA和PyTorch安装
人工智能·pytorch·python
大千AI助手2 小时前
DeepSeek V3.2 技术解读:一次不靠“堆参数”的模型升级
人工智能·机器学习·agent·dsa·deepseek·deepseek-v3.2·大千ai助手
十铭忘2 小时前
SAM2跟踪的理解13——mask decoder
人工智能·深度学习
大、男人2 小时前
FastMCP 高级特性之Background Tasks
人工智能·python·mcp·fastmcp
rayufo2 小时前
arXiv论文《Content-Aware Transformer for All-in-one Image Restoration》解读与代码实现
人工智能·深度学习·transformer