TensorFlow 2.0 手写数字分类教程之SparseCategoricalCrossentropy 核心原理(一)

tf.keras.losses.SparseCategoricalCrossentropy 核心原理

SparseCategoricalCrossentropy(稀疏类别交叉熵)是 TensorFlow/Keras 中针对多分类任务 的损失函数,专为稀疏标签 (整数型标签,如 0,1,2)设计,核心作用是衡量模型输出的类别概率分布与真实稀疏标签的「差异」,本质是交叉熵(Cross-Entropy)在稀疏标签场景下的优化实现。

一、先理解核心背景:交叉熵的本质

交叉熵源于信息论,用于衡量两个概率分布的「距离」(差异程度)。对于多分类任务:

  • 真实标签的分布是「one-hot 分布」(比如 3 分类中标签为 1,对应分布是 [0,1,0]);
  • 模型输出是类别概率分布(经 Softmax 归一化后,和为 1,如 [0.1,0.8,0.1])。

交叉熵的公式为:
H(p,q)=−∑i=1Cp(i)log⁡(q(i)) H(p,q) = -\sum_{i=1}^C p(i) \log(q(i)) H(p,q)=−i=1∑Cp(i)log(q(i))

其中:

  • ppp:真实标签的概率分布(one-hot 形式,仅目标类别为 1,其余为 0);
  • qqq:模型预测的类别概率分布;
  • CCC:类别总数。

由于 ppp 是 one-hot 分布,交叉熵可简化为:仅取目标类别对应的预测概率的负对数 (因为其他项都是 0×log⁡(q(i))=00 \times \log(q(i))=00×log(q(i))=0)。

二、SparseCategoricalCrossentropy 的核心适配:稀疏标签

普通的 CategoricalCrossentropy 要求标签是one-hot 编码 (如 3 分类标签 1 对应 [0,1,0]),而 SparseCategoricalCrossentropy 直接支持整数型稀疏标签(如 1),无需手动 one-hot 编码,核心优势是节省内存(尤其是类别数多的场景,比如 1000 类时,稀疏标签仅存 1 个整数,one-hot 需存 1000 维向量)。

三、完整计算逻辑(分两种场景)

SparseCategoricalCrossentropy 的关键参数是 from_logits(默认 False),决定模型输出是否为「原始 logits(未归一化的得分)」或「Softmax 归一化后的概率」,两种场景的计算逻辑不同(TensorFlow 内部做了优化,避免数值不稳定)。

场景 1:from_logits=False(默认,模型输出是 Softmax 概率)

假设:

  • 类别数 C=3C=3C=3;
  • 真实稀疏标签 y=1y=1y=1(对应目标类别是第 2 类,索引从 0 开始);
  • 模型输出 Softmax 概率 q=[0.1,0.8,0.1]q=[0.1, 0.8, 0.1]q=[0.1,0.8,0.1]。

计算步骤:

  1. 取真实标签对应的概率:q(y)=q(1)=0.8q(y)=q(1)=0.8q(y)=q(1)=0.8;
  2. 计算负对数:−log⁡(q(y))=−log⁡(0.8)≈0.223-\log(q(y)) = -\log(0.8) ≈ 0.223−log(q(y))=−log(0.8)≈0.223;
  3. 最终损失值即为该结果(批量数据会取均值/求和,由 reduction 参数控制)。

公式简化为:
loss=−log⁡(q(y)) \text{loss} = -\log(q(y)) loss=−log(q(y))

场景 2:from_logits=True(模型输出是原始 logits,推荐!)

模型输出的是未经过 Softmax 归一化的原始得分(logits,如 z=[1.0,3.0,0.5]z=[1.0, 3.0, 0.5]z=[1.0,3.0,0.5]),此时 TensorFlow 不会先单独计算 Softmax(避免数值下溢/上溢),而是直接用 log_softmax 优化计算:

  1. 对 logits 计算 log_softmax:log⁡(Softmax(z))=z−log⁡(∑i=1Cezi)\log(\text{Softmax}(z)) = z - \log(\sum_{i=1}^C e^{z_i})log(Softmax(z))=z−log(∑i=1Cezi);
  2. 取真实标签对应的项,取负数即为损失:
    loss=−(zy−log⁡(∑i=1Cezi)) \text{loss} = - \left( z_y - \log(\sum_{i=1}^C e^{z_i}) \right) loss=−(zy−log(i=1∑Cezi))

示例计算(z=[1.0,3.0,0.5],y=1z=[1.0, 3.0, 0.5], y=1z=[1.0,3.0,0.5],y=1):

  • 先算 ∑ezi=e1.0+e3.0+e0.5≈2.718+20.085+1.648≈24.451\sum e^{z_i} = e^{1.0} + e^{3.0} + e^{0.5} ≈ 2.718 + 20.085 + 1.648 ≈ 24.451∑ezi=e1.0+e3.0+e0.5≈2.718+20.085+1.648≈24.451;
  • log⁡(24.451)≈3.200\log(24.451) ≈ 3.200log(24.451)≈3.200;
  • log⁡(Softmax(z))1=3.0−3.200=−0.200\log(\text{Softmax}(z))_1 = 3.0 - 3.200 = -0.200log(Softmax(z))1=3.0−3.200=−0.200;
  • 损失值:−(−0.200)=0.200-(-0.200) = 0.200−(−0.200)=0.200。

为什么推荐 from_logits=True

Softmax 对大 logits 会产生 e大值e^{大值}e大值(如 e100e^{100}e100 溢出),而 log_softmax 直接通过代数变换避免了单独计算 Softmax,提升数值稳定性。

四、批量数据的损失归约

实际训练中输入是批量数据(batch),损失会通过 reduction 参数归约(默认 AUTO,等价于 SUM_OVER_BATCH_SIZE):

  • 对每个样本计算损失值;
  • 求批量内所有样本损失的均值 (或求和,取决于 reduction)。

示例(batch_size=2):

样本 稀疏标签 模型概率 单样本损失
1 1 [0.1,0.8,0.1] 0.223
2 0 [0.9,0.05,0.05] 0.105
批量损失 = (0.223 + 0.105) / 2 ≈ 0.164。

五、关键参数解析

参数 作用 示例
from_logits 是否输入为原始 logits(非 Softmax 概率) from_logits=True(推荐)
reduction 损失归约方式: - NONE:返回每个样本的损失 - SUM:批量损失求和 - SUM_OVER_BATCH_SIZE:批量损失求均值 reduction="sum_over_batch_size"
ignore_index 忽略指定标签(计算损失时跳过),适用于样本标注缺失场景 ignore_index=-1
axis 类别维度(默认 -1,即最后一维是类别) 模型输出形状 (batch, 3) 时,axis=-1 对应 3 个类别

六、与 CategoricalCrossentropy 的对比

特性 SparseCategoricalCrossentropy CategoricalCrossentropy
标签格式 整数型稀疏标签(如 1,2,3) one-hot 编码标签(如 [0,1,0])
内存占用 低(仅存整数) 高(类别数维向量)
适用场景 类别数多、标签天然为整数(如图像分类的类别索引) 标签已做 one-hot 编码
核心公式 同交叉熵,但直接取整数标签对应项 交叉熵原始公式(遍历所有类别)

七、注意事项

  1. 标签范围 :稀疏标签必须是 [0,C−1][0, C-1][0,C−1] 范围内的整数(C 是类别数),否则会报错;
  2. 数值稳定性 :优先设置 from_logits=True,避免 Softmax 导致的数值溢出;
  3. 多标签任务 :该损失适用于「单标签多分类」(每个样本仅属于一个类别),多标签任务需用 BinaryCrossentropy

示例代码验证

python 复制代码
import tensorflow as tf

# 1. 定义损失函数(from_logits=True,模型输出logits)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

# 2. 模拟批量数据(batch_size=2,类别数=3)
y_true = tf.constant([1, 0])  # 稀疏标签
y_pred_logits = tf.constant([[1.0, 3.0, 0.5], [5.0, 1.0, 0.1]])  # 模型输出logits

# 3. 计算损失
loss = loss_fn(y_true, y_pred_logits)
print("批量损失值:", loss.numpy())  # 输出约 0.15(手动计算验证)

综上,SparseCategoricalCrossentropy 本质是「多分类交叉熵」在稀疏标签下的高效实现,核心是通过直接索引整数标签避免 one-hot 编码,同时优化数值计算保证稳定性,是单标签多分类任务的首选损失函数之一。

相关推荐
小宇的天下5 小时前
HBM(高带宽内存)深度解析:先进封装视角的技术指南
网络·人工智能
rongcj5 小时前
2026,“硅基经济”的时代正在悄然来临
人工智能
狼叔也疯狂5 小时前
英语启蒙SSS绘本第一辑50册高清PDF可打印
人工智能·全文检索
万行6 小时前
机器学习&第四章支持向量机
人工智能·机器学习·支持向量机
幻云20106 小时前
Next.js之道:从入门到精通
人工智能·python
予枫的编程笔记6 小时前
【Java集合】深入浅出 Java HashMap:从链表到红黑树的“进化”之路
java·开发语言·数据结构·人工智能·链表·哈希算法
llddycidy6 小时前
峰值需求预测中的机器学习:基础、趋势和见解(最新文献)
网络·人工智能·深度学习
larance6 小时前
机器学习的一些基本知识
人工智能·机器学习
l1t6 小时前
利用DeepSeek辅助拉取GitHub存储库目录跳过特定文件方法
人工智能·github·deepseek
12344526 小时前
Agent入门实战-一个题目生成Agent
人工智能·后端