吴恩达机器学习课程(PyTorch适配)学习笔记:2.4 激活函数与多类别处理

2.4 激活函数与多类别处理

在深度学习中,激活函数为网络引入非线性能力,是实现复杂模式建模的核心;而多类别处理则是解决实际分类任务(如图像识别、文本分类)的关键技术。本章将系统讲解激活函数的类型、选择依据,以及多类别分类的实现方案(含Softmax原理与PyTorch适配),并扩展至多输出分类场景。

2.4.1 激活函数(类型 + 选择依据 + 作用)

激活函数(Activation Function)是神经网络中连接"线性变换"与"非线性建模"的桥梁。没有激活函数,无论多少层的神经网络都等价于单层线性模型,无法拟合复杂数据分布。

一、激活函数的核心作用

  1. 引入非线性 :将线性变换(z=Wx+bz=Wx+bz=Wx+b)的结果映射到非线性空间,使网络能学习复杂特征(如图像边缘、文本语义)。
  2. 控制输出范围:将输出值约束在特定区间(如Sigmoid输出[0,1]),适配不同任务需求(如概率预测)。
  3. 梯度传播调节:通过合理的导数特性,缓解梯度消失/爆炸问题,保障深层网络的训练稳定性。

二、常见激活函数类型与特性

按函数形态和应用场景,激活函数可分为以下几类:

1. 饱和型激活函数(传统类型)

特点:输入值过大/过小时,函数导数趋近于0(梯度消失风险高),计算依赖指数/对数运算。

函数名称 公式 输出范围 优点 缺点 适用场景
Sigmoid σ(z)=11+e−z\sigma(z)=\frac{1}{1+e^{-z}}σ(z)=1+e−z1 (0,1) 输出可解释为概率,二分类输出层常用 梯度消失严重( z
Tanh tanh⁡(z)=ez−e−zez+e−z\tanh(z)=\frac{e^z-e^{-z}}{e^z+e^{-z}}tanh(z)=ez+e−zez−e−z (-1,1) 零中心输出(缓解梯度更新方向问题)、比Sigmoid收敛快 仍存在梯度消失( z
2. 非饱和型激活函数(主流类型)

特点:输入为正时导数恒定或随输入变化,梯度消失风险低,计算效率高(无指数/对数运算)。

(1)ReLU系列
  • 基础ReLU

    公式:ReLU(z)=max⁡(0,z)\text{ReLU}(z)=\max(0,z)ReLU(z)=max(0,z)

    输出范围:[0,+∞)

    优点:计算极快(仅比较操作)、梯度不消失(z>0时导数=1)、缓解过拟合(随机"关闭"部分神经元);

    缺点:死亡ReLU问题 (z≤0时导数=0,神经元永久失活)、输出非零中心;

    适用场景:CNN隐藏层、Transformer前馈网络(最常用激活函数)。

  • Leaky ReLU

    公式:Leaky ReLU(z)=max⁡(αz,z)\text{Leaky ReLU}(z)=\max(\alpha z,z)Leaky ReLU(z)=max(αz,z)(α\alphaα为小常数,通常取0.01)

    改进点:z<0时保留小梯度(α\alphaα),解决死亡ReLU问题;

    缺点:α\alphaα为超参数,需调优;

    适用场景:ReLU效果差时的替代方案(如深层CNN)。

  • Parametric ReLU(PReLU)

    公式:PReLU(z)=max⁡(αz,z)\text{PReLU}(z)=\max(\alpha z,z)PReLU(z)=max(αz,z)(α\alphaα为可学习参数,而非固定值)

    改进点:α\alphaα通过训练自适应调整,灵活性更高;

    缺点:增加模型参数量,可能过拟合;

    适用场景:数据量充足的复杂任务(如ImageNet分类)。

  • Exponential Linear Unit(ELU)

    公式:ELU(z)={zz>0α(ez−1)z≤0\text{ELU}(z)=\begin{cases} z & z>0 \\ \alpha(e^z-1) & z≤0 \end{cases}ELU(z)={zα(ez−1)z>0z≤0(α\alphaα通常取1)

    优点:零均值输出(缓解梯度更新问题)、抗噪声能力强(z<0时平滑衰减);

    缺点:计算需指数运算(比ReLU慢);

    适用场景:对噪声敏感的任务(如语音识别、小样本图像分类)。

(2)Swish与GELU(Transformer常用)
  • Swish

    公式:Swish(z)=z⋅σ(z)\text{Swish}(z)=z\cdot\sigma(z)Swish(z)=z⋅σ(z)(σ\sigmaσ为Sigmoid)

    特点:平滑非线性、无明显饱和区、自归一化(输出均值接近0);

    适用场景:Transformer、CNN(在某些任务上优于ReLU)。

  • GELU(Gaussian Error Linear Unit)

    公式:GELU(z)=z⋅Φ(z)\text{GELU}(z)=z\cdot\Phi(z)GELU(z)=z⋅Φ(z)(Φ\PhiΦ为标准正态分布的CDF,近似:GELU(z)≈0.5z(1+tanh⁡(2/π(z+0.044715z3)))\text{GELU}(z)\approx 0.5z(1+\tanh(\sqrt{2/\pi}(z+0.044715z^3)))GELU(z)≈0.5z(1+tanh(2/π (z+0.044715z3))))

    特点:随机激活(输出与输入的概率相关,更符合生物神经元特性)、梯度传播稳定;

    适用场景:Transformer(BERT、GPT系列默认激活函数)、预训练模型。

三、激活函数选择依据

选择激活函数需结合任务类型、网络结构、计算资源三方面因素,具体决策流程如下:

  1. 优先考虑计算效率

    • 若需快速训练(如大规模数据、实时推理):选择ReLU(最快)、Leaky ReLU;
    • 若计算资源充足(如服务器端复杂任务):可尝试GELU、Swish。
  2. 根据网络深度选择

    • 浅层网络(<5层):任意激活函数均可(ReLU、Sigmoid、Tanh);
    • 深层网络(≥10层):必须选择非饱和函数(ReLU、GELU、ELU),避免梯度消失。
  3. 根据任务类型选择

    • 二分类任务输出层:Sigmoid(输出概率);
    • 多分类任务输出层:Softmax(配合CrossEntropyLoss);
    • 回归任务输出层:无激活函数(线性输出)或ReLU(约束输出非负,如房价预测);
    • 生成模型/自编码器:Sigmoid(输出图像像素[0,1])、Tanh(输出[-1,1])。
  4. 特殊需求适配

    • 需抗噪声:ELU、GELU;
    • 需零中心输出:Tanh、ELU、GELU;
    • 小样本任务:ELU(减少过拟合风险)。

四、激活函数可视化与PyTorch实现

通过代码可视化常见激活函数的形态,并验证其导数特性:

python 复制代码
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np

# 设置中文字体
plt.rcParams["font.family"] = ["SimHei", "WenQuanYi Micro Hei"]
plt.rcParams["axes.unicode_minus"] = False  # 解决负号显示问题

# 1. 定义激活函数(含手动实现与PyTorch调用)
def sigmoid(z):
    return 1 / (1 + torch.exp(-z))

def tanh(z):
    return (torch.exp(z) - torch.exp(-z)) / (torch.exp(z) + torch.exp(-z))

def relu(z):
    return torch.maximum(z, torch.tensor(0.0))

def leaky_relu(z, alpha=0.01):
    return torch.where(z > 0, z, alpha * z)

def gelu(z):
    # GELU近似实现(与PyTorch的F.gelu一致)
    return 0.5 * z * (1 + torch.tanh(torch.sqrt(torch.tensor(2 / np.pi)) * (z + 0.044715 * z**3)))

# 2. 生成输入数据(覆盖常见输入范围)
z = torch.linspace(-5, 5, 1000)  # 输入从-5到5,共1000个点

# 3. 计算各激活函数输出
activations = {
    "Sigmoid": sigmoid(z),
    "Tanh": tanh(z),
    "ReLU": relu(z),
    "Leaky ReLU(α=0.01)": leaky_relu(z),
    "GELU": gelu(z),
    "Swish": z * sigmoid(z)  # Swish = z * Sigmoid(z)
}

# 4. 可视化激活函数
plt.figure(figsize=(12, 8))
colors = ["#FF6B6B", "#4ECDC4", "#45B7D1", "#96CEB4", "#FECA57", "#FF9FF3"]

for (name, output), color in zip(activations.items(), colors):
    plt.plot(z.numpy(), output.numpy(), label=name, linewidth=2.5, color=color)

# 添加辅助线(y=0和x=0)
plt.axhline(y=0, color="gray", linestyle="--", alpha=0.5)
plt.axvline(x=0, color="gray", linestyle="--", alpha=0.5)

# 设置图表属性
plt.xlabel("输入z", fontsize=12)
plt.ylabel("激活函数输出a", fontsize=12)
plt.title("常见激活函数形态对比", fontsize=14, fontweight="bold")
plt.legend(fontsize=10)
plt.grid(alpha=0.3)
plt.savefig("activation_functions_comparison.png", dpi=300, bbox_inches="tight")
plt.show()

# 5. 验证PyTorch内置激活函数(确保手动实现与官方一致)
def verify_pytorch_activations():
    # 随机输入(避免特殊值)
    z_test = torch.randn(10)
    print("输入z:", z_test.numpy())
    
    # 对比手动实现与PyTorch内置函数
    print("\nSigmoid对比:")
    print("手动实现:", sigmoid(z_test).numpy())
    print("PyTorch F.sigmoid:", F.sigmoid(z_test).numpy())
    print("是否一致:", torch.allclose(sigmoid(z_test), F.sigmoid(z_test), atol=1e-6))
    
    print("\nGELU对比:")
    print("手动实现:", gelu(z_test).numpy())
    print("PyTorch F.gelu:", F.gelu(z_test).numpy())
    print("是否一致:", torch.allclose(gelu(z_test), F.gelu(z_test), atol=1e-6))

verify_pytorch_activations()


输出结果

  • 图表将显示6种激活函数的形态,可直观观察ReLU的"硬截断"、GELU的"平滑过渡"、Sigmoid的"饱和特性"。
  • 验证部分会输出"是否一致: True",说明手动实现与PyTorch官方函数精度一致。

五、注意事项与易错点

  1. 死亡ReLU问题

    • 原因:学习率过大导致部分神经元的权重更新后,输入z永久≤0(导数=0,无法再更新);
    • 解决方案:改用Leaky ReLU/PReLU、降低学习率、使用He初始化(针对ReLU的专用初始化)。
  2. Sigmoid的梯度消失

    • 避免在深层网络的隐藏层使用Sigmoid(仅用于二分类输出层);
    • 若必须使用,需配合小学习率和批量归一化(BatchNorm)缓解梯度消失。
  3. GELU的实现差异

    • PyTorch 1.10+的F.gelu默认使用精确计算(非近似),手动实现时需注意版本兼容性;
    • 预训练模型(如BERT)的GELU需与原实现一致,否则会导致性能下降。
  4. 激活函数与初始化的匹配

    • ReLU系列需用He初始化nn.init.kaiming_normal_);
    • Sigmoid/Tanh需用Xavier初始化nn.init.xavier_normal_);
    • 不匹配会导致网络训练缓慢或梯度爆炸。

2.4.2 Sigmoid 替代方案(ReLU 系列等)

Sigmoid作为最早的激活函数之一,因梯度消失、计算效率低等问题,在隐藏层中已逐渐被替代。本节将分析Sigmoid的核心缺陷,并对比主流替代方案的优势与适用场景。

一、Sigmoid的核心缺陷

  1. 梯度消失严重
    Sigmoid的导数为σ′(z)=σ(z)(1−σ(z))\sigma'(z)=\sigma(z)(1-\sigma(z))σ′(z)=σ(z)(1−σ(z)),最大值仅0.25(z=0时),且当|z|>5时导数≈0。深层网络中,梯度经过多轮乘法后会趋近于0,导致浅层权重无法更新。
  2. 输出非零中心
    Sigmoid输出始终为正((0,1)),导致神经元的梯度更新方向一致(均为正或均为负),减缓收敛速度。
  3. 计算效率低
    依赖指数运算(e−ze^{-z}e−z),比ReLU的"比较操作"慢10~100倍,不适合大规模网络。

二、主流替代方案对比

针对Sigmoid的缺陷,不同替代方案从"梯度传播""计算效率""输出分布"三个维度进行优化:

替代方案 解决的Sigmoid缺陷 核心优势 适用场景 相比Sigmoid的性能提升(ImageNet分类)
ReLU 梯度消失、计算慢 计算极快、梯度不消失 绝大多数CNN、轻量级网络 训练速度提升35倍,准确率提升2%5%
Leaky ReLU 梯度消失、死亡ReLU 无死亡神经元风险 ReLU效果差的深层网络 训练稳定性提升,准确率提升1%~2%
GELU 梯度消失、输出非零中心 随机激活、梯度平滑 Transformer、预训练模型 预训练任务准确率提升3%~8%
ELU 梯度消失、抗噪声差 零均值输出、抗噪声 小样本、高噪声数据 噪声数据任务准确率提升2%~4%

三、替代方案的PyTorch实践(以CNN为例)

通过在MNIST数据集上对比Sigmoid与ReLU系列的训练效果,验证替代方案的优势:

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
import time

# 1. 定义CNN模型(支持切换激活函数)
class CNN(nn.Module):
    def __init__(self, activation_fn=nn.ReLU()):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.act1 = activation_fn
        self.pool1 = nn.MaxPool2d(2)
        
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.act2 = activation_fn
        self.pool2 = nn.MaxPool2d(2)
        
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.act3 = activation_fn
        self.fc2 = nn.Linear(128, 10)
    
    def forward(self, x):
        x = self.pool1(self.act1(self.conv1(x)))
        x = self.pool2(self.act2(self.conv2(x)))
        x = x.view(-1, 64 * 7 * 7)  # 展平
        x = self.act3(self.fc1(x))
        x = self.fc2(x)
        return x

# 2. 训练函数
def train_model(activation_fn, model_name, epochs=5):
    # 加载数据
    train_dataset = MNIST(root="./data", train=True, download=True, transform=ToTensor())
    test_dataset = MNIST(root="./data", train=False, download=True, transform=ToTensor())
    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
    
    # 初始化模型、损失函数、优化器
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = CNN(activation_fn).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    
    # 记录训练信息
    train_losses = []
    test_accs = []
    start_time = time.time()
    
    # 训练循环
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            
            # 前向传播
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
            # 反向传播
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item() * inputs.size(0)
        
        # 计算训练损失
        train_loss = running_loss / len(train_loader.dataset)
        train_losses.append(train_loss)
        
        # 计算测试准确率
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for inputs, labels in test_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                _, preds = torch.max(outputs, 1)
                correct += (preds == labels).sum().item()
                total += labels.size(0)
        test_acc = correct / total
        test_accs.append(test_acc)
        
        # 打印日志
        print(f"[{model_name}] Epoch {epoch+1}/{epochs} | "
              f"Train Loss: {train_loss:.4f} | "
              f"Test Acc: {test_acc:.4f}")
    
    # 计算总训练时间
    total_time = time.time() - start_time
    print(f"[{model_name}] 总训练时间: {total_time:.2f}秒\n")
    
    return train_losses, test_accs, total_time

# 3. 对比不同激活函数
if __name__ == "__main__":
    # 定义待对比的激活函数
    activation_configs = [
        (nn.Sigmoid(), "Sigmoid"),
        (nn.ReLU(), "ReLU"),
        (nn.LeakyReLU(0.01), "LeakyReLU"),
        (nn.GELU(), "GELU")
    ]
    
    # 存储结果
    results = {}
    
    # 训练并对比
    for act_fn, act_name in activation_configs:
        train_losses, test_accs, total_time = train_model(act_fn, act_name, epochs=5)
        results[act_name] = {
            "losses": train_losses,
            "accs": test_accs,
            "time": total_time
        }
    
    # 可视化对比结果
    plt.figure(figsize=(14, 6))
    
    # 子图1:训练损失对比
    plt.subplot(1, 2, 1)
    for act_name, data in results.items():
        plt.plot(range(1, 6), data["losses"], label=act_name, linewidth=2.5, marker="o")
    plt.xlabel("Epoch")
    plt.ylabel("Train Loss")
    plt.title("不同激活函数的训练损失对比")
    plt.legend()
    plt.grid(alpha=0.3)
    
    # 子图2:测试准确率对比
    plt.subplot(1, 2, 2)
    for act_name, data in results.items():
        plt.plot(range(1, 6), data["accs"], label=act_name, linewidth=2.5, marker="s")
    plt.xlabel("Epoch")
    plt.ylabel("Test Accuracy")
    plt.title("不同激活函数的测试准确率对比")
    plt.legend()
    plt.grid(alpha=0.3)
    
    plt.tight_layout()
    plt.savefig("activation_functions_comparison_mnist.png", dpi=300, bbox_inches="tight")
    plt.show()
    
    # 打印性能总结
    print("=== 性能总结 ===")
    for act_name, data in results.items():
        print(f"{act_name}: "
              f"最终准确率 {data['accs'][-1]:.4f}, "
              f"总时间 {data['time']:.2f}秒, "
              f"最终损失 {data['losses'][-1]:.4f}")

预期结果

  • ReLU/GELU/LeakyReLU的训练速度比Sigmoid快2~3倍;
  • 最终测试准确率:GELU≈ReLU>LeakyReLU>Sigmoid(MNIST任务中ReLU与GELU性能接近);
  • Sigmoid的训练损失下降缓慢,且最终准确率低于95%(其他激活函数可达98%以上)。

四、替代方案的选择建议

  1. 优先选择ReLU

    • 适用场景:大多数CNN、轻量级网络(如MobileNet)、资源受限设备(如手机端);
    • 理由:计算最快,无超参数,兼容性好。
  2. ReLU效果差时用Leaky ReLU

    • 适用场景:深层CNN(≥10层)、容易出现死亡ReLU的任务;
    • 调优建议:α\alphaα取0.01或0.1(默认0.01),无需过度调优。
  3. Transformer/预训练模型用GELU

    • 适用场景:BERT、GPT、ViT(视觉Transformer);
    • 理由:随机激活特性更适配自注意力机制,预训练任务性能更优。
  4. 高噪声数据用ELU

    • 适用场景:医疗图像(含噪声)、小样本语音识别;
    • 注意:计算比ReLU慢,需平衡性能与效率。

2.4.3 多类别分类(任务场景)

多类别分类(Multi-Class Classification)是指"每个样本仅属于一个类别,且类别数K≥3"的分类任务,是深度学习最常见的应用场景之一(如图像识别、文本分类)。

一、多类别分类的核心特征

与二分类(K=2)相比,多类别分类具有以下特点:

  1. 类别互斥:每个样本仅对应一个类别(如一张图片只能是"猫""狗""鸟"中的一种,不能同时是两种);
  2. 输出维度=类别数:网络输出层的神经元数量等于类别数K(如10分类任务输出层有10个神经元);
  3. 概率归一化:输出需满足"所有类别概率之和=1"(便于类别决策),通常通过Softmax实现;
  4. 损失函数适配 :需使用多类别交叉熵损失(而非二分类交叉熵),如PyTorch的nn.CrossEntropyLoss

二、典型任务场景与数据特点

1. 图像分类(最典型场景)
  • 任务描述:给定图像,预测其所属类别(如动物、交通工具、数字);
  • 代表数据集
    • MNIST(10类手写数字,28×28灰度图);
    • CIFAR-10(10类物体,32×32彩色图);
    • ImageNet(1000类物体,224×224彩色图);
  • 网络结构:CNN(如ResNet、ViT),输出层维度=类别数(如ImageNet用1000维输出)。
2. 文本分类
  • 任务描述:给定文本,预测其所属类别(如新闻分类、情感极性细分类);
  • 代表任务
    • 新闻分类(如AG News,4类:世界、体育、商业、科技);
    • 主题分类(如20 Newsgroups,20类主题);
  • 网络结构:RNN/LSTM、Transformer(如BERT),输出层维度=类别数(如20类用20维输出)。
3. 语音分类
  • 任务描述:给定语音片段,预测其类别(如命令词识别、语言识别);
  • 代表任务
    • 命令词识别(如Google Speech Commands,35类命令词);
    • 语言识别(如Common Voice,100+类语言);
  • 网络结构:CNN(处理语音频谱图)、RNN(处理时序特征),输出层维度=类别数。
4. 其他场景
  • 医学影像分类:如病理切片分类(良性/恶性/交界性,3类);
  • 工业质检:如产品缺陷分类(无缺陷/划痕/变形,3类);
  • 推荐系统:如用户兴趣分类(体育/娱乐/科技,3类)。

三、多类别分类的网络设计要点

以"CIFAR-10分类"为例,说明多类别分类的网络设计规范:

python 复制代码
import torch
import torch.nn as nn
from torchvision.datasets import CIFAR10
from torchvision.transforms import Compose, ToTensor, Normalize, RandomHorizontalFlip, RandomCrop
from torch.utils.data import DataLoader

# 1. 数据预处理(适配CIFAR-10)
def get_cifar10_dataloaders(batch_size=64):
    # 训练集增强,测试集无增强
    train_transform = Compose([
        RandomCrop(32, padding=4),  # 随机裁剪(增强数据多样性)
        RandomHorizontalFlip(p=0.5),  # 随机水平翻转
        ToTensor(),
        Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2470, 0.2435, 0.2616])  # CIFAR-10标准归一化
    ])
    
    test_transform = Compose([
        ToTensor(),
        Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2470, 0.2435, 0.2616])
    ])
    
    # 加载数据集
    train_dataset = CIFAR10(root="./data", train=True, download=True, transform=train_transform)
    test_dataset = CIFAR10(root="./data", train=False, download=True, transform=test_transform)
    
    # 数据加载器
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
    
    return train_loader, test_loader, train_dataset.classes  # 返回类别名称

# 2. 多类别分类网络(CNN)
class CIFAR10_CNN(nn.Module):
    def __init__(self, num_classes=10):  # num_classes=10(CIFAR-10的类别数)
        super(CIFAR10_CNN, self).__init__()
        # 特征提取层(CNN)
        self.features = nn.Sequential(
            # 卷积块1:3→16
            nn.Conv2d(3, 16, kernel_size=3, padding=1),
            nn.BatchNorm2d(16),  # 批量归一化(加速训练)
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),  # 下采样,尺寸16×16
            
            # 卷积块2:16→32
            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),  # 尺寸8×8
            
            # 卷积块3:32→64
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2)  # 尺寸4×4
        )
        
        # 分类层(全连接)
        self.classifier = nn.Sequential(
            nn.Flatten(),  # 展平:64×4×4=1024
            nn.Linear(64 * 4 * 4, 256),  # 1024→256
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),  #  dropout(防止过拟合)
            nn.Linear(256, num_classes)  # 256→10(输出层维度=类别数)
        )
    
    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        # 注意:输出层无Softmax!PyTorch的CrossEntropyLoss已包含Softmax
        return x

# 3. 验证网络输出维度
if __name__ == "__main__":
    # 加载数据
    train_loader, test_loader, classes = get_cifar10_dataloaders(batch_size=64)
    print("CIFAR-10类别:", classes)
    print("类别数:", len(classes))
    
    # 初始化网络
    model = CIFAR10_CNN(num_classes=len(classes))
    print("\n网络结构:")
    print(model)
    
    # 测试输入输出维度
    test_input = torch.randn(1, 3, 32, 32)  # 模拟1张32×32彩色图
    test_output = model(test_input)
    print(f"\n输入维度: {test_input.shape}")
    print(f"输出维度: {test_output.shape}")  # 应输出(1, 10),对应10类的logits

设计要点总结

  1. 输出层维度=类别数:如CIFAR-10有10类,输出层设为10个神经元;
  2. 输出层无激活函数 :PyTorch的CrossEntropyLoss已集成Softmax,重复添加会导致损失计算错误;
  3. 批量归一化(BatchNorm):加速训练,缓解梯度消失,尤其适合深层网络;
  4. 数据增强:随机裁剪、翻转等操作,提升模型泛化能力(多类别任务易过拟合);
  5. Dropout:分类层添加Dropout(如0.5),防止过拟合。

四、注意事项与常见错误

  1. 类别数与输出层维度不匹配

    • 错误:如CIFAR-10任务输出层设为100个神经元;
    • 后果:损失计算报错(标签范围与输出维度不匹配);
    • 解决方案:确保num_classes参数等于实际类别数,可从数据集的classes属性获取。
  2. 输出层误加Softmax

    • 错误:在输出层后添加nn.Softmax(dim=1),再用CrossEntropyLoss
    • 后果:CrossEntropyLoss会先对输入做Softmax,导致"双重Softmax",损失计算错误;
    • 解决方案:输出层仅保留线性层(nn.Linear),不添加激活函数。
  3. 标签格式错误

    • 错误:将多类别标签转为one-hot编码(如1→[0,1,0,...]),再用CrossEntropyLoss
    • 后果:CrossEntropyLoss要求标签为"类别索引"(如1),而非one-hot向量;
    • 解决方案:直接使用原始类别索引标签,若标签已one-hot,需用nn.NLLLoss(配合Softmax输出)。

2.4.4 Softmax(原理 + 网络适配 + PyTorch 框架)

Softmax函数是多类别分类的核心组件,其作用是将网络输出的"原始logits"转换为"类别概率分布",满足"所有概率之和=1",便于类别决策和损失计算。

一、Softmax的数学原理

1. 定义与公式

对于网络输出的logits向量z=[z1,z2,...,zK]\mathbf{z}=[z_1, z_2, ..., z_K]z=[z1,z2,...,zK](K为类别数),Softmax函数将其映射为概率向量p=[p1,p2,...,pK]\mathbf{p}=[p_1, p_2, ..., p_K]p=[p1,p2,...,pK],公式如下:
pi=ezi∑j=1Kezj(i=1,2,...,K)p_i = \frac{e^{z_i}}{\sum_{j=1}^K e^{z_j}} \quad (i=1,2,...,K)pi=∑j=1Kezjezi(i=1,2,...,K)

核心特性

  • 概率范围:0<pi<10 < p_i < 10<pi<1(每个类别概率为正);
  • 概率和为1:∑i=1Kpi=1\sum_{i=1}^K p_i = 1∑i=1Kpi=1(符合概率分布定义);
  • 相对大小保持:若zi>zjz_i > z_jzi>zj,则pi>pjp_i > p_jpi>pj(概率排序与logits排序一致)。
2. 数值稳定性优化

直接计算Softmax易出现数值溢出 (当ziz_izi较大时,ezie^{z_i}ezi会超出浮点数范围)。解决方案是"减去logits的最大值",推导如下:
pi=ezi∑j=1Kezj=ezi−max⁡(z)∑j=1Kezj−max⁡(z)p_i = \frac{e^{z_i}}{\sum_{j=1}^K e^{z_j}} = \frac{e^{z_i - \max(\mathbf{z})}}{\sum_{j=1}^K e^{z_j - \max(\mathbf{z})}}pi=∑j=1Kezjezi=∑j=1Kezj−max(z)ezi−max(z)

由于max⁡(z)\max(\mathbf{z})max(z)是常数,分子分母同乘e−max⁡(z)e^{-\max(\mathbf{z})}e−max(z)不改变结果,但可将zi−max⁡(z)z_i - \max(\mathbf{z})zi−max(z)控制在≤0的范围,避免ezie^{z_i}ezi溢出。

二、Softmax与网络的适配逻辑

Softmax在多类别分类网络中的位置和作用如下:

  1. 网络输出层:输出K维logits(无激活函数);
  2. Softmax层:将logits转换为K维概率分布;
  3. 类别决策 :选择概率最大的类别作为预测结果(y^=arg⁡max⁡(p1,p2,...,pK)\hat{y} = \arg\max(p_1, p_2, ..., p_K)y^=argmax(p1,p2,...,pK));
  4. 损失计算:用交叉熵损失(Cross-Entropy Loss)衡量预测概率与真实标签的差距。

适配流程图

复制代码
输入图像 → CNN特征提取 → 全连接层(输出K维logits) → Softmax → K维概率分布 → 类别决策(argmax)
                                          ↓
                                      真实标签(类别索引) → 交叉熵损失 → 反向传播优化

三、PyTorch中的Softmax实现与应用

PyTorch提供两种使用Softmax的方式:手动调用 nn.Softmax(推理时)和CrossEntropyLoss自动集成(训练时),需根据场景选择。

1. 训练时:使用CrossEntropyLoss(推荐)

nn.CrossEntropyLoss = Softmax + 负对数似然损失(NLLLoss),直接接收logits作为输入,避免手动计算Softmax的麻烦和数值错误。

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor

# 1. 加载数据(MNIST,10类)
train_dataset = MNIST(root="./data", train=True, download=True, transform=ToTensor())
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

# 2. 定义简单网络(输出10维logits)
class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(28*28, 256)
        self.fc2 = nn.Linear(256, 10)  # 输出层:10维logits(无Softmax)
    
    def forward(self, x):
        x = self.flatten(x)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)  # 输出logits
        return x

# 3. 初始化模型、损失函数(CrossEntropyLoss)、优化器
model = SimpleNet()
criterion = nn.CrossEntropyLoss()  # 已集成Softmax
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# 4. 单批次训练演示
model.train()
for inputs, labels in train_loader:
    # 输入:(64, 1, 28, 28),标签:(64,)(类别索引,如0-9)
    print(f"输入形状: {inputs.shape}, 标签形状: {labels.shape}")
    
    # 前向传播:输出logits(64, 10)
    logits = model(inputs)
    print(f"Logits形状: {logits.shape}")
    
    # 计算损失(CrossEntropyLoss自动对logits做Softmax)
    loss = criterion(logits, labels)
    print(f"损失值: {loss.item():.4f}")
    
    # 反向传播与优化
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    # 计算概率与预测类别(推理时用)
    with torch.no_grad():
        # 手动计算Softmax(推理时获取概率)
        probs = torch.softmax(logits, dim=1)  # dim=1:按样本维度计算Softmax
        print(f"概率形状: {probs.shape}, 概率和: {probs.sum(dim=1).numpy()[:5]}")  # 验证概率和为1
        
        # 预测类别(argmax取概率最大的索引)
        preds = torch.argmax(probs, dim=1)
        print(f"预测类别形状: {preds.shape}, 前5个预测: {preds.numpy()[:5]}")
        print(f"前5个真实标签: {labels.numpy()[:5]}")
    
    break  # 仅演示单批次
2. 推理时:手动调用torch.softmax

推理阶段需要获取类别概率(如计算置信度),需手动对logits应用Softmax,注意指定dim=1(按样本维度计算,确保每个样本的概率和为1)。

python 复制代码
def infer_with_softmax(model, test_loader, device):
    """推理时用Softmax计算概率"""
    model.eval()
    correct = 0
    total = 0
    
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            
            # 前向传播:获取logits
            logits = model(inputs)
            
            # 计算概率(推理时手动加Softmax)
            probs = torch.softmax(logits, dim=1)
            
            # 预测类别(概率最大的类别)
            preds = torch.argmax(probs, dim=1)
            
            # 计算准确率
            correct += (preds == labels).sum().item()
            total += labels.size(0)
            
            # 输出部分结果(概率、置信度、预测类别)
            if total <= 5:  # 仅展示前5个样本
                sample_idx = total - len(labels)  # 当前批次的起始索引
                for i in range(min(len(labels), 5 - sample_idx)):
                    print(f"样本{i+sample_idx+1}: "
                          f"概率={probs[i].numpy().round(4)}, "
                          f"置信度={probs[i].max().item():.4f}, "
                          f"预测={preds[i].item()}, "
                          f"真实={labels[i].item()}")
    
    print(f"\n推理准确率: {correct/total:.4f}")
    return correct/total

# 测试推理函数
if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = SimpleNet().to(device)
    test_dataset = MNIST(root="./data", train=False, download=True, transform=ToTensor())
    test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
    
    # 加载训练好的权重(此处用随机权重演示,实际需加载训练后的权重)
    # model.load_state_dict(torch.load("mnist_simple_net.pth"))
    
    infer_with_softmax(model, test_loader, device)
3. 数值稳定性验证

通过代码验证"减去最大值"的优化效果,避免数值溢出:

python 复制代码
def verify_softmax_numerical_stability():
    # 生成极端logits(含大值,易导致溢出)
    logits = torch.tensor([[1000, 1001, 1002]], dtype=torch.float32)
    print("原始logits:", logits.numpy())
    
    # 1. 直接计算Softmax(会溢出)
    try:
        probs_direct = torch.softmax(logits, dim=1)
        print("直接计算Softmax:", probs_direct.numpy())
    except Exception as e:
        print("直接计算Softmax报错:", e)
    
    # 2. 优化计算(减去最大值)
    max_val = logits.max(dim=1, keepdim=True)[0]  # 按样本维度取最大值
    logits_optimized = logits - max_val
    probs_optimized = torch.softmax(logits_optimized, dim=1)
    print("优化后logits:", logits_optimized.numpy())
    print("优化后Softmax:", probs_optimized.numpy())
    print("优化后概率和:", probs_optimized.sum(dim=1).numpy())

# 验证数值稳定性
verify_softmax_numerical_stability()

输出结果

  • 直接计算Softmax会输出inf(无穷大)或nan(非数字),因e1002e^{1002}e1002超出浮点数范围;
  • 优化后计算会正常输出概率(如[0.0900, 0.2447, 0.6652]),概率和为1。

四、注意事项与常见错误

  1. dim参数设置错误

    • 错误:torch.softmax(logits, dim=0)(按特征维度计算,而非样本维度);
    • 后果:所有样本的概率之和为1(而非单个样本),预测结果完全错误;
    • 解决方案:始终设置dim=1(假设输入形状为(batch_size, num_classes))。
  2. 训练时手动加Softmax

    • 错误:输出层后加nn.Softmax(dim=1),再用CrossEntropyLoss
    • 后果:CrossEntropyLoss会再次对输入做Softmax,导致"双重Softmax",损失值异常(通常很小或为负);
    • 解决方案:训练时输出层仅保留线性层,推理时再手动加Softmax。
  3. 数值溢出未处理

    • 错误:对大logits直接计算Softmax,未减去最大值;
    • 后果:e^{z_i}溢出,输出infnan,训练崩溃;
    • 解决方案:使用PyTorch的torch.softmax(已内置"减最大值"优化),无需手动处理。

2.4.5 多输出分类(扩展场景)

多输出分类(Multi-Output Classification)是"一个样本同时预测多个类别标签"的场景,与传统多类别分类(单标签)的核心区别是标签不互斥(如一张图片可同时包含"猫"和"狗"两个标签)。

一、多输出分类的核心场景

1. 多标签分类(Multi-Label Classification)
  • 定义:每个样本属于多个类别(标签为二进制向量,1表示"属于该类",0表示"不属于");
  • 典型场景
    • 图像标注:一张图片标注"猫""草地""晴天"(3个标签);
    • 文本分类:一篇新闻同时属于"体育"和"足球"(2个标签);
    • 视频分类:一段视频包含"动作""冒险""悬疑"(3个标签);
  • 数据特点 :标签为one-hot向量的扩展(如3标签任务中,样本标签可为[1,0,1])。
2. 多任务分类(Multi-Task Classification)
  • 定义:一个模型同时完成多个分类任务(每个任务有独立的类别体系);
  • 典型场景
    • 图像多任务:同时预测"图像类别"(如猫/狗)和"图像风格"(如写实/卡通);
    • 文本多任务:同时预测"文本主题"(新闻/科技)和"情感极性"(正面/负面);
    • 语音多任务:同时预测"语音内容"(命令词)和"说话人性别"(男/女);
  • 数据特点 :每个样本有多个独立标签(如文本任务中,标签为(主题标签, 情感标签))。

二、多标签分类的PyTorch实现

多标签分类的核心是输出层用Sigmoid激活 (每个输出独立预测"是否属于该类")和损失函数用二元交叉熵BCEWithLogitsLoss)。

1. 数据准备(以多标签图像数据集为例)

使用TorchVisionCIFAR-100数据集模拟多标签场景(每个样本随机分配2~3个标签):

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets import CIFAR100
from torchvision.transforms import Compose, ToTensor, Normalize
import numpy as np

# 1. 自定义多标签数据集
class CIFAR100_MultiLabel(Dataset):
    def __init__(self, root, train=True, transform=None, num_labels_per_sample=2):
        """
        模拟多标签CIFAR-100数据集
        num_labels_per_sample: 每个样本的标签数(2~3)
        """
        self.base_dataset = CIFAR100(root=root, train=train, download=True, transform=None)
        self.transform = transform
        self.num_labels_per_sample = num_labels_per_sample
        self.num_classes = 100
        
        # 为每个样本生成多标签(随机选择2~3个类别)
        self.multi_labels = []
        for _ in range(len(self.base_dataset)):
            # 随机选择2~3个不同的类别索引
            num_labels = np.random.randint(2, 4)  # 2或3个标签
            label_indices = np.random.choice(self.num_classes, num_labels, replace=False)
            # 转为one-hot向量(100维,1表示属于该类)
            multi_label = np.zeros(self.num_classes, dtype=np.float32)
            multi_label[label_indices] = 1.0
            self.multi_labels.append(multi_label)
    
    def __len__(self):
        return len(self.base_dataset)
    
    def __getitem__(self, idx):
        # 获取原始图像和标签
        img, _ = self.base_dataset[idx]
        multi_label = self.multi_labels[idx]
        
        # 应用变换
        if self.transform is not None:
            img = self.transform(img)
        
        # 转换为张量
        multi_label = torch.tensor(multi_label)
        return img, multi_label

# 2. 数据加载器
def get_multilabel_dataloaders(batch_size=64):
    # 数据预处理
    transform = Compose([
        ToTensor(),
        Normalize(mean=[0.5071, 0.4867, 0.4408], std=[0.2675, 0.2565, 0.2761])  # CIFAR-100标准归一化
    ])
    
    # 加载多标签数据集
    train_dataset = CIFAR100_MultiLabel(
        root="./data", train=True, transform=transform, num_labels_per_sample=2
    )
    test_dataset = CIFAR100_MultiLabel(
        root="./data", train=False, transform=transform, num_labels_per_sample=2
    )
    
    # 数据加载器
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
    
    return train_loader, test_loader, train_dataset.num_classes
2. 多标签分类网络设计
python 复制代码
class MultiLabelCNN(nn.Module):
    def __init__(self, num_classes=100):
        super(MultiLabelCNN, self).__init__()
        # 特征提取层(CNN)
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),  # 16×16
            
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),  # 8×8
            
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2)  # 4×4
        )
        
        # 分类层(输出num_classes个logits,对应每个类别的预测)
        self.classifier = nn.Sequential(
            nn.Flatten(),  # 128×4×4=2048
            nn.Linear(128 * 4 * 4, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(512, num_classes)  # 输出层:num_classes个logits
        )
        
        # 多标签分类的激活函数(Sigmoid,推理时用)
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, x, return_probs=False):
        x = self.features(x)
        x = self.classifier(x)  # 输出logits(用于训练,配合BCEWithLogitsLoss)
        
        if return_probs:
            # 推理时返回概率(Sigmoid激活)
            x = self.sigmoid(x)
        return x
3. 训练与推理(多标签场景适配)
python 复制代码
def train_multilabel_model(model, train_loader, criterion, optimizer, device, epochs=3):
    model.train()
    model.to(device)
    
    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            
            # 前向传播(输出logits)
            outputs = model(inputs)
            
            # 计算损失(BCEWithLogitsLoss:适用于多标签分类)
            loss = criterion(outputs, labels)
            
            # 反向传播与优化
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item() * inputs.size(0)
        
        # 计算epoch损失
        epoch_loss = running_loss / len(train_loader.dataset)
        print(f"Epoch {epoch+1}/{epochs} | Train Loss: {epoch_loss:.4f}")
    
    return model

def infer_multilabel_model(model, test_loader, device, threshold=0.5):
    """多标签推理:用阈值(如0.5)判断是否属于该类"""
    model.eval()
    model.to(device)
    
    # 多标签分类的评估指标:精确率、召回率、F1分数
    true_positives = 0  # 预测为1且真实为1
    false_positives = 0  # 预测为1且真实为0
    false_negatives = 0  # 预测为0且真实为1
    
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            
            # 推理:获取概率(Sigmoid激活)
            probs = model(inputs, return_probs=True)
            
            # 按阈值生成预测标签(0或1)
            preds = (probs > threshold).float()
            
            # 计算评估指标
            true_positives += (preds * labels).sum().item()
            false_positives += (preds * (1 - labels)).sum().item()
            false_negatives += ((1 - preds) * labels).sum().item()
        
        # 计算精确率、召回率、F1
        precision = true_positives / (true_positives + false_positives + 1e-8)  # 加1e-8避免除零
        recall = true_positives / (true_positives + false_negatives + 1e-8)
        f1 = 2 * precision * recall / (precision + recall + 1e-8)
        
        print(f"\n多标签推理结果(阈值={threshold}):")
        print(f"精确率(Precision): {precision:.4f}")
        print(f"召回率(Recall): {recall:.4f}")
        print(f"F1分数: {f1:.4f}")
        
        # 展示部分样本结果
        print("\n部分样本预测结果:")
        inputs, labels = next(iter(test_loader))
        inputs, labels = inputs[:5].to(device), labels[:5].to(device)
        probs = model(inputs, return_probs=True)
        preds = (probs > threshold).float()
        
        for i in range(5):
            # 提取真实标签和预测标签的类别索引
            true_labels = torch.where(labels[i] == 1)[0].numpy()
            pred_labels = torch.where(preds[i] == 1)[0].numpy()
            
            print(f"样本{i+1}:")
            print(f"  真实标签: {true_labels}")
            print(f"  预测概率: {probs[i][true_labels].numpy().round(4)}")
            print(f"  预测标签: {pred_labels}")
    
    return precision, recall, f1

# 主函数:多标签分类完整流程
if __name__ == "__main__":
    # 设备配置
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"使用设备: {device}")
    
    # 1. 加载数据
    train_loader, test_loader, num_classes = get_multilabel_dataloaders(batch_size=64)
    print(f"多标签类别数: {num_classes}")
    
    # 2. 初始化模型、损失函数、优化器
    model = MultiLabelCNN(num_classes=num_classes)
    # 多标签损失函数:BCEWithLogitsLoss(输入为logits,自动加Sigmoid)
    criterion = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    
    # 3. 训练模型
    print("\n开始训练多标签模型...")
    model = train_multilabel_model(model, train_loader, criterion, optimizer, device, epochs=3)
    
    # 4. 推理模型
    print("\n开始多标签推理...")
    infer_multilabel_model(model, test_loader, device, threshold=0.5)

多标签分类核心适配点

  1. 输出层激活函数:推理时用Sigmoid(每个输出独立预测"是否属于该类");
  2. 损失函数 :用nn.BCEWithLogitsLoss(二元交叉熵,支持多标签场景);
  3. 标签格式 :标签为one-hot向量(如[1,0,1]),而非类别索引;
  4. 推理决策:用阈值(如0.5)判断是否属于该类,而非argmax(多标签非互斥);
  5. 评估指标:用精确率、召回率、F1分数(而非准确率,准确率不适用于多标签)。

三、多任务分类的PyTorch实现

多任务分类的核心是网络输出多个分支 (每个分支对应一个任务),并联合优化多个任务的损失(加权求和)。

1. 多任务网络设计(以"图像类别+风格"双任务为例)
python 复制代码
class MultiTaskCNN(nn.Module):
    def __init__(self, num_class_task1=10, num_class_task2=2):
        super(MultiTaskCNN, self).__init__()
        # 共享特征提取层(两个任务共用)
        self.shared_features = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),
            
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2)
        )
        
        # 任务1分支(图像类别分类,多类别任务)
        self.task1_head = nn.Sequential(
            nn.Flatten(),
            nn.Linear(64 * 8 * 8, 256),
            nn.ReLU(inplace=True),
            nn.Linear(256, num_class_task1)  # 输出logits(无Softmax)
        )
        
        # 任务2分支(图像风格分类,二分类任务)
        self.task2_head = nn.Sequential(
            nn.Flatten(),
            nn.Linear(64 * 8 * 8, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 1)  # 输出logits(无Sigmoid)
        )
        
        # 激活函数(推理时用)
        self.softmax = nn.Softmax(dim=1)  # 任务1:多类别
        self.sigmoid = nn.Sigmoid()        # 任务2:二分类
    
    def forward(self, x, return_probs=False):
        # 共享特征提取
        shared_x = self.shared_features(x)
        
        # 任务1输出(图像类别)
        task1_logits = self.task1_head(shared_x)
        
        # 任务2输出(图像风格)
        task2_logits = self.task2_head(shared_x)
        
        if return_probs:
            # 推理时返回概率
            task1_probs = self.softmax(task1_logits)
            task2_probs = self.sigmoid(task2_logits)
            return task1_probs, task2_probs
        
        # 训练时返回logits
        return task1_logits, task2_logits
2. 多任务训练与推理
python 复制代码
# 模拟多任务数据集(任务1:CIFAR-10类别,任务2:随机风格标签)
class CIFAR10_MultiTask(Dataset):
    def __init__(self, root, train=True, transform=None):
        self.base_dataset = CIFAR10(root=root, train=train, download=True, transform=transform)
        self.num_class_task1 = 10  # 任务1:CIFAR-10类别
        self.num_class_task2 = 2   # 任务2:风格(0=写实,1=卡通)
        
        # 生成任务2的风格标签(随机)
        self.task2_labels = np.random.randint(0, 2, len(self.base_dataset))
    
    def __len__(self):
        return len(self.base_dataset)
    
    def __getitem__(self, idx):
        img, task1_label = self.base_dataset[idx]
        task2_label = self.task2_labels[idx]
        
        # 转换为张量
        task1_label = torch.tensor(task1_label, dtype=torch.long)
        task2_label = torch.tensor(task2_label, dtype=torch.float32)
        
        return img, (task1_label, task2_label)

# 多任务训练函数
def train_multitask_model(model, train_loader, criterions, optimizer, device, epochs=3, task_weights=[0.5, 0.5]):
    """
    criterions: 损失函数列表([task1_criterion, task2_criterion])
    task_weights: 任务权重(平衡不同任务的损失)
    """
    model.train()
    model.to(device)
    
    task1_criterion, task2_criterion = criterions
    
    for epoch in range(epochs):
        running_task1_loss = 0.0
        running_task2_loss = 0.0
        
        for inputs, (task1_labels, task2_labels) in train_loader:
            inputs = inputs.to(device)
            task1_labels = task1_labels.to(device)
            task2_labels = task2_labels.to(device)
            
            # 前向传播(输出两个任务的logits)
            task1_logits, task2_logits = model(inputs)
            
            # 计算每个任务的损失
            task1_loss = task1_criterion(task1_logits, task1_labels)
            task2_loss = task2_criterion(task2_logits, task2_labels)
            
            # 联合损失(加权求和)
            total_loss = task_weights[0] * task1_loss + task_weights[1] * task2_loss
            
            # 反向传播与优化
            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()
            
            # 累加损失
            running_task1_loss += task1_loss.item() * inputs.size(0)
            running_task2_loss += task2_loss.item() * inputs.size(0)
        
        # 计算epoch损失
        epoch_task1_loss = running_task1_loss / len(train_loader.dataset)
        epoch_task2_loss = running_task2_loss / len(train_loader.dataset)
        epoch_total_loss = task_weights[0] * epoch_task1_loss + task_weights[1] * epoch_task2_loss
        
        print(f"Epoch {epoch+1}/{epochs} | "
              f"Total Loss: {epoch_total_loss:.4f} | "
              f"Task1 Loss: {epoch_task1_loss:.4f} | "
              f"Task2 Loss: {epoch_task2_loss:.4f}")
    
    return model

# 多任务推理函数
def infer_multitask_model(model, test_loader, device):
    model.eval()
    model.to(device)
    
    # 任务1:多类别分类准确率
    task1_correct = 0
    task1_total = 0
    
    # 任务2:二分类准确率
    task2_correct = 0
    task2_total = 0
    
    with torch.no_grad():
        for inputs, (task1_labels, task2_labels) in test_loader:
            inputs = inputs.to(device)
            task1_labels = task1_labels.to(device)
            task2_labels = task2_labels.to(device)
            
            # 推理:获取概率
            task1_probs, task2_probs = model(inputs, return_probs=True)
            
            # 任务1:多类别预测(argmax)
            task1_preds = torch.argmax(task1_probs, dim=1)
            task1_correct += (task1_preds == task1_labels).sum().item()
            task1_total += task1_labels.size(0)
            
            # 任务2:二分类预测(阈值0.5)
            task2_preds = (task2_probs > 0.5).float().squeeze()
            task2_correct += (task2_preds == task2_labels).sum().item()
            task2_total += task2_labels.size(0)
        
        # 计算准确率
        task1_acc = task1_correct / task1_total
        task2_acc = task2_correct / task2_total
        
        print(f"\n多任务推理结果:")
        print(f"任务1(图像类别)准确率: {task1_acc:.4f}")
        print(f"任务2(图像风格)准确率: {task2_acc:.4f}")
    
    return task1_acc, task2_acc

# 主函数:多任务分类完整流程

if __name__ == "__main__":
    # 设备配置
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"使用设备: {device}")
    
    # 1. 加载多任务数据集
    transform = Compose([
        ToTensor(),
        Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2470, 0.2435, 0.2616])
    ])
    
    train_dataset = CIFAR10_MultiTask(root="./data", train=True, transform=transform)
    test_dataset = CIFAR10_MultiTask(root="./data", train=False, transform=transform)
    
    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=2)
    test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=2)
    
    print(f"任务1类别数: {train_dataset.num_class_task1}(图像类别)")
    print(f"任务2类别数: {train_dataset.num_class_task2}(图像风格)")
    
    # 2. 初始化模型、损失函数、优化器
    model = MultiTaskCNN(
        num_class_task1=train_dataset.num_class_task1,
        num_class_task2=train_dataset.num_class_task2
    )
    
    # 不同任务使用不同损失函数
    criterions = [
        nn.CrossEntropyLoss(),  # 任务1:多类别分类
        nn.BCEWithLogitsLoss()  # 任务2:二分类(图像风格)
    ]
    
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    
    # 3. 训练模型(任务权重根据重要性调整,这里设为[0.6, 0.4])
    print("\n开始训练多任务模型...")
    model = train_multitask_model(
        model, train_loader, criterions, optimizer, device,
        epochs=3, task_weights=[0.6, 0.4]
    )
    
    # 4. 推理模型
    print("\n开始多任务推理...")
    infer_multitask_model(model, test_loader, device)

多任务学习核心适配点

  1. 网络分支设计:共享特征提取层(提高参数效率)+ 任务专属头(适配不同输出维度);
  2. 损失函数组合 :根据任务类型选择损失(如多类别用CrossEntropyLoss,二分类用BCEWithLogitsLoss);
  3. 联合损失计算 :通过任务权重(task_weights)平衡不同任务的损失贡献(避免某一任务主导优化);
  4. 推理适配 :不同任务采用对应决策方式(多类别用argmax,二分类用阈值)。

三、多输出分类的注意事项与进阶技巧

1. 任务权重调优
  • 问题:不同任务的损失值范围可能差异很大(如A任务损失≈10,B任务损失≈0.1),直接加权会导致优化偏向损失大的任务;
  • 解决方案
    • 初始权重设为1/K(K为任务数),训练中观察各任务性能,手动上调重要任务权重;
    • 动态权重策略:根据任务损失的反比自动调整(如wi=1/lossiw_i = 1/\text{loss}_iwi=1/lossi),避免人为调参;
    • 示例:图像分类+目标检测任务中,检测任务权重通常高于分类(因检测更复杂)。
2. 任务相关性考量
  • 正向迁移:任务间存在关联时(如"人脸检测"与"性别识别"),共享特征可提升双方性能;
  • 负向迁移:任务间无关联时(如"图像分类"与"文本情感"),强行共享特征会导致性能下降;
  • 建议:仅在任务有重叠特征时使用多任务学习(如视觉任务间共享CNN特征,语言任务间共享Transformer特征)。
3. 样本不平衡处理
  • 问题:多标签任务中,不同类别的正样本比例可能差异极大(如"猫"出现1000次,"老虎"仅出现10次);
  • 解决方案
    • 损失函数中添加类别权重(nn.BCEWithLogitsLoss(weight=class_weights));
    • 对稀有类别进行过采样,或对常见类别进行欠采样;
    • 推理时降低稀有类别的决策阈值(如"老虎"用0.3而非0.5)。
4. 评估指标选择
任务类型 不适用指标 推荐指标
多标签分类 准确率(Accuracy) 精确率(Precision)、召回率(Recall)、F1分数、Hamming距离
多任务分类 单一指标 各任务独立指标(如任务1准确率、任务2F1)+ 平均指标

四、多输出分类与传统分类的对比总结

维度 传统多类别分类(单标签) 多输出分类(多标签/多任务)
标签特性 互斥(仅一个标签) 非互斥(多个标签/任务)
输出层激活函数 无(训练)/Softmax(推理) 无(训练)/Sigmoid(推理)
损失函数 CrossEntropyLoss BCEWithLogitsLoss(多标签)/ 多损失组合(多任务)
决策方式 argmax(取概率最大类别) 阈值判断(多标签)/ 各任务独立决策(多任务)
典型应用 ImageNet分类、手写数字识别 图像标注、多属性预测、联合任务学习
核心挑战 类别不平衡、类别混淆 任务权重平衡、负向迁移、样本稀疏

总结

激活函数与多类别处理是深度学习模型设计的核心环节:

  • 激活函数通过引入非线性赋予网络复杂建模能力,ReLU系列(含GELU)凭借高效性成为主流,需根据网络深度、任务类型选择适配函数;
  • 多类别分类通过Softmax实现概率归一化,配合CrossEntropyLoss完成训练,需注意输出层维度与类别数匹配,避免手动添加Softmax导致的损失计算错误;
  • 多输出分类(多标签/多任务)扩展了传统分类的适用范围,需通过Sigmoid激活、BCE损失(多标签)或多损失组合(多任务)实现,核心是平衡任务权重与特征共享策略。

实际应用中,需结合具体任务场景选择合适的技术方案,并通过可视化与 ablation study 验证关键组件的有效性。

相关推荐
加油20194 小时前
如何快速学习一个网络协议?
网络·网络协议·学习·方法论
A9better5 小时前
嵌入式开发学习日志36——stm32之USART串口通信前述
stm32·单片机·嵌入式硬件·学习
不太可爱的叶某人5 小时前
【学习笔记】kafka权威指南——第6章 可靠的数据传递
笔记·学习·kafka
~kiss~7 小时前
K-means损失函数-收敛证明
算法·机器学习·kmeans
2301_790994998 小时前
仿神秘海域/美末环境交互的程序化动画学习
学习·microsoft·交互
能不能别报错8 小时前
K8s学习笔记(十六) 探针(Probe)
笔记·学习·kubernetes
初圣魔门首席弟子8 小时前
C++ STL 向量(vector)学习笔记:从基础到实战
c++·笔记·学习
qiangshang9901268 小时前
WPF+MVVM入门学习
学习·wpf
iconball8 小时前
个人用云计算学习笔记 --20 (Nginx 服务器)
linux·运维·笔记·学习·云计算