Mean-Teacher 均值教师自训练框架详解

Mean-Teacher 均值教师自训练框架详解

一、Mean-Teacher 核心原理详细描述

1. 设计背景

半监督学习场景:少量标注数据 + 大量无标注数据。传统伪标签方法单模型预测噪声大、易过拟合无标注样本;Mean-Teacher 提出双模型架构(学生模型 + 教师模型),用教师平滑预测作为软伪标签监督学生,大幅降低伪标签噪声。

2. 两大核心模型

  1. Student 学生网络
    • 可训练、带梯度更新;
    • 输入加随机数据增强(强增广);
    • 同时接收标注/无标注样本,损失包含监督损失 + 一致性正则损失。
  2. Teacher 教师网络
    • 无梯度、不直接训练 ;权重由学生权重指数移动平均EMA 缓慢更新:
      θt=α⋅θt+(1−α)⋅θs\theta_t = \alpha \cdot \theta_t + (1-\alpha) \cdot \theta_sθt=α⋅θt+(1−α)⋅θs
      α\alphaα:EMA衰减系数,常用0.99~0.999;训练前期可退火逐步提升。
    • 输入仅弱数据增强;
    • 输出作为稳定、平滑的目标标签,用来约束学生模型输出分布一致。

3. 完整损失函数

总损失 = 标注监督分类损失 + 无标注一致性损失

Ltotal=Lsup+λ(t)⋅Lconsist \mathcal{L}{total} = \mathcal{L}{sup} + \lambda(t) \cdot \mathcal{L}_{consist} Ltotal=Lsup+λ(t)⋅Lconsist

  1. 监督损失 Lsup\mathcal{L}_{sup}Lsup:仅作用有标签样本,交叉熵分类损失;
  2. 一致性损失 Lconsist\mathcal{L}_{consist}Lconsist:MSE/KL散度,约束学生强增广输出分布 ≈ 教师弱增广输出分布;
  3. λ(t)\lambda(t)λ(t):一致性权重退火系数:训练前期权重小,后期放大正则约束,避免前期噪声干扰。

4. 训练流程步骤

  1. 初始化学生网络、教师网络(复制学生初始权重,关闭教师梯度);
  2. 每迭代取一批混合数据:有标签+无标签;
  3. 样本分支处理:
    • 标注样本:弱增强送入学生,计算分类交叉熵;
    • 无标注样本:强增强输入学生弱增强输入教师
  4. 前向传播:
    • 学生输出:带噪声强增广预测;
    • 教师输出:停止梯度,平滑稳定预测(伪标签目标);
  5. 计算一致性MSE损失,叠加监督损失;
  6. 反向传播只更新学生网络;
  7. EMA更新教师权重(无梯度);
  8. 迭代至收敛,推理只用教师模型(泛化更好)。

5. 关键创新点总结

  • EMA教师权重:缓慢平滑权重,预测更稳定,缓解伪标签漂移;
  • 强弱双增强分离:学生强增广提升鲁棒性,教师弱增广保证目标可靠;
  • 一致性正则化:利用无标注数据约束模型输出不变性;
  • 损失权重退火:解决训练初期模型预测不可靠的问题。

二、PyTorch 代码

完整代码(MNIST半监督任务,少量标注+大量无标注)

python 复制代码
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Subset
from torchvision import datasets, transforms
from tqdm import tqdm

# ===================== 超参数配置 =====================
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
EPOCHS = 50
BATCH_SIZE = 128
LR = 0.002
EMA_ALPHA = 0.99  # EMA衰减系数
LAMBDA_MAX = 10.0 # 一致性损失最大权重
ANNEAL_STEPS = EPOCHS * 500  # 退火总步数
NUM_LABELED = 1000  # 仅使用1000张标注MNIST,其余为无标注

# ===================== 1. 简单CNN backbone(学生/教师共用) =====================
class ConvNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        logits = self.fc2(x)
        return logits

# ===================== 2. EMA 教师更新工具函数 =====================
def update_teacher(student_model, teacher_model, alpha):
    for s_param, t_param in zip(student_model.parameters(), teacher_model.parameters()):
        t_param.data = alpha * t_param.data + (1.0 - alpha) * s_param.data

# 计算一致性损失权重退火
def get_consistency_weight(current_step):
    # sigmoid退火,0~LAMBDA_MAX
    rampup = np.exp(-5.0 * (1.0 - current_step / ANNEAL_STEPS) ** 2)
    return LAMBDA_MAX * rampup

# ===================== 3. 数据增强:弱增强 + 强增强 =====================
weak_aug = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

strong_aug = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomAffine(degrees=10, translate=(0.1, 0.1)),
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

# 混合数据集:区分标注/无标注样本
class SemiMNIST(Dataset):
    def __init__(self, full_dataset, labeled_mask):
        self.data = full_dataset
        self.labeled_mask = labeled_mask  # True=有标签, False=无标签

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img, label = self.data[idx]
        is_labeled = self.labeled_mask[idx]
        if is_labeled:
            img_weak = weak_aug(img)
            img_strong = strong_aug(img)
            return img_weak, img_strong, label, is_labeled
        else:
            img_weak_t = weak_aug(img)   # 教师输入弱增强
            img_strong_s = strong_aug(img)# 学生输入强增强
            return img_strong_s, img_weak_t, -1, is_labeled

# ===================== 4. 构建半监督MNIST数据集 =====================
def build_semi_mnist():
    train_full = datasets.MNIST(root="./data", train=True, download=True, transform=transforms.ToTensor())
    total_train = len(train_full)
    # 随机划分标注/无标注
    all_indices = np.arange(total_train)
    np.random.shuffle(all_indices)
    labeled_idx = all_indices[:NUM_LABELED]
    labeled_mask = np.zeros(total_train, dtype=bool)
    labeled_mask[labeled_idx] = True

    semi_train = SemiMNIST(train_full, labeled_mask)
    test_set = datasets.MNIST(root="./data", train=False, transform=weak_aug)
    train_loader = DataLoader(semi_train, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
    test_loader = DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
    return train_loader, test_loader

# ===================== 5. 训练主逻辑 =====================
def train_mean_teacher():
    # 初始化双网络
    student = ConvNet().to(DEVICE)
    teacher = ConvNet().to(DEVICE)
    # 教师初始权重复制学生,冻结梯度
    for t_param in teacher.parameters():
        t_param.requires_grad = False
    # 优化器仅更新学生
    opt = torch.optim.Adam(student.parameters(), lr=LR)
    train_loader, test_loader = build_semi_mnist()
    global_step = 0

    for epoch in range(EPOCHS):
        student.train()
        total_loss_epoch = 0.0
        sup_loss_epoch = 0.0
        cons_loss_epoch = 0.0

        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}")
        for batch in pbar:
            opt.zero_grad()
            # 拆分批次数据
            x1, x2, labels, is_labeled = batch
            x1, x2, labels = x1.to(DEVICE), x2.to(DEVICE), labels.to(DEVICE)
            batch_size = x1.shape[0]

            # 前向传播学生网络(全部样本强增广输入学生)
            student_logits = student(x1)
            with torch.no_grad():
                # 教师仅弱增广输入,停止梯度
                teacher_logits = teacher(x2)
                teacher_probs = F.softmax(teacher_logits, dim=1)

            # 1. 监督损失:仅标注样本
            sup_mask = is_labeled.to(DEVICE)
            sup_loss = 0.0
            if torch.sum(sup_mask) > 0:
                sup_logits = student_logits[sup_mask]
                sup_labels = labels[sup_mask]
                sup_loss = F.cross_entropy(sup_logits, sup_labels)

            # 2. 一致性损失:全部样本(标注+无标注都约束分布一致)
            student_probs = F.softmax(student_logits, dim=1)
            cons_loss = F.mse_loss(student_probs, teacher_probs.detach())

            # 3. 加权总损失
            cons_weight = get_consistency_weight(global_step)
            total_loss = sup_loss + cons_weight * cons_loss

            # 反向传播更新学生
            total_loss.backward()
            opt.step()
            # EMA平滑更新教师权重
            update_teacher(student, teacher, EMA_ALPHA)

            # 统计损失
            total_loss_epoch += total_loss.item()
            sup_loss_epoch += sup_loss.item()
            cons_loss_epoch += cons_loss.item() * cons_weight
            global_step += 1

            pbar.set_postfix({
                "total_loss": f"{total_loss.item():.4f}",
                "sup_loss": f"{sup_loss.item():.4f}",
                "cons_loss": f"{cons_weight * cons_loss.item():.4f}",
                "cons_w": f"{cons_weight:.2f}"
            })

        # 打印epoch平均损失
        avg_total = total_loss_epoch / len(train_loader)
        avg_sup = sup_loss_epoch / len(train_loader)
        avg_cons = cons_loss_epoch / len(train_loader)
        print(f"\n[Epoch {epoch+1}] Avg Loss: total={avg_total:.4f}, sup={avg_sup:.4f}, cons={avg_cons:.4f}")

        # 测试阶段:使用教师模型评估(泛化更强)
        teacher.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for img, lab in test_loader:
                img, lab = img.to(DEVICE), lab.to(DEVICE)
                logits = teacher(img)
                pred = torch.argmax(logits, dim=1)
                correct += (pred == lab).sum().item()
                total += lab.size(0)
        acc = 100 * correct / total
        print(f"Test Acc (Teacher Model): {acc:.2f}%\n")

if __name__ == "__main__":
    train_mean_teacher()

三、代码模块逐段解释

1. 网络模块 ConvNet

轻量2层CNN分类网络,学生、教师完全同结构,权重独立初始化后复制。

2. EMA教师更新 update_teacher

仅在学生参数更新后执行,无梯度参与,教师权重缓慢跟随学生平滑移动,避免单步梯度剧烈波动带来的预测噪声。

3. 一致性权重退火 get_consistency_weight

训练前期一致性权重趋近0,优先学习标注数据;训练中后期权重上升,利用大量无标注数据做一致性正则,防止前期模型预测不可靠带来错误监督。

4. 双增强策略

  • 弱增强:轻微归一化,给教师输入,保证目标标签稳定;
  • 强增强:翻转、仿射变换,给学生输入,迫使模型对图像扰动保持输出一致。

5. 数据集 SemiMNIST

自动区分标注/无标注样本,对两类样本分别返回对应强弱增强图像,统一送入训练循环。

6. 损失计算细节

  1. 监督损失仅在有标签样本上计算交叉熵;
  2. 一致性MSE损失对全部样本生效,有标签样本也会额外加一致性约束,进一步提升鲁棒性;
  3. 教师输出全程 torch.no_grad(),不产生梯度,仅作为固定目标。

7. 推理规则

测试时只用教师模型:EMA平滑后的权重泛化性能显著优于实时更新的学生模型,是Mean-Teacher标准推理方案。

四、调优关键技巧

  1. EMA α:分类任务0.99~0.999,数值越大教师更新越慢、预测越平滑;
  2. 一致性权重LAMBDA_MAX:图像分类常用5~20,过小无正则效果,过大训练震荡;
  3. 数据增强差距:学生增强越强,一致性约束收益越高;
  4. 训练策略:前期降低学习率、缓慢提升一致性权重,防止模型崩溃;
  5. 损失替换:可将MSE替换KL散度,对分类概率分布约束效果接近。

五、运行效果说明

  • 数据集:MNIST仅1000张标注,其余59000张无标注;
  • 基线(只用1000标注无自训练)测试精度约75%~82%;
  • Mean-Teacher训练后教师模型测试精度可达93%+,充分验证半监督一致性正则收益。
相关推荐
星空露珠2 小时前
迷你世界UGc3.0脚本Wiki[剧情动画模块管理接口 Timeline]
开发语言·数据结构·算法·游戏·lua
笨笨没好名字2 小时前
Leetcode刷题python3版第一周(下)
linux·算法·leetcode
手写码匠2 小时前
手写 LLM 安全护栏:从内容审核到越狱防御的完整实现
人工智能·深度学习·算法·aigc
luj_17682 小时前
草酸与烟酸对消化及糖代谢的影响解析
服务器·c语言·开发语言·经验分享·算法
青风972 小时前
16-ADAPTRACK:基于自适应阈值的多目标跟踪匹配算法
人工智能·算法·目标跟踪
汤姆yu3 小时前
macOS系统下Aider完整安装、配置与实战使用教程
大数据·人工智能·算法·macos·github·copilot
Sam09273 小时前
【AI 算法精讲 14】TF-IDF:词频与逆文档频率
人工智能·python·算法·ai
AI科技星3 小时前
拓扑生命系统确定性理论:基于32维流形的遗传密码起源与衰老动力学( 中英双语顶刊终稿·标准数学符号)
开发语言·网络·人工智能·算法·机器学习·乖乖数学·全域数学
编程圈子3 小时前
电机驱动开发学习18. SVPWM空间矢量调制算法详解与实现
驱动开发·学习·算法