TensorFlow 2.0 手写数字分类教程之SparseCategoricalCrossentropy 核心原理(二)

tf.keras.losses.SparseCategoricalCrossentropy,核心是记住它的「作用」和「使用场景」,不用纠结复杂推导~

一、先明确:这个损失函数是用来干嘛的?

它的核心使命是------给模型的"分类答案"打分,告诉模型"猜得对不对、准不准",分数(损失值)越低,说明模型猜得越准。

适用场景:「单标签多分类任务」(每个样本只有一个正确答案,比如:

  • 识别图片是猫/狗/鸟(3分类);
  • 识别数字是0-9(10分类)。

二、关键特点:"稀疏标签"是什么意思?(为什么叫"稀疏")

"稀疏"是相对于"密集"(one-hot编码)来说的,核心是「标签的写法不同」:

  • 比如做"猫(0)、狗(1)、鸟(2)"3分类:
    • 「密集标签(one-hot)」:正确答案是狗,标签要写成 [0,1,0](像选择题的"答题卡",只有正确选项打勾);
    • 「稀疏标签」:正确答案是狗,标签直接写成 1(像填空题的"答案编号",直接写正确选项的序号)。

这个损失函数的第一个核心优势:不用手动把标签改成one-hot格式,直接用整数序号(0、1、2...)就行,省事儿还省内存(比如1000分类时,稀疏标签只存1个整数,one-hot要存1000个0和1)。

三、核心逻辑:它是怎么"打分"的?(不用公式!)

模型分类时,最终会输出「每个类别的"置信度"」(比如猜猫的置信度0.1、狗0.8、鸟0.1),损失函数的打分规则很简单:

规则:「正确类别的置信度越高,损失越低;置信度越低,损失越高」

举3个直观例子(3分类,正确答案是狗,标签=1):

模型输出(每个类别的置信度) 正确类别的置信度 损失值(打分结果) 模型表现
0.1, 0.8, 0.1 0.8(很高) 0.22(很低) 猜得准,加分!
0.3, 0.5, 0.2 0.5(中等) 0.69(中等) 猜得一般
0.9, 0.05, 0.05 0.05(很低) 2.99(很高) 猜反了,扣分!

简单说:损失函数就像一个"评委",只盯着「正确答案对应的置信度」------你越确定正确答案,得分(损失)越好;越不确定甚至猜反,得分越差。

四、关键参数:2个必须搞懂的设置(实际用的时候用得到)

1. from_logits=True/False(最关键,默认False)
  • 先搞懂「logits」:模型最后一层没经过任何处理的"原始得分"(比如 [1.0, 3.0, 0.5]),不是0-1之间的置信度;
  • 「置信度」:把logits通过「Softmax函数」转换后得到的结果(比如上面的 [0.1, 0.8, 0.1]),总和是1,符合"概率"的逻辑。

参数选择:

  • 推荐用 from_logits=True:直接把模型的原始得分(logits)传给损失函数,它内部会自己转换置信度,还能避免计算出错(比如原始得分太大时,直接算置信度会溢出);
  • from_logits=False(默认):必须确保模型输出是0-1之间的置信度(比如最后一层加了Softmax),否则会报错或计算不准。
2. reduction(损失的"汇总方式",默认不用改)

实际训练时,一次会喂给模型一批数据(比如32个样本),这个参数控制"怎么把32个样本的损失汇总成一个数":

  • 默认是 SUM_OVER_BATCH_SIZE:求所有样本损失的「平均值」(比如32个样本的损失加起来除以32),方便模型调整参数;
  • 简单理解:不用管它,默认设置就够用。

五、和常见的 CategoricalCrossentropy 怎么选?(避免用错)

两个都是多分类损失函数,核心区别就是「标签格式」,用表格一眼看明白:

损失函数 标签格式要求 适用场景 举个例子(3分类,正确答案是狗)
SparseCategoricalCrossentropy 整数稀疏标签(0/1/2...) 标签是类别序号,不想手动转one-hot 标签直接写 1
CategoricalCrossentropy one-hot密集标签(0,1,0 标签已经是one-hot格式 标签必须写 [0,1,0]

总结:如果你的标签是"0、1、2"这种整数,直接用 SparseCategoricalCrossentropy;如果是"0,1,0"这种向量,用 CategoricalCrossentropy

六、实际用的时候要注意的2个坑(避坑指南)

  1. 标签必须是「0到类别数-1」的整数:比如3分类,标签只能是0、1、2,不能是3或-1,否则会报错;
  2. 只适用于「单标签多分类」:如果每个样本有多个正确答案(比如一张图里既有猫又有狗),不能用这个,要换 BinaryCrossentropy

最后:简单代码示例(直观感受)

用最朴素的代码,看它怎么工作:

python 复制代码
import tensorflow as tf

# 1. 定义损失函数(推荐from_logits=True)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

# 2. 模拟数据:模型预测的原始得分(logits)、真实标签(稀疏标签)
y_true = tf.constant([1, 0])  # 2个样本的真实标签:第1个是1(狗),第2个是0(猫)
y_pred_logits = tf.constant([[1.0, 3.0, 0.5], [5.0, 1.0, 0.1]])  # 模型输出的原始得分

# 3. 计算损失
loss = loss_fn(y_true, y_pred_logits)
print("批量损失值:", loss.numpy())  # 输出约0.15(两个样本损失的平均值,数值越小越好)

运行结果说明:模型对这两个样本的预测整体不错,损失值很低~

核心总结(记3句话就行)

  1. 用途:给「单标签多分类」模型打分,判断预测准不准;
  2. 特点:直接用整数标签(0/1/2...),不用转one-hot,省事儿;
  3. 逻辑:正确类别的置信度越高,损失越低,模型越棒。
相关推荐
战族狼魂13 小时前
AI巨头IPO热潮引爆资本市场
人工智能·chatgpt·大模型·大语言模型·ai工程化
编程令我快乐13 小时前
基于AI工具的高效文档撰写方法
人工智能
Techblog of HaoWANG13 小时前
智巡守卫:多模态巡检智能体算法服务端设计与实现——基于Ollama+Qwen3.5的自动化巡检报告生成系统
运维·人工智能·算法·目标检测·自动化·边缘计算
hsg7713 小时前
简述:读《置身钉内》后读后感
人工智能
小白不白11113 小时前
Invoke的用法
开发语言·人工智能·数码相机·计算机视觉·c#
有什么事13 小时前
AI革命:云手机从脚本到智能体的跨越
人工智能·智能手机·自动化
o561-6o623o7鹿13 小时前
路,新生鼠适配器
人工智能
2601_9594779113 小时前
Vatee:外汇行情信息呈现与技术架构如何影响体验,给出一套细节
大数据·人工智能·安全·ux
KaMeidebaby13 小时前
卡梅德生物技术快报|重组蛋白的表达和纯化:工艺调试全记录:大肠杆菌体系重组蛋白的表达和纯化参数标定(肠激酶轻链案例)
前端·人工智能·算法·数据挖掘·数据分析
kishu_iOS&AI13 小时前
LLM —— LangChain
人工智能·langchain