Mean-Teacher 均值教师自训练框架详解
一、Mean-Teacher 核心原理详细描述
1. 设计背景
半监督学习场景:少量标注数据 + 大量无标注数据。传统伪标签方法单模型预测噪声大、易过拟合无标注样本;Mean-Teacher 提出双模型架构(学生模型 + 教师模型),用教师平滑预测作为软伪标签监督学生,大幅降低伪标签噪声。
2. 两大核心模型
- Student 学生网络
- 可训练、带梯度更新;
- 输入加随机数据增强(强增广);
- 同时接收标注/无标注样本,损失包含监督损失 + 一致性正则损失。
- Teacher 教师网络
- 无梯度、不直接训练 ;权重由学生权重指数移动平均EMA 缓慢更新:
θt=α⋅θt+(1−α)⋅θs\theta_t = \alpha \cdot \theta_t + (1-\alpha) \cdot \theta_sθt=α⋅θt+(1−α)⋅θs
α\alphaα:EMA衰减系数,常用0.99~0.999;训练前期可退火逐步提升。 - 输入仅弱数据增强;
- 输出作为稳定、平滑的目标标签,用来约束学生模型输出分布一致。
- 无梯度、不直接训练 ;权重由学生权重指数移动平均EMA 缓慢更新:
3. 完整损失函数
总损失 = 标注监督分类损失 + 无标注一致性损失
Ltotal=Lsup+λ(t)⋅Lconsist \mathcal{L}{total} = \mathcal{L}{sup} + \lambda(t) \cdot \mathcal{L}_{consist} Ltotal=Lsup+λ(t)⋅Lconsist
- 监督损失 Lsup\mathcal{L}_{sup}Lsup:仅作用有标签样本,交叉熵分类损失;
- 一致性损失 Lconsist\mathcal{L}_{consist}Lconsist:MSE/KL散度,约束学生强增广输出分布 ≈ 教师弱增广输出分布;
- λ(t)\lambda(t)λ(t):一致性权重退火系数:训练前期权重小,后期放大正则约束,避免前期噪声干扰。
4. 训练流程步骤
- 初始化学生网络、教师网络(复制学生初始权重,关闭教师梯度);
- 每迭代取一批混合数据:有标签+无标签;
- 样本分支处理:
- 标注样本:弱增强送入学生,计算分类交叉熵;
- 无标注样本:强增强输入学生 ,弱增强输入教师;
- 前向传播:
- 学生输出:带噪声强增广预测;
- 教师输出:停止梯度,平滑稳定预测(伪标签目标);
- 计算一致性MSE损失,叠加监督损失;
- 反向传播只更新学生网络;
- EMA更新教师权重(无梯度);
- 迭代至收敛,推理只用教师模型(泛化更好)。
5. 关键创新点总结
- EMA教师权重:缓慢平滑权重,预测更稳定,缓解伪标签漂移;
- 强弱双增强分离:学生强增广提升鲁棒性,教师弱增广保证目标可靠;
- 一致性正则化:利用无标注数据约束模型输出不变性;
- 损失权重退火:解决训练初期模型预测不可靠的问题。
二、PyTorch 代码
完整代码(MNIST半监督任务,少量标注+大量无标注)
python
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Subset
from torchvision import datasets, transforms
from tqdm import tqdm
# ===================== 超参数配置 =====================
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
EPOCHS = 50
BATCH_SIZE = 128
LR = 0.002
EMA_ALPHA = 0.99 # EMA衰减系数
LAMBDA_MAX = 10.0 # 一致性损失最大权重
ANNEAL_STEPS = EPOCHS * 500 # 退火总步数
NUM_LABELED = 1000 # 仅使用1000张标注MNIST,其余为无标注
# ===================== 1. 简单CNN backbone(学生/教师共用) =====================
class ConvNet(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 32, 3, padding=1)
self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
self.pool = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(64 * 7 * 7, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = torch.flatten(x, 1)
x = F.relu(self.fc1(x))
logits = self.fc2(x)
return logits
# ===================== 2. EMA 教师更新工具函数 =====================
def update_teacher(student_model, teacher_model, alpha):
for s_param, t_param in zip(student_model.parameters(), teacher_model.parameters()):
t_param.data = alpha * t_param.data + (1.0 - alpha) * s_param.data
# 计算一致性损失权重退火
def get_consistency_weight(current_step):
# sigmoid退火,0~LAMBDA_MAX
rampup = np.exp(-5.0 * (1.0 - current_step / ANNEAL_STEPS) ** 2)
return LAMBDA_MAX * rampup
# ===================== 3. 数据增强:弱增强 + 强增强 =====================
weak_aug = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
strong_aug = transforms.Compose([
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomAffine(degrees=10, translate=(0.1, 0.1)),
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
# 混合数据集:区分标注/无标注样本
class SemiMNIST(Dataset):
def __init__(self, full_dataset, labeled_mask):
self.data = full_dataset
self.labeled_mask = labeled_mask # True=有标签, False=无标签
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
img, label = self.data[idx]
is_labeled = self.labeled_mask[idx]
if is_labeled:
img_weak = weak_aug(img)
img_strong = strong_aug(img)
return img_weak, img_strong, label, is_labeled
else:
img_weak_t = weak_aug(img) # 教师输入弱增强
img_strong_s = strong_aug(img)# 学生输入强增强
return img_strong_s, img_weak_t, -1, is_labeled
# ===================== 4. 构建半监督MNIST数据集 =====================
def build_semi_mnist():
train_full = datasets.MNIST(root="./data", train=True, download=True, transform=transforms.ToTensor())
total_train = len(train_full)
# 随机划分标注/无标注
all_indices = np.arange(total_train)
np.random.shuffle(all_indices)
labeled_idx = all_indices[:NUM_LABELED]
labeled_mask = np.zeros(total_train, dtype=bool)
labeled_mask[labeled_idx] = True
semi_train = SemiMNIST(train_full, labeled_mask)
test_set = datasets.MNIST(root="./data", train=False, transform=weak_aug)
train_loader = DataLoader(semi_train, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
test_loader = DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
return train_loader, test_loader
# ===================== 5. 训练主逻辑 =====================
def train_mean_teacher():
# 初始化双网络
student = ConvNet().to(DEVICE)
teacher = ConvNet().to(DEVICE)
# 教师初始权重复制学生,冻结梯度
for t_param in teacher.parameters():
t_param.requires_grad = False
# 优化器仅更新学生
opt = torch.optim.Adam(student.parameters(), lr=LR)
train_loader, test_loader = build_semi_mnist()
global_step = 0
for epoch in range(EPOCHS):
student.train()
total_loss_epoch = 0.0
sup_loss_epoch = 0.0
cons_loss_epoch = 0.0
pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}")
for batch in pbar:
opt.zero_grad()
# 拆分批次数据
x1, x2, labels, is_labeled = batch
x1, x2, labels = x1.to(DEVICE), x2.to(DEVICE), labels.to(DEVICE)
batch_size = x1.shape[0]
# 前向传播学生网络(全部样本强增广输入学生)
student_logits = student(x1)
with torch.no_grad():
# 教师仅弱增广输入,停止梯度
teacher_logits = teacher(x2)
teacher_probs = F.softmax(teacher_logits, dim=1)
# 1. 监督损失:仅标注样本
sup_mask = is_labeled.to(DEVICE)
sup_loss = 0.0
if torch.sum(sup_mask) > 0:
sup_logits = student_logits[sup_mask]
sup_labels = labels[sup_mask]
sup_loss = F.cross_entropy(sup_logits, sup_labels)
# 2. 一致性损失:全部样本(标注+无标注都约束分布一致)
student_probs = F.softmax(student_logits, dim=1)
cons_loss = F.mse_loss(student_probs, teacher_probs.detach())
# 3. 加权总损失
cons_weight = get_consistency_weight(global_step)
total_loss = sup_loss + cons_weight * cons_loss
# 反向传播更新学生
total_loss.backward()
opt.step()
# EMA平滑更新教师权重
update_teacher(student, teacher, EMA_ALPHA)
# 统计损失
total_loss_epoch += total_loss.item()
sup_loss_epoch += sup_loss.item()
cons_loss_epoch += cons_loss.item() * cons_weight
global_step += 1
pbar.set_postfix({
"total_loss": f"{total_loss.item():.4f}",
"sup_loss": f"{sup_loss.item():.4f}",
"cons_loss": f"{cons_weight * cons_loss.item():.4f}",
"cons_w": f"{cons_weight:.2f}"
})
# 打印epoch平均损失
avg_total = total_loss_epoch / len(train_loader)
avg_sup = sup_loss_epoch / len(train_loader)
avg_cons = cons_loss_epoch / len(train_loader)
print(f"\n[Epoch {epoch+1}] Avg Loss: total={avg_total:.4f}, sup={avg_sup:.4f}, cons={avg_cons:.4f}")
# 测试阶段:使用教师模型评估(泛化更强)
teacher.eval()
correct = 0
total = 0
with torch.no_grad():
for img, lab in test_loader:
img, lab = img.to(DEVICE), lab.to(DEVICE)
logits = teacher(img)
pred = torch.argmax(logits, dim=1)
correct += (pred == lab).sum().item()
total += lab.size(0)
acc = 100 * correct / total
print(f"Test Acc (Teacher Model): {acc:.2f}%\n")
if __name__ == "__main__":
train_mean_teacher()
三、代码模块逐段解释
1. 网络模块 ConvNet
轻量2层CNN分类网络,学生、教师完全同结构,权重独立初始化后复制。
2. EMA教师更新 update_teacher
仅在学生参数更新后执行,无梯度参与,教师权重缓慢跟随学生平滑移动,避免单步梯度剧烈波动带来的预测噪声。
3. 一致性权重退火 get_consistency_weight
训练前期一致性权重趋近0,优先学习标注数据;训练中后期权重上升,利用大量无标注数据做一致性正则,防止前期模型预测不可靠带来错误监督。
4. 双增强策略
- 弱增强:轻微归一化,给教师输入,保证目标标签稳定;
- 强增强:翻转、仿射变换,给学生输入,迫使模型对图像扰动保持输出一致。
5. 数据集 SemiMNIST
自动区分标注/无标注样本,对两类样本分别返回对应强弱增强图像,统一送入训练循环。
6. 损失计算细节
- 监督损失仅在有标签样本上计算交叉熵;
- 一致性MSE损失对全部样本生效,有标签样本也会额外加一致性约束,进一步提升鲁棒性;
- 教师输出全程
torch.no_grad(),不产生梯度,仅作为固定目标。
7. 推理规则
测试时只用教师模型:EMA平滑后的权重泛化性能显著优于实时更新的学生模型,是Mean-Teacher标准推理方案。
四、调优关键技巧
- EMA α:分类任务0.99~0.999,数值越大教师更新越慢、预测越平滑;
- 一致性权重LAMBDA_MAX:图像分类常用5~20,过小无正则效果,过大训练震荡;
- 数据增强差距:学生增强越强,一致性约束收益越高;
- 训练策略:前期降低学习率、缓慢提升一致性权重,防止模型崩溃;
- 损失替换:可将MSE替换KL散度,对分类概率分布约束效果接近。
五、运行效果说明
- 数据集:MNIST仅1000张标注,其余59000张无标注;
- 基线(只用1000标注无自训练)测试精度约75%~82%;
- Mean-Teacher训练后教师模型测试精度可达93%+,充分验证半监督一致性正则收益。