【分类】【损失函数】处理类别不平衡:CEFL 和 CEFL2 损失函数的实现与应用

引言

在深度学习中的分类问题中,类别不平衡问题是常见的挑战之一。尤其在面部表情分类任务中,不同表情类别的样本数量可能差异较大,比如"开心"表情的样本远远多于"生气"表情。面对这种情况,普通的交叉熵损失函数容易导致模型过拟合到大类样本,忽略少数类样本。为了有效解决类别不平衡问题,Class-balanced Exponential Focal Loss (CEFL)Class-balanced Exponential Focal Loss 2 (CEFL2) 损失函数应运而生。

本文将详细介绍CEFLCEFL2损失函数,阐述它们在面部表情分类任务中的应用,并提供PyTorch实现代码,带有详细注释,适合开发者在实际项目中使用。

目录

  • 引言
  • [一、CEFL 和 CEFL2 损失函数概述](#一、CEFL 和 CEFL2 损失函数概述)
    • [1.1 Focal Loss 的背景](#1.1 Focal Loss 的背景)
    • [1.2 CEFL 的定义](#1.2 CEFL 的定义)
    • [1.3 CEFL2 的扩展与改进](#1.3 CEFL2 的扩展与改进)
    • [1.4 对比 CEFL 和 CEFL2](#1.4 对比 CEFL 和 CEFL2)
  • 二、面部表情分类中的类别不平衡问题
    • [2.1 类别不平衡对模型训练的影响](#2.1 类别不平衡对模型训练的影响)
    • [2.2 解决策略](#2.2 解决策略)
  • [三、如何使用 CEFL 和 CEFL2 损失函数](#三、如何使用 CEFL 和 CEFL2 损失函数)
    • [3.1 CEFL 和 CEFL2 损失函数的核心公式](#3.1 CEFL 和 CEFL2 损失函数的核心公式)
    • [3.2 类别频率的计算与应用](#3.2 类别频率的计算与应用)
    • [3.3 计算类别频率的代码示例](#3.3 计算类别频率的代码示例)
  • [四、PyTorch 实现](#四、PyTorch 实现)
    • [4.1 CEFL 实现](#4.1 CEFL 实现)
    • [4.2 CEFL2 实现](#4.2 CEFL2 实现)
    • [4.3 训练过程](#4.3 训练过程)
  • 总结
  • 参考文献

一、CEFL 和 CEFL2 损失函数概述

1.1 Focal Loss 的背景

在传统的分类任务中,交叉熵损失(Cross-Entropy Loss)常常用作优化目标。然而,交叉熵损失函数并没有很好地解决类别不平衡问题,特别是在少数类样本较少时。Focal Loss (焦点损失)由 Lin et al. (2017) 提出,主要用于解决 类别不平衡 问题,旨在通过减小容易分类样本的损失权重,增强模型对困难样本的关注。

Focal Loss 引入了一个调节因子 ( 1 − p t ) γ (1 - p_t)^\gamma (1−pt)γ,通过减小容易分类样本的损失,聚焦模型训练中的难分类样本,从而引导模型更加关注难以分类的样本,尤其在类别不平衡的情形下,避免多数类样本主导训练。其公式如下:

F L ( p t ) = − α t ( 1 − p t ) γ log ⁡ ( p t ) FL(p_t) = -\alpha_t (1 - p_t)^\gamma \log(p_t) FL(pt)=−αt(1−pt)γlog(pt)

其中:

  • p t p_t pt 是预测类别的概率。
  • γ \gamma γ 调节因子,是一个常量超参数,通常设为大于 0 的值(例如 2),用于控制易分类样本的惩罚程度。较大的 γ \gamma γ 会增加对难分类样本的关注。当 γ \gamma γ=0时 ,焦点损失在形式上等价于交叉熵损失。
  • α t \alpha_t αt 是一个用于平衡类别不平衡的权重因子,可以根据每个类别的频率进行调整。

1.2 CEFL 的定义

Class-balanced Exponential Focal Loss (CEFL) 是在 Focal Loss 基础上的进一步改进。它通过在焦点损失中引入类别平衡策略,赋予每个类别不同的权重,从而有效地应对类别不平衡问题。

通过引入类别平衡策略来处理类别不平衡问题。与 Focal Loss 相比,CEFL 会根据每个类别的频率赋予不同的权重,从而调整损失函数,特别是在类别不平衡的情况下更加有效。

CEFL 的公式如下:

CEFL ( p t ) = − ( 1 − p t ) log ⁡ ( p t ) − p t ( 1 − p t ) γ log ⁡ ( p t ) \text{CEFL}(p_t) = -(1 - p_t) \log(p_t) - p_t (1 - p_t)^\gamma \log(p_t) CEFL(pt)=−(1−pt)log(pt)−pt(1−pt)γlog(pt)

其中:

  • p t p_t pt:表示样本属于正确类别的预测概率。
  • γ \gamma γ:( γ \gamma γ>0)焦点损失的调节因子,通常设置为 2,用于放大难以分类的样本的损失,使得模型更加关注困难的样本。注意当 γ \gamma γ=0时 ,CEFL损失在形式上是交叉熵损失。

公式的第一项是传统的交叉熵损失,第二项则是引入焦点损失后的部分,用来减小易分类样本的影响权重,使得困难样本对总损失的贡献更大,从而模型更加专注于难分类的样本。特别地,第二项通过 ( (1 - p_t)\^\\gamma ) 调节了模型对不同难度样本的关注程度。

1.3 CEFL2 的扩展与改进

CEFL2 是对 CEFL 损失函数的扩展,它进一步考虑了类别的频率信息,通过精细的调整每个类别的损失权重,使得模型在极度不平衡的数据集上表现更好。CEFL2 引入了类别频率(class frequency)作为权重,使用每个类别在数据集中出现的频率来调整每个类别的影响。

CEFL2 的公式为:
CEFL2 ( p t ) = − ( 1 − p t ) 2 ( 1 − p t ) 2 + p t 2 log ⁡ ( p t ) − p t 2 ( 1 − p t ) 2 + p t 2 ( 1 − p t ) γ log ⁡ ( p t ) \text{CEFL2}(p_t) = -\frac{(1 - p_t)^2}{(1 - p_t)^2 + p_t^2} \log(p_t) - \frac{p_t^2}{(1 - p_t)^2 + p_t^2} (1 - p_t)^\gamma \log(p_t) CEFL2(pt)=−(1−pt)2+pt2(1−pt)2log(pt)−(1−pt)2+pt2pt2(1−pt)γlog(pt)

其中:

  • 第一个项和第二个项分别对应于不同类别的损失权重和焦点损失的加权贡献。
  • ( 1 − p t ) 2 ( 1 − p t ) 2 + p t 2 \frac{(1 - p_t)^2}{(1 - p_t)^2 + p_t^2} (1−pt)2+pt2(1−pt)2 和 p t 2 ( 1 − p t ) 2 + p t 2 \frac{p_t^2}{(1 - p_t)^2 + p_t^2} (1−pt)2+pt2pt2是根据类别的频率对损失进行调整的权重项。具体来说,它们的比例反映了每个类别相对于整个数据集的频率。

该损失函数通过动态调整类别的权重,使得模型对少数类样本的损失更加敏感,从而提升对少数类的识别能力。

1.4 对比 CEFL 和 CEFL2

特性 CEFL CEFL2
核心思想 结合焦点损失和类别平衡 引入类别频率,进一步优化类别平衡
类别权重 通过 α t \alpha_t αt 设置权重 通过类别频率动态调整权重
适用场景 通用的类别不平衡问题 极度不平衡的类别问题
主要优点 简单有效,适合一般类别不平衡问题 更适用于处理极端类别不平衡的数据

二、面部表情分类中的类别不平衡问题

2.1 类别不平衡对模型训练的影响

在面部表情分类任务中,可能会出现不同表情类别样本不平衡的情况。例如,常见表情如"开心"或"惊讶"在数据集中占有大量样本,而"生气"或"害怕"等情绪类别可能样本较少。这种类别不平衡将导致模型偏向于大类表情,忽视少数类表情,从而影响分类性能,尤其是对少数类样本的识别。

影响

  • 模型可能会对大类表情有较高的分类准确率,而忽视少数类表情。
  • 少数类表情样本的训练效果较差,难以学到有效的特征表示。

2.2 解决策略

使用 CEFLCEFL2 损失函数可以有效缓解类别不平衡问题,在训练过程中让模型更多关注少数类样本,从而提升少数类样本的分类效果。

三、如何使用 CEFL 和 CEFL2 损失函数

3.1 CEFL 和 CEFL2 损失函数的核心公式

损失函数 公式 说明
CEFL − ( 1 − p t ) log ⁡ ( p t ) − p t ( 1 − p t ) γ log ⁡ ( p t ) -(1 - p_t) \log(p_t) - p_t (1 - p_t)^\gamma \log(p_t) −(1−pt)log(pt)−pt(1−pt)γlog(pt) 基于 Focal Loss,加入类别权重调整
CEFL2 − ( 1 − p t ) 2 ( 1 − p t ) 2 + p t 2 log ⁡ ( p t ) − p t 2 ( 1 − p t ) 2 + p t 2 ( 1 − p t ) γ log ⁡ ( p t ) -\frac{(1 - p_t)^2}{(1 - p_t)^2 + p_t^2} \log(p_t) - \frac{p_t^2}{(1 - p_t)^2 + p_t^2} (1 - p_t)^\gamma \log(p_t) −(1−pt)2+pt2(1−pt)2log(pt)−(1−pt)2+pt2pt2(1−pt)γlog(pt) 引入类别频率,进一步调整损失权重

3.2 类别频率的计算与应用

CEFL2 中,需要根据训练集中的类别分布计算每个类别的频率。这些频率作为权重在损失函数中进行调整。类别频率的计算公式如下:

class_freq t = 1 num_samples_in_class t \text{class\_freq}_t = \frac{1}{\text{num\_samples\_in\_class}_t} class_freqt=num_samples_in_classt1

随后,将类别频率归一化,使其和为 1:

normalized_class_freq t = class_freq t ∑ class_freq \text{normalized\_class\_freq}_t = \frac{\text{class\_freq}_t}{\sum \text{class\_freq}} normalized_class_freqt=∑class_freqclass_freqt

3.3 计算类别频率的代码示例

python 复制代码
import numpy as np

def compute_class_frequencies(targets, num_classes):
    # 计算每个类别的样本数量
    class_counts = np.bincount(targets.numpy(), minlength=num_classes)
    
    # 防止除零错误,计算每个类别的频率
    class_freq = 1.0 / (class_counts + 1e-6)
    
    # 归一化类别频率
    class_freq = class_freq / np.sum(class_freq)
    
    return torch.tensor(class_freq, dtype=torch.float32)

四、PyTorch 实现

4.1 CEFL 实现

python 复制代码
import torch	
import torch.nn as nn
import torch.nn.functional as F

class CEFL(nn.Module):
    def __init__(self, alpha, gamma=2.0):
        super(CEFL, self).__init__()
        self.alpha = alpha  # 类别的权重
        self.gamma = gamma  # 焦点损失的调节参数

    def forward(self, inputs, targets):
        # 使用softmax计算类别概率
        p = F.softmax(inputs, dim=1)
        
        # 选择正确类别的预测概率
        p_t = p.gather(1, targets.view(-1, 1))
        
        # 计算损失
        loss = -self.alpha * (1 - p_t) ** self.gamma * torch.log(p_t)
        
        return loss.mean()

代码解释

  1. 类的构造函数 (__init__):

    • alpha: 这是一个超参数,用于对各类别的损失加权。它在训练过程中控制类别的重要性。一般来说,alpha 用来增加或减少某些类别的损失权重(通常在类别不平衡时使用)。
    • gamma: 这是焦点损失的调节参数。焦点损失(Focal Loss)是一种为了解决类别不平衡问题而提出的损失函数,gamma 控制模型对易分类样本和难分类样本的关注程度。较大的 gamma 会增加对难分类样本的关注。
  2. forward 方法:

    • inputs: 网络的输出(通常是 logits),大小为 (batch_size, num_classes),表示每个样本对于每个类别的预测得分。
    • targets: 真实标签,大小为 (batch_size,),是样本的正确类别标签。
  3. F.softmax(inputs, dim=1):

    • 对模型的输出 logits 进行 softmax 计算,将其转化为概率分布。softmax 的作用是将每个样本的所有类别得分转化为一个概率分布,概率值的总和为 1。
    • dim=1 表示在类别维度上进行归一化,即每个样本的类别概率和为 1。
  4. p.gather(1, targets.view(-1, 1)):

    • p 是通过 softmax 得到的类别概率矩阵,p.gather(1, targets.view(-1, 1)) 选择每个样本的正确类别的概率。
    • gather(1, targets.view(-1, 1)) 会根据 targets 中给出的标签索引,从 p 中提取每个样本对应类别的概率。view(-1, 1)targets 转换为列向量,确保正确地索引每个样本的类别。
  5. 焦点损失部分:

    • loss = -self.alpha * (1 - p_t) ** self.gamma * torch.log(p_t):
      • p_t: 每个样本在正确类别上的预测概率。
      • (1 - p_t) ** self.gamma: 这是焦点损失的核心部分。它会放大模型对难分类样本的关注。对于那些预测较为确定的样本(即 p_t 接近 1),(1 - p_t) 会较小,损失减少;对于难分类样本(即 p_t 接近 0),(1 - p_t) 会较大,损失增加。
      • self.alpha: 用于控制类别的重要性。如果某些类别较为不平衡,alpha 可以增加这些类别的损失权重。
      • torch.log(p_t): 计算类别概率的对数值,通常是交叉熵的一部分。
  6. 返回平均损失:

    • loss.mean(): 返回所有样本的平均损失。

4.2 CEFL2 实现

python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F

class CEFL2(nn.Module):
    def __init__(self, class_frequencies, gamma=2.0):
        super(CEFL2, self).__init__()
        self.class_frequencies = class_frequencies  # 类别频率
        self.gamma = gamma  # 焦点损失的调节参数

    def forward(self, inputs, targets):
        # 使用softmax计算类别概率
        p = F.softmax(inputs, dim=1)
        
        # 选择正确类别的预测概率
        p_t = p.gather(1, targets.view(-1, 1))

        # 计算每个类别的加权损失
        loss_term_1 = (1 - p_t)**2 / ((1 - p_t)**2 + p_t**2) * torch.log(p_t)
        loss_term_2 = p_t**2 / ((1 - p_t)**2 + p_t**2) * (1 - p_t)**self.gamma * torch.log(p_t)
        
        # 将每个类别的频率作为加权项
        loss = -self.class_frequencies[targets] * (loss_term_1 + loss_term_2)
        
        return loss.mean()

代码解释

  1. 类的构造函数 (__init__):

    • class_frequencies: 这是每个类别的频率。通常,频率是类别样本的出现概率或样本的加权值。该参数在处理类别不平衡时尤其重要。较少出现的类别会赋予较高的权重,以便模型对这些类别更敏感。
    • gamma: 和 CEFL 中的 gamma 相同,用于调节焦点损失的程度,控制对难分类样本的关注。
  2. forward 方法:

    • inputs: 与 CEFL 相同,是模型的输出(即 logits)。
    • targets: 与 CEFL 相同,是真实标签。
  3. F.softmax(inputs, dim=1):

    • 对模型的输出 inputs 进行 softmax 计算,得到每个样本在各类别上的概率。
  4. p.gather(1, targets.view(-1, 1)):

    • gather 方法用来根据 targets 中的标签提取每个样本的正确类别的预测概率。
  5. 加权损失部分

    • loss_term_1 = (1 - p_t)**2 / ((1 - p_t)**2 + p_t**2) * torch.log(p_t):
      • 这是针对正确类别概率 p_t 的一个加权项。这个项的目的是将模型的关注点放在难分类的样本上。计算时考虑了正确类和错误类之间的比例,进而调整损失值。
    • loss_term_2 = p_t**2 / ((1 - p_t)**2 + p_t**2) * (1 - p_t)**self.gamma * torch.log(p_t):
      • 另一个加权项,考虑了模型对难分类样本的关注(即当 p_t 小,样本难分类时),通过增加 gamma 使得模型对难分类样本的权重更加突出。
    • 这两个损失项的组合有助于在类别不平衡问题中进行加权,增强对少数类的学习。
  6. 加权频率项

    • loss = -self.class_frequencies[targets] * (loss_term_1 + loss_term_2):
      • 将每个类别的频率(class_frequencies)引入损失计算中。这使得类别频率较低的类别(通常是少数类)在计算损失时有更高的权重,从而让模型更加关注少数类。
  7. 返回平均损失

    • loss.mean(): 返回所有样本的加权平均损失。

4.3 训练过程

为了更好地展示如何在训练过程中使用 CEFLCEFL2 损失函数,,并将其应用于一个简单的神经网络模型。以下是更新后的代码示例:

python 复制代码
import torch
from torch.utils.data import DataLoader, Dataset
import numpy as np

# 计算类别频率的函数
def compute_class_frequencies(targets, num_classes):
    # 计算每个类别的样本数量
    class_counts = torch.bincount(targets, minlength=num_classes)
    
    # 防止除零错误,计算每个类别的频率
    class_freq = 1.0 / (class_counts.float() + 1e-6)
    
    # 归一化类别频率
    class_freq = class_freq / class_freq.sum()
    
    return class_freq

# 自定义数据集类
class CustomDataset(Dataset):
    def __init__(self, data, targets):
        self.data = data
        self.targets = targets

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.targets[idx]

# 创建一个虚拟数据集
num_samples = 1000
num_classes = 7
input_dim = 128
data = torch.randn(num_samples, input_dim)
targets = torch.randint(0, num_classes, (num_samples,))

# 计算每个类别的频率
class_frequencies = compute_class_frequencies(targets, num_classes)

# 创建数据加载器
dataset = CustomDataset(data, targets)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

# 定义模型(例如一个简单的全连接网络)
class SimpleModel(nn.Module):
    def __init__(self, input_dim, num_classes):
        super(SimpleModel, self).__init__()
        self.fc = nn.Linear(input_dim, num_classes)

    def forward(self, x):
        return self.fc(x)

# 初始化模型和损失函数
model = SimpleModel(input_dim, num_classes)
criterion = CEFL2(class_frequencies)  # 使用 CEFL2 损失函数
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# 训练模型
epochs = 10
for epoch in range(epochs):
    model.train()
    running_loss = 0.0
    for inputs, targets in dataloader:
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()

    avg_loss = running_loss / len(dataloader)
    print(f"Epoch [{epoch+1}/{epochs}], Loss: {avg_loss:.4f}")

代码解释

  1. 类别频率计算 (compute_class_frequencies):

    • 这个函数计算了每个类别在数据中的出现频率。我们通过计算每个类别的出现次数并进行归一化,得到类别频率。
    • torch.bincount(targets) 用于计算每个类别出现的次数,随后通过频率逆转的方式进行归一化。
  2. 自定义数据集 (CustomDataset):

    • 这个自定义数据集类返回每个样本的输入和标签对,适用于 PyTorch 的 DataLoader
  3. 模型定义 (SimpleModel):

    • 定义了一个简单的全连接层的神经网络,用于演示如何应用损失函数。
    • 模型输入为 input_dim,输出为 num_classes
  4. 训练循环:

    • 在每个 epoch 中,模型通过前向传播获得预测结果,并计算损失。
    • 使用 CEFL2 损失函数,基于每个类别的频率进行加权损失计算。
    • optimizer.zero_grad() 清空之前的梯度,loss.backward() 计算梯度,optimizer.step() 更新模型权重。

总结

在本文中,我们深入探讨了 Class-balanced Exponential Focal Loss (CEFL)Class-balanced Exponential Focal Loss 2 (CEFL2) 损失函数的定义、原理及其应用,重点介绍了它们如何有效解决类别不平衡问题。通过引入类别权重和类别频率,这些损失函数能够帮助模型在训练过程中更好地关注少数类样本,避免对多数类样本的过拟合,从而提升少数类的分类性能。

本文还提供了 PyTorch 实现的详细代码,包括如何计算类别频率、定义损失函数,并在训练过程中应用它们。

为帮助理解类别频率的影响,以下图示展示了不同类别在训练过程中损失调整的效果:
CSDN @ 2136 原始训练集 计算类别频率 计算类别频率加权后的损失 优化模型 训练结果 类别频率 损失调整 模型优化 CSDN @ 2136

图中展示了训练过程中如何计算类别频率,并利用这些频率对损失进行加权,从而优化模型训练效果。

通过本文的讲解,您应该对 CEFLCEFL2 损失函数的定义、实现和应用有了更深刻的理解。如果您正在处理类别不平衡的分类任务,不妨尝试使用这些损失函数,它们能有效提升模型的性能,特别是在少数类样本的分类效果上。

参考文献

  • T.-Y. Lin, P. Goyal, R. Girshick, K. He, and P. Dollar, ``Focal loss for dense object detection,'' in Proc. IEEE Int. Conf. Comput. Vis., Oct. 2017, pp. 2980-2988.doi:10.48550/arXiv.1708.02002.
  • L. Wang, C. Wang, Z. Sun, S. Cheng and L. Guo, "Class Balanced Loss for Image Classification," in IEEE Access, vol. 8, pp. 81142-81153, 2020, doi: 10.1109/ACCESS.2020.2991237.

相关推荐
guoji77882 分钟前
安全与对齐的深层博弈:Gemini 3.1 Pro 安全护栏与对抗测试深度拆解
人工智能·安全
实在智能RPA10 分钟前
实在 Agent 和通用大模型有什么不一样?深度拆解 AI Agent 的感知、决策与执行逻辑
人工智能·ai
独隅15 分钟前
PyTorch 模型部署的 Docker 配置与性能调优深入指南
人工智能·pytorch·docker
lihuayong22 分钟前
OpenClaw 系统提示词
人工智能·prompt·提示词·openclaw
黑客说36 分钟前
AI驱动剧情,解锁无限可能——AI游戏发展解析
人工智能·游戏
踩着两条虫41 分钟前
AI驱动的Vue3应用开发平台深入探究(十):物料系统之内置组件库
android·前端·vue.js·人工智能·低代码·系统架构·rxjava
小仙女的小稀罕1 小时前
听不清重要会议录音急疯?这款常见AI工具听脑AI精准转译
开发语言·人工智能·python
reesn1 小时前
qwen3.5 0.8B纠正任务实践
人工智能·语言模型
实在智能RPA1 小时前
实在Agent 制造业落地案例:探寻工业大模型从实验室走向车间的实战路径
人工智能·ai