文章目录
-
- 一、标注成本的现实困境
- 二、半监督学习三大范式
-
- [2.1 自训练(Self-Training)](#2.1 自训练(Self-Training))
- [2.2 协同训练(Co-Training)](#2.2 协同训练(Co-Training))
- [2.3 标签传播(Label Propagation)](#2.3 标签传播(Label Propagation))
- 三、自监督学习:零标注的预训练革命
-
- [3.1 对比学习(Contrastive Learning)](#3.1 对比学习(Contrastive Learning))
- [3.2 掩码自编码(Masked Autoencoding)](#3.2 掩码自编码(Masked Autoencoding))
- [四、sklearn 半监督工具速查](#四、sklearn 半监督工具速查)
- [五、半监督 + 自监督的组合策略](#五、半监督 + 自监督的组合策略)
- 六、适用场景决策框架
- 七、实战:文本分类(少量标注场景)
- 八、工程实践中的实用原则
一、标注成本的现实困境
机器学习工程师都知道一个事实:模型性能的上限由数据质量决定,而数据质量的瓶颈往往不是数据量,而是标注量。
标注一条数据需要多少钱?在不同领域,答案差距悬殊:
- 图像分类(ImageNet 类):0.1-1 元 / 条
- 医疗影像标注(需放射科医生确认):100-500 元 / 条
- 法律文本分类(需律师审阅):50-200 元 / 条
- 工业缺陷标注(需领域专家):10-100 元 / 条
- 自动驾驶场景标注(3D 边框 + 轨迹):1000+ 元 / 条
10 万条医疗影像标注 = 1000 万-5000 万元的成本。这不是一个数学问题,而是一个现实障碍。
但无标注数据呢?互联网每天产生数十亿张图片、数万亿条文本、数百亿条行为日志------这些数据是免费的。
这就是半监督学习和自监督学习存在的根本动机:
- 半监督学习:少量标注数据 + 大量无标注数据 → 性能显著优于纯监督学习
- 自监督学习:零标注 → 从数据自身结构创造训练信号 → 学到可迁移的通用表示
两者解决的是同一个问题,但从不同角度切入。
#mermaid-svg-kda9i9lVlzA4VKon{font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}@keyframes edge-animation-frame{from{stroke-dashoffset:0;}}@keyframes dash{to{stroke-dashoffset:0;}}#mermaid-svg-kda9i9lVlzA4VKon .edge-animation-slow{stroke-dasharray:9,5!important;stroke-dashoffset:900;animation:dash 50s linear infinite;stroke-linecap:round;}#mermaid-svg-kda9i9lVlzA4VKon .edge-animation-fast{stroke-dasharray:9,5!important;stroke-dashoffset:900;animation:dash 20s linear infinite;stroke-linecap:round;}#mermaid-svg-kda9i9lVlzA4VKon .error-icon{fill:#552222;}#mermaid-svg-kda9i9lVlzA4VKon .error-text{fill:#552222;stroke:#552222;}#mermaid-svg-kda9i9lVlzA4VKon .edge-thickness-normal{stroke-width:1px;}#mermaid-svg-kda9i9lVlzA4VKon .edge-thickness-thick{stroke-width:3.5px;}#mermaid-svg-kda9i9lVlzA4VKon .edge-pattern-solid{stroke-dasharray:0;}#mermaid-svg-kda9i9lVlzA4VKon .edge-thickness-invisible{stroke-width:0;fill:none;}#mermaid-svg-kda9i9lVlzA4VKon .edge-pattern-dashed{stroke-dasharray:3;}#mermaid-svg-kda9i9lVlzA4VKon .edge-pattern-dotted{stroke-dasharray:2;}#mermaid-svg-kda9i9lVlzA4VKon .marker{fill:#333333;stroke:#333333;}#mermaid-svg-kda9i9lVlzA4VKon .marker.cross{stroke:#333333;}#mermaid-svg-kda9i9lVlzA4VKon svg{font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;}#mermaid-svg-kda9i9lVlzA4VKon p{margin:0;}#mermaid-svg-kda9i9lVlzA4VKon .label{font-family:"trebuchet ms",verdana,arial,sans-serif;color:#333;}#mermaid-svg-kda9i9lVlzA4VKon .cluster-label text{fill:#333;}#mermaid-svg-kda9i9lVlzA4VKon .cluster-label span{color:#333;}#mermaid-svg-kda9i9lVlzA4VKon .cluster-label span p{background-color:transparent;}#mermaid-svg-kda9i9lVlzA4VKon .label text,#mermaid-svg-kda9i9lVlzA4VKon span{fill:#333;color:#333;}#mermaid-svg-kda9i9lVlzA4VKon .node rect,#mermaid-svg-kda9i9lVlzA4VKon .node circle,#mermaid-svg-kda9i9lVlzA4VKon .node ellipse,#mermaid-svg-kda9i9lVlzA4VKon .node polygon,#mermaid-svg-kda9i9lVlzA4VKon .node path{fill:#ECECFF;stroke:#9370DB;stroke-width:1px;}#mermaid-svg-kda9i9lVlzA4VKon .rough-node .label text,#mermaid-svg-kda9i9lVlzA4VKon .node .label text,#mermaid-svg-kda9i9lVlzA4VKon .image-shape .label,#mermaid-svg-kda9i9lVlzA4VKon .icon-shape .label{text-anchor:middle;}#mermaid-svg-kda9i9lVlzA4VKon .node .katex path{fill:#000;stroke:#000;stroke-width:1px;}#mermaid-svg-kda9i9lVlzA4VKon .rough-node .label,#mermaid-svg-kda9i9lVlzA4VKon .node .label,#mermaid-svg-kda9i9lVlzA4VKon .image-shape .label,#mermaid-svg-kda9i9lVlzA4VKon .icon-shape .label{text-align:center;}#mermaid-svg-kda9i9lVlzA4VKon .node.clickable{cursor:pointer;}#mermaid-svg-kda9i9lVlzA4VKon .root .anchor path{fill:#333333!important;stroke-width:0;stroke:#333333;}#mermaid-svg-kda9i9lVlzA4VKon .arrowheadPath{fill:#333333;}#mermaid-svg-kda9i9lVlzA4VKon .edgePath .path{stroke:#333333;stroke-width:2.0px;}#mermaid-svg-kda9i9lVlzA4VKon .flowchart-link{stroke:#333333;fill:none;}#mermaid-svg-kda9i9lVlzA4VKon .edgeLabel{background-color:rgba(232,232,232, 0.8);text-align:center;}#mermaid-svg-kda9i9lVlzA4VKon .edgeLabel p{background-color:rgba(232,232,232, 0.8);}#mermaid-svg-kda9i9lVlzA4VKon .edgeLabel rect{opacity:0.5;background-color:rgba(232,232,232, 0.8);fill:rgba(232,232,232, 0.8);}#mermaid-svg-kda9i9lVlzA4VKon .labelBkg{background-color:rgba(232, 232, 232, 0.5);}#mermaid-svg-kda9i9lVlzA4VKon .cluster rect{fill:#ffffde;stroke:#aaaa33;stroke-width:1px;}#mermaid-svg-kda9i9lVlzA4VKon .cluster text{fill:#333;}#mermaid-svg-kda9i9lVlzA4VKon .cluster span{color:#333;}#mermaid-svg-kda9i9lVlzA4VKon div.mermaidTooltip{position:absolute;text-align:center;max-width:200px;padding:2px;font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:12px;background:hsl(80, 100%, 96.2745098039%);border:1px solid #aaaa33;border-radius:2px;pointer-events:none;z-index:100;}#mermaid-svg-kda9i9lVlzA4VKon .flowchartTitleText{text-anchor:middle;font-size:18px;fill:#333;}#mermaid-svg-kda9i9lVlzA4VKon rect.text{fill:none;stroke-width:0;}#mermaid-svg-kda9i9lVlzA4VKon .icon-shape,#mermaid-svg-kda9i9lVlzA4VKon .image-shape{background-color:rgba(232,232,232, 0.8);text-align:center;}#mermaid-svg-kda9i9lVlzA4VKon .icon-shape p,#mermaid-svg-kda9i9lVlzA4VKon .image-shape p{background-color:rgba(232,232,232, 0.8);padding:2px;}#mermaid-svg-kda9i9lVlzA4VKon .icon-shape .label rect,#mermaid-svg-kda9i9lVlzA4VKon .image-shape .label rect{opacity:0.5;background-color:rgba(232,232,232, 0.8);fill:rgba(232,232,232, 0.8);}#mermaid-svg-kda9i9lVlzA4VKon .label-icon{display:inline-block;height:1em;overflow:visible;vertical-align:-0.125em;}#mermaid-svg-kda9i9lVlzA4VKon .node .label-icon path{fill:currentColor;stroke:revert;stroke-width:revert;}#mermaid-svg-kda9i9lVlzA4VKon :root{--mermaid-font-family:"trebuchet ms",verdana,arial,sans-serif;} 数据现实
大量无标注数据
成本几乎为零
少量标注数据
成本极高
自监督学习
从无标注数据中挖掘
训练信号
纯监督学习
只用标注数据
性能受限
半监督学习
少量标注 + 大量无标注
协同提升
通用表示预训练
下游任务微调
少量标注即可达高性能
直接提升目标任务性能
二、半监督学习三大范式
半监督学习依赖两个核心假设,这两个假设决定了方法的适用边界:
平滑假设(Smoothness Assumption):特征空间中距离相近的点,标签也应相近。如果两个数据点通过高密度区域相连,它们的标签应当一致。
聚类假设(Cluster Assumption):数据分布有天然的聚类结构,同一簇内的点倾向于有相同标签。决策边界应该穿过低密度区域,而不是穿过密集的数据区域。
基于这两个假设,形成了三种主流范式:
2.1 自训练(Self-Training)
核心思路:"自信的学生教自己"
初始模型(用100条标注数据训练)
↓
对所有无标注数据做预测
↓
选高置信预测(confidence > 阈值)作为"伪标签"
↓
将伪标签数据加入训练集,重新训练
↓
重复迭代,直到收敛
自训练的关键工程细节:
置信度阈值如何设置? 不同类别可以有不同阈值------多数类别天然置信度高,少数类别阈值可以适当降低,否则会加剧类别不平衡。
迭代多少轮? 通常 5-10 轮即趋于收敛。过多轮次会导致"错误传播"------错误的伪标签被当作真实标签反复强化。
为什么有效? 模型对无标注数据的高置信预测,往往对应了数据分布中的密集区域(符合聚类假设)。将这些点纳入训练集,相当于帮助模型更好地理解数据流形。
python
import numpy as np
from sklearn.base import clone
from sklearn.semi_supervised import SelfTrainingClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score
# 构造半监督场景:10000 条数据,只有 200 条有标签
X, y = make_classification(
n_samples=10000, n_features=20, n_informative=10,
n_classes=3, random_state=42
)
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42
)
# 模拟标注场景:只保留 200 条标注,其余设为 -1(sklearn 的无标注标记)
rng = np.random.RandomState(42)
labeled_idx = rng.choice(len(X_train), size=200, replace=False)
y_semi = np.full_like(y_train, fill_value=-1)
y_semi[labeled_idx] = y_train[labeled_idx]
# 基线:只用 200 条标注数据的纯监督学习
base_clf = LogisticRegression(max_iter=1000, random_state=42)
base_clf.fit(X_train[labeled_idx], y_train[labeled_idx])
baseline_f1 = f1_score(y_test, base_clf.predict(X_test), average='macro')
print(f"纯监督(200条标注):F1 = {baseline_f1:.4f}")
# 自训练:200 条标注 + 7800 条无标注
self_training_clf = SelfTrainingClassifier(
base_estimator=LogisticRegression(max_iter=1000, random_state=42),
threshold=0.8, # 置信度阈值:只有 >80% 置信度才作为伪标签
max_iter=10, # 最多迭代 10 轮
verbose=True
)
self_training_clf.fit(X_train, y_semi)
st_f1 = f1_score(y_test, self_training_clf.predict(X_test), average='macro')
print(f"自训练(200标注+7800无标注):F1 = {st_f1:.4f}")
print(f"性能提升:{(st_f1 - baseline_f1) / baseline_f1 * 100:.1f}%")
2.2 协同训练(Co-Training)
协同训练需要一个独特条件:数据有两个"充分但冗余"的特征视图。
经典例子是网页分类:
- 视图一:网页正文的词袋特征
- 视图二:指向该网页的超链接锚文本
这两个视图是"充分的"(每个单独都足以完成分类),也是"冗余的"(包含同样的语义信息,只是来自不同角度)。
python
from sklearn.naive_bayes import MultinomialNB
from sklearn.preprocessing import MinMaxScaler
import numpy as np
class CoTrainingClassifier:
"""
协同训练分类器
要求:特征必须可以分成两个视图(view1_features 和 view2_features)
"""
def __init__(self, clf1=None, clf2=None, p=1, n=3, k=30, u=75):
"""
p: 每轮每个分类器选取的正样本数
n: 每轮每个分类器选取的负样本数(每个类别)
k: 迭代轮数
u: 未标注池大小
"""
self.clf1 = clf1 or LogisticRegression(max_iter=500)
self.clf2 = clf2 or LogisticRegression(max_iter=500)
self.p = p
self.n = n
self.k = k
self.u = u
def fit(self, X1_labeled, X2_labeled, y_labeled,
X1_unlabeled, X2_unlabeled):
"""
X1_labeled, X2_labeled: 标注数据的两个视图
X1_unlabeled, X2_unlabeled: 无标注数据的两个视图
"""
# 初始训练
self.clf1.fit(X1_labeled, y_labeled)
self.clf2.fit(X2_labeled, y_labeled)
# 复制无标注数据(避免修改原始数据)
U1 = X1_unlabeled.copy()
U2 = X2_unlabeled.copy()
unlabeled_indices = list(range(len(U1)))
X1_train = X1_labeled.copy()
X2_train = X2_labeled.copy()
y_train = y_labeled.copy()
classes = np.unique(y_labeled)
for iteration in range(self.k):
# 从无标注池中随机采样
if len(unlabeled_indices) < self.u:
pool_indices = unlabeled_indices
else:
pool_indices = np.random.choice(
unlabeled_indices, size=self.u, replace=False
).tolist()
pool1 = U1[pool_indices]
pool2 = U2[pool_indices]
# clf1 对视图1的高置信预测 → 传给 clf2
proba1 = self.clf1.predict_proba(pool1)
new_from_clf1 = self._select_high_confidence(
pool_indices, proba1, classes
)
# clf2 对视图2的高置信预测 → 传给 clf1
proba2 = self.clf2.predict_proba(pool2)
new_from_clf2 = self._select_high_confidence(
pool_indices, proba2, classes
)
# 更新训练集
for idx, label in new_from_clf1 + new_from_clf2:
if idx in unlabeled_indices:
X1_train = np.vstack([X1_train, U1[idx]])
X2_train = np.vstack([X2_train, U2[idx]])
y_train = np.append(y_train, label)
unlabeled_indices.remove(idx)
# 用扩充后的数据重新训练
if len(new_from_clf1) + len(new_from_clf2) > 0:
self.clf1.fit(X1_train, y_train)
self.clf2.fit(X2_train, y_train)
return self
def _select_high_confidence(self, indices, proba, classes):
"""选取每个类别中置信度最高的样本"""
selected = []
for c_idx, c in enumerate(classes):
class_proba = proba[:, c_idx]
sorted_idx = np.argsort(class_proba)[::-1]
for i in sorted_idx[:self.p]:
if class_proba[i] > 0.9: # 置信度阈值
selected.append((indices[i], c))
return selected
def predict(self, X1, X2):
"""两个分类器投票"""
pred1 = self.clf1.predict(X1)
pred2 = self.clf2.predict(X2)
# 简单多数投票(相同时取 clf1 结果)
return np.where(pred1 == pred2, pred1, pred1)
2.3 标签传播(Label Propagation)
标签传播把数据看作一个图:每个样本是节点,样本间的相似度是边权。标签从已知节点沿边传播到未知节点,直到全图收敛。
python
from sklearn.semi_supervised import LabelPropagation, LabelSpreading
import numpy as np
class LabelPropagationAnalysis:
"""
标签传播方法的封装与分析
LabelPropagation vs LabelSpreading 的关键区别:
- LabelPropagation:硬标签(已标注样本的标签固定不变)
- LabelSpreading:软标签(已标注样本的标签也可以轻微修改,更鲁棒)
"""
def compare_methods(self, X, y_partial, X_test, y_test):
"""对比 LP、LS、基线三种方法"""
results = {}
# 1. 纯监督基线
labeled_mask = y_partial != -1
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import f1_score
base = LogisticRegression(max_iter=1000)
base.fit(X[labeled_mask], y_partial[labeled_mask])
results['supervised_baseline'] = f1_score(
y_test, base.predict(X_test), average='macro'
)
# 2. LabelPropagation(硬标签固定)
lp = LabelPropagation(
kernel='rbf', # RBF 核构建相似度图
gamma=0.25, # gamma 控制相似度衰减速度
max_iter=1000
)
lp.fit(X, y_partial)
results['label_propagation'] = f1_score(
y_test, lp.predict(X_test), average='macro'
)
# 3. LabelSpreading(软标签,alpha 控制正则化强度)
ls = LabelSpreading(
kernel='rbf',
alpha=0.2, # alpha=0:完全信任标注;alpha=1:完全忽略标注
max_iter=1000
)
ls.fit(X, y_partial)
results['label_spreading'] = f1_score(
y_test, ls.predict(X_test), average='macro'
)
return results
def analyze_label_spread_quality(self, label_spreading_model, X, y_true):
"""
分析标签传播质量:检查无标注样本的预测置信度分布
置信度低的样本说明图结构无法良好传播标签
"""
proba = label_spreading_model.label_distributions_ # (n_samples, n_classes)
max_proba = proba.max(axis=1)
print("无标注样本置信度分布:")
print(f" 高置信(>0.9): {(max_proba > 0.9).sum()} 条")
print(f" 中置信(0.7-0.9): {((max_proba > 0.7) & (max_proba <= 0.9)).sum()} 条")
print(f" 低置信(<0.7): {(max_proba <= 0.7).sum()} 条")
print("⚠️ 低置信样本是标签传播的薄弱点------图连接质量不足")
三、自监督学习:零标注的预训练革命
自监督学习的核心思路是:从数据自身的结构中构造监督信号------不需要人工标注,数据本身就是标签的来源。
这个思路听起来玄妙,但实例非常直观:
| 预训练任务 | 输入 | 自动生成的"标签" | 学到的能力 |
|---|---|---|---|
| BERT 掩码语言模型 | 完整句子 | 被掩盖的词 | 上下文语义理解 |
| 图像旋转预测 | 旋转后的图像 | 旋转角度(0/90/180/270°) | 图像内容识别 |
| 对比学习 | 同一图像的两次增强 | 同一图像的增强是正样本 | 视觉语义相似性 |
| 自回归预测 | 前缀序列 | 下一个 token | 序列结构建模 |
这些任务的共同点:解决这些任务需要"理解"数据内容,因此学到的表示是有语义价值的。
3.1 对比学习(Contrastive Learning)
对比学习是 2020-2022 年视觉自监督的核心突破。核心思路:
同一图像的不同增强 → 正样本对(应该相似)
不同图像的增强 → 负样本对(应该不同)
训练目标:拉近正样本表示,推开负样本表示
#mermaid-svg-rfuG5hhV8UfobeCb{font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}@keyframes edge-animation-frame{from{stroke-dashoffset:0;}}@keyframes dash{to{stroke-dashoffset:0;}}#mermaid-svg-rfuG5hhV8UfobeCb .edge-animation-slow{stroke-dasharray:9,5!important;stroke-dashoffset:900;animation:dash 50s linear infinite;stroke-linecap:round;}#mermaid-svg-rfuG5hhV8UfobeCb .edge-animation-fast{stroke-dasharray:9,5!important;stroke-dashoffset:900;animation:dash 20s linear infinite;stroke-linecap:round;}#mermaid-svg-rfuG5hhV8UfobeCb .error-icon{fill:#552222;}#mermaid-svg-rfuG5hhV8UfobeCb .error-text{fill:#552222;stroke:#552222;}#mermaid-svg-rfuG5hhV8UfobeCb .edge-thickness-normal{stroke-width:1px;}#mermaid-svg-rfuG5hhV8UfobeCb .edge-thickness-thick{stroke-width:3.5px;}#mermaid-svg-rfuG5hhV8UfobeCb .edge-pattern-solid{stroke-dasharray:0;}#mermaid-svg-rfuG5hhV8UfobeCb .edge-thickness-invisible{stroke-width:0;fill:none;}#mermaid-svg-rfuG5hhV8UfobeCb .edge-pattern-dashed{stroke-dasharray:3;}#mermaid-svg-rfuG5hhV8UfobeCb .edge-pattern-dotted{stroke-dasharray:2;}#mermaid-svg-rfuG5hhV8UfobeCb .marker{fill:#333333;stroke:#333333;}#mermaid-svg-rfuG5hhV8UfobeCb .marker.cross{stroke:#333333;}#mermaid-svg-rfuG5hhV8UfobeCb svg{font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;}#mermaid-svg-rfuG5hhV8UfobeCb p{margin:0;}#mermaid-svg-rfuG5hhV8UfobeCb .label{font-family:"trebuchet ms",verdana,arial,sans-serif;color:#333;}#mermaid-svg-rfuG5hhV8UfobeCb .cluster-label text{fill:#333;}#mermaid-svg-rfuG5hhV8UfobeCb .cluster-label span{color:#333;}#mermaid-svg-rfuG5hhV8UfobeCb .cluster-label span p{background-color:transparent;}#mermaid-svg-rfuG5hhV8UfobeCb .label text,#mermaid-svg-rfuG5hhV8UfobeCb span{fill:#333;color:#333;}#mermaid-svg-rfuG5hhV8UfobeCb .node rect,#mermaid-svg-rfuG5hhV8UfobeCb .node circle,#mermaid-svg-rfuG5hhV8UfobeCb .node ellipse,#mermaid-svg-rfuG5hhV8UfobeCb .node polygon,#mermaid-svg-rfuG5hhV8UfobeCb .node path{fill:#ECECFF;stroke:#9370DB;stroke-width:1px;}#mermaid-svg-rfuG5hhV8UfobeCb .rough-node .label text,#mermaid-svg-rfuG5hhV8UfobeCb .node .label text,#mermaid-svg-rfuG5hhV8UfobeCb .image-shape .label,#mermaid-svg-rfuG5hhV8UfobeCb .icon-shape .label{text-anchor:middle;}#mermaid-svg-rfuG5hhV8UfobeCb .node .katex path{fill:#000;stroke:#000;stroke-width:1px;}#mermaid-svg-rfuG5hhV8UfobeCb .rough-node .label,#mermaid-svg-rfuG5hhV8UfobeCb .node .label,#mermaid-svg-rfuG5hhV8UfobeCb .image-shape .label,#mermaid-svg-rfuG5hhV8UfobeCb .icon-shape .label{text-align:center;}#mermaid-svg-rfuG5hhV8UfobeCb .node.clickable{cursor:pointer;}#mermaid-svg-rfuG5hhV8UfobeCb .root .anchor path{fill:#333333!important;stroke-width:0;stroke:#333333;}#mermaid-svg-rfuG5hhV8UfobeCb .arrowheadPath{fill:#333333;}#mermaid-svg-rfuG5hhV8UfobeCb .edgePath .path{stroke:#333333;stroke-width:2.0px;}#mermaid-svg-rfuG5hhV8UfobeCb .flowchart-link{stroke:#333333;fill:none;}#mermaid-svg-rfuG5hhV8UfobeCb .edgeLabel{background-color:rgba(232,232,232, 0.8);text-align:center;}#mermaid-svg-rfuG5hhV8UfobeCb .edgeLabel p{background-color:rgba(232,232,232, 0.8);}#mermaid-svg-rfuG5hhV8UfobeCb .edgeLabel rect{opacity:0.5;background-color:rgba(232,232,232, 0.8);fill:rgba(232,232,232, 0.8);}#mermaid-svg-rfuG5hhV8UfobeCb .labelBkg{background-color:rgba(232, 232, 232, 0.5);}#mermaid-svg-rfuG5hhV8UfobeCb .cluster rect{fill:#ffffde;stroke:#aaaa33;stroke-width:1px;}#mermaid-svg-rfuG5hhV8UfobeCb .cluster text{fill:#333;}#mermaid-svg-rfuG5hhV8UfobeCb .cluster span{color:#333;}#mermaid-svg-rfuG5hhV8UfobeCb div.mermaidTooltip{position:absolute;text-align:center;max-width:200px;padding:2px;font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:12px;background:hsl(80, 100%, 96.2745098039%);border:1px solid #aaaa33;border-radius:2px;pointer-events:none;z-index:100;}#mermaid-svg-rfuG5hhV8UfobeCb .flowchartTitleText{text-anchor:middle;font-size:18px;fill:#333;}#mermaid-svg-rfuG5hhV8UfobeCb rect.text{fill:none;stroke-width:0;}#mermaid-svg-rfuG5hhV8UfobeCb .icon-shape,#mermaid-svg-rfuG5hhV8UfobeCb .image-shape{background-color:rgba(232,232,232, 0.8);text-align:center;}#mermaid-svg-rfuG5hhV8UfobeCb .icon-shape p,#mermaid-svg-rfuG5hhV8UfobeCb .image-shape p{background-color:rgba(232,232,232, 0.8);padding:2px;}#mermaid-svg-rfuG5hhV8UfobeCb .icon-shape .label rect,#mermaid-svg-rfuG5hhV8UfobeCb .image-shape .label rect{opacity:0.5;background-color:rgba(232,232,232, 0.8);fill:rgba(232,232,232, 0.8);}#mermaid-svg-rfuG5hhV8UfobeCb .label-icon{display:inline-block;height:1em;overflow:visible;vertical-align:-0.125em;}#mermaid-svg-rfuG5hhV8UfobeCb .node .label-icon path{fill:currentColor;stroke:revert;stroke-width:revert;}#mermaid-svg-rfuG5hhV8UfobeCb :root{--mermaid-font-family:"trebuchet ms",verdana,arial,sans-serif;} 原始图片 x
数据增强 t₁
(裁剪/旋转/颜色抖动)
数据增强 t₂
(裁剪/旋转/颜色抖动)
编码器 f(·)
ResNet/ViT
编码器 f(·)
共享权重
投影头 g(·)
MLP
投影头 g(·)
MLP
z₁
z₂
NT-Xent Loss
相同图像的增强:拉近
不同图像的增强:推开
SimCLR 的损失函数(NT-Xent,归一化温度缩放交叉熵):
python
import torch
import torch.nn as nn
import torch.nn.functional as F
class NTXentLoss(nn.Module):
"""
SimCLR 的对比损失函数
参考:Chen et al. 2020 "A Simple Framework for Contrastive Learning"
"""
def __init__(self, temperature=0.5, batch_size=256):
super().__init__()
self.temperature = temperature
self.batch_size = batch_size
def forward(self, z1, z2):
"""
z1, z2: (batch_size, embedding_dim) --- 同一批图片的两次增强的嵌入
"""
batch_size = z1.shape[0]
# L2 归一化(确保只比较方向,不比较大小)
z1 = F.normalize(z1, dim=1)
z2 = F.normalize(z2, dim=1)
# 拼接 → (2*batch_size, embedding_dim)
z = torch.cat([z1, z2], dim=0)
# 计算所有对之间的余弦相似度矩阵
sim_matrix = torch.mm(z, z.t()) / self.temperature # (2B, 2B)
# 去除对角线(自相似度)
mask = torch.eye(2 * batch_size, dtype=torch.bool, device=z.device)
sim_matrix = sim_matrix.masked_fill(mask, float('-inf'))
# 正样本对的索引:z1[i] 的正样本是 z2[i],即索引 i+batch_size
labels = torch.arange(batch_size, device=z.device)
labels = torch.cat([labels + batch_size, labels], dim=0) # (2B,)
# 交叉熵损失
loss = F.cross_entropy(sim_matrix, labels)
return loss
class SimpleContrastiveEncoder(nn.Module):
"""极简对比学习编码器(演示用)"""
def __init__(self, input_dim, hidden_dim=256, proj_dim=128):
super().__init__()
# 骨干网络(实际用 ResNet/ViT,这里简化为 MLP)
self.backbone = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.BatchNorm1d(hidden_dim),
nn.ReLU()
)
# 投影头(只在预训练时用,下游任务丢弃)
self.projector = nn.Sequential(
nn.Linear(hidden_dim, proj_dim),
nn.BatchNorm1d(proj_dim),
nn.ReLU(),
nn.Linear(proj_dim, proj_dim)
)
def forward(self, x):
h = self.backbone(x)
z = self.projector(h)
return z
def get_representation(self, x):
"""下游任务使用骨干网络的输出(不经过投影头)"""
return self.backbone(x)
为什么要丢弃投影头? 这是 SimCLR 中的重要发现:投影头(projector)学到的表示对下游任务反而更差,骨干网络的输出(representation)才是真正有用的特征。原因是投影头被训练去"忘记"增强不变的信息,而骨干网络保留了更多语义信息。
3.2 掩码自编码(Masked Autoencoding)
python
import torch
import torch.nn as nn
import numpy as np
class MaskedTabularAutoencoder(nn.Module):
"""
面向表格数据的掩码自编码器
思路来自 BERT 的掩码语言模型,适配到结构化特征
"""
def __init__(self, input_dim, hidden_dim=256, mask_ratio=0.3):
super().__init__()
self.mask_ratio = mask_ratio
self.input_dim = input_dim
# 掩码 token 的可学习嵌入
self.mask_token = nn.Parameter(torch.zeros(1, input_dim))
# 编码器:将(部分)可见特征编码为隐向量
self.encoder = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.GELU()
)
# 解码器:从隐向量重建所有特征
self.decoder = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, input_dim)
)
def forward(self, x):
batch_size = x.shape[0]
# 随机生成掩码(1=被掩码,0=可见)
mask = torch.rand(batch_size, self.input_dim) < self.mask_ratio
mask = mask.to(x.device)
# 用掩码 token 替换被掩码的特征
x_masked = x.clone()
x_masked[mask] = self.mask_token.expand_as(x)[mask]
# 编码 → 解码
latent = self.encoder(x_masked)
reconstruction = self.decoder(latent)
# 只在被掩码的位置计算重建损失
loss = ((reconstruction - x) ** 2 * mask.float()).sum() / mask.float().sum()
return loss, latent
def get_representation(self, x):
"""提取表示(推理时不掩码)"""
with torch.no_grad():
latent = self.encoder(x)
return latent
四、sklearn 半监督工具速查
sklearn 提供了开箱即用的半监督工具,适合快速验证方案可行性:
python
import numpy as np
from sklearn.datasets import fetch_20newsgroups
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.semi_supervised import (
SelfTrainingClassifier,
LabelPropagation,
LabelSpreading
)
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import f1_score, classification_report
from sklearn.pipeline import Pipeline
import warnings
warnings.filterwarnings('ignore')
def run_semi_supervised_comparison(
X_train, y_train_full, X_test, y_test,
labeled_count=200
):
"""
对比三种半监督方法 vs 纯监督基线
参数:
labeled_count: 模拟的标注数量
"""
rng = np.random.RandomState(42)
# 构造半标注场景
labeled_idx = rng.choice(len(X_train), size=labeled_count, replace=False)
y_partial = np.full(len(y_train_full), -1)
y_partial[labeled_idx] = y_train_full[labeled_idx]
results = {}
# ① 纯监督基线(只用 labeled_count 条标注)
lr_base = LogisticRegression(max_iter=1000, C=1.0)
lr_base.fit(X_train[labeled_idx], y_train_full[labeled_idx])
results['supervised_only'] = f1_score(
y_test, lr_base.predict(X_test), average='macro'
)
# ② 全量监督上界(告诉我们半监督最多能接近的水平)
lr_full = LogisticRegression(max_iter=1000, C=1.0)
lr_full.fit(X_train, y_train_full)
results['full_supervised_upper_bound'] = f1_score(
y_test, lr_full.predict(X_test), average='macro'
)
# ③ 自训练(SelfTrainingClassifier)
# 关键参数:threshold 越高越保守,每次只用最确定的伪标签
st_clf = SelfTrainingClassifier(
base_estimator=LogisticRegression(max_iter=500),
threshold=0.85,
max_iter=10,
verbose=False
)
st_clf.fit(X_train, y_partial)
results['self_training'] = f1_score(
y_test, st_clf.predict(X_test), average='macro'
)
# ④ LabelSpreading(图传播,alpha 控制平滑程度)
# 注意:LabelSpreading 比 LabelPropagation 更鲁棒(软标签,容忍标注噪声)
ls_clf = LabelSpreading(
kernel='knn', # knn 核:每个点连接 k 个最近邻
k=7,
alpha=0.2, # alpha 接近 0:紧扣标注;alpha 接近 1:更依赖图结构
max_iter=1000
)
ls_clf.fit(X_train, y_partial)
results['label_spreading'] = f1_score(
y_test, ls_clf.predict(X_test), average='macro'
)
# 打印对比报告
print(f"\n{'='*60}")
print(f"半监督 vs 纯监督对比({labeled_count} 条标注 / {len(X_train)} 条总数据)")
print(f"{'='*60}")
print(f"{'方法':<30} {'Macro F1':>10}")
print(f"{'-'*40}")
for method, score in results.items():
marker = " ← 当前方法" if method not in [
'supervised_only', 'full_supervised_upper_bound'
] else ""
print(f"{method:<30} {score:>10.4f}{marker}")
gap_closed = (
(results.get('self_training', 0) - results['supervised_only']) /
(results['full_supervised_upper_bound'] - results['supervised_only'] + 1e-8)
) * 100
print(f"\n自训练弥补了监督学习上限差距的 {gap_closed:.1f}%")
return results
五、半监督 + 自监督的组合策略
单独使用半监督或自监督都有局限:
| 方法 | 优势 | 局限 |
|---|---|---|
| 半监督(自训练) | 直接优化目标任务 | 依赖初始模型质量 |
| 半监督(图传播) | 利用数据流形结构 | 高维数据图构建困难 |
| 自监督(对比学习) | 不需要任何标注 | 预训练-微调有任务鸿沟 |
组合策略(当前最佳实践):
Step 1:自监督预训练(使用所有无标注数据)
→ 目标:学到好的通用特征表示
→ 方法:对比学习 / 掩码自编码
Step 2:半监督微调(使用少量标注 + 大量无标注)
→ 目标:将通用表示适配到目标任务
→ 方法:自训练 / 标签传播
→ 初始化:使用 Step 1 的表示而非随机初始化
关键优势:
Step 1 提供了好的"起点"------表示已经捕捉了数据的语义结构
Step 2 的半监督学习在好的表示基础上更容易传播标签
性能远超任一单独方法
python
import torch
import torch.nn as nn
from sklearn.semi_supervised import SelfTrainingClassifier
from sklearn.linear_model import LogisticRegression
import numpy as np
class SelfSupervisedPretrain:
"""
表格数据的自监督预训练 → 半监督微调流程
"""
def __init__(self, input_dim, hidden_dim=256, mask_ratio=0.3,
pretrain_epochs=100, lr=1e-3):
self.autoencoder = MaskedTabularAutoencoder(
input_dim, hidden_dim, mask_ratio
)
self.pretrain_epochs = pretrain_epochs
self.lr = lr
self.hidden_dim = hidden_dim
def pretrain(self, X_all):
"""
用所有无标注数据做自监督预训练
X_all: 全部数据(含无标注),shape (n_samples, n_features)
"""
X_tensor = torch.FloatTensor(X_all)
optimizer = torch.optim.Adam(self.autoencoder.parameters(), lr=self.lr)
self.autoencoder.train()
for epoch in range(self.pretrain_epochs):
# 打乱顺序
perm = torch.randperm(len(X_tensor))
total_loss = 0
n_batches = 0
for i in range(0, len(X_tensor), 256):
batch = X_tensor[perm[i:i+256]]
optimizer.zero_grad()
loss, _ = self.autoencoder(batch)
loss.backward()
optimizer.step()
total_loss += loss.item()
n_batches += 1
if (epoch + 1) % 20 == 0:
print(f"预训练 Epoch {epoch+1}/{self.pretrain_epochs}, "
f"重建损失: {total_loss/n_batches:.4f}")
self.autoencoder.eval()
return self
def extract_features(self, X):
"""提取预训练表示"""
X_tensor = torch.FloatTensor(X)
with torch.no_grad():
features = self.autoencoder.get_representation(X_tensor)
return features.numpy()
def semi_supervised_finetune(self, X_train, y_partial, X_test, y_test):
"""
用预训练表示做半监督微调
y_partial: 含 -1 的标签数组(-1 表示无标注)
"""
# 用预训练表示替换原始特征
X_train_repr = self.extract_features(X_train)
X_test_repr = self.extract_features(X_test)
# 在好的表示上做自训练
st_clf = SelfTrainingClassifier(
base_estimator=LogisticRegression(max_iter=1000),
threshold=0.8,
max_iter=15
)
st_clf.fit(X_train_repr, y_partial)
from sklearn.metrics import f1_score
return f1_score(y_test, st_clf.predict(X_test_repr), average='macro')
六、适用场景决策框架
在决定使用哪种半监督方法之前,需要回答几个关键问题:
#mermaid-svg-gRGmuNSjLcvOtYIv{font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}@keyframes edge-animation-frame{from{stroke-dashoffset:0;}}@keyframes dash{to{stroke-dashoffset:0;}}#mermaid-svg-gRGmuNSjLcvOtYIv .edge-animation-slow{stroke-dasharray:9,5!important;stroke-dashoffset:900;animation:dash 50s linear infinite;stroke-linecap:round;}#mermaid-svg-gRGmuNSjLcvOtYIv .edge-animation-fast{stroke-dasharray:9,5!important;stroke-dashoffset:900;animation:dash 20s linear infinite;stroke-linecap:round;}#mermaid-svg-gRGmuNSjLcvOtYIv .error-icon{fill:#552222;}#mermaid-svg-gRGmuNSjLcvOtYIv .error-text{fill:#552222;stroke:#552222;}#mermaid-svg-gRGmuNSjLcvOtYIv .edge-thickness-normal{stroke-width:1px;}#mermaid-svg-gRGmuNSjLcvOtYIv .edge-thickness-thick{stroke-width:3.5px;}#mermaid-svg-gRGmuNSjLcvOtYIv .edge-pattern-solid{stroke-dasharray:0;}#mermaid-svg-gRGmuNSjLcvOtYIv .edge-thickness-invisible{stroke-width:0;fill:none;}#mermaid-svg-gRGmuNSjLcvOtYIv .edge-pattern-dashed{stroke-dasharray:3;}#mermaid-svg-gRGmuNSjLcvOtYIv .edge-pattern-dotted{stroke-dasharray:2;}#mermaid-svg-gRGmuNSjLcvOtYIv .marker{fill:#333333;stroke:#333333;}#mermaid-svg-gRGmuNSjLcvOtYIv .marker.cross{stroke:#333333;}#mermaid-svg-gRGmuNSjLcvOtYIv svg{font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;}#mermaid-svg-gRGmuNSjLcvOtYIv p{margin:0;}#mermaid-svg-gRGmuNSjLcvOtYIv .label{font-family:"trebuchet ms",verdana,arial,sans-serif;color:#333;}#mermaid-svg-gRGmuNSjLcvOtYIv .cluster-label text{fill:#333;}#mermaid-svg-gRGmuNSjLcvOtYIv .cluster-label span{color:#333;}#mermaid-svg-gRGmuNSjLcvOtYIv .cluster-label span p{background-color:transparent;}#mermaid-svg-gRGmuNSjLcvOtYIv .label text,#mermaid-svg-gRGmuNSjLcvOtYIv span{fill:#333;color:#333;}#mermaid-svg-gRGmuNSjLcvOtYIv .node rect,#mermaid-svg-gRGmuNSjLcvOtYIv .node circle,#mermaid-svg-gRGmuNSjLcvOtYIv .node ellipse,#mermaid-svg-gRGmuNSjLcvOtYIv .node polygon,#mermaid-svg-gRGmuNSjLcvOtYIv .node path{fill:#ECECFF;stroke:#9370DB;stroke-width:1px;}#mermaid-svg-gRGmuNSjLcvOtYIv .rough-node .label text,#mermaid-svg-gRGmuNSjLcvOtYIv .node .label text,#mermaid-svg-gRGmuNSjLcvOtYIv .image-shape .label,#mermaid-svg-gRGmuNSjLcvOtYIv .icon-shape .label{text-anchor:middle;}#mermaid-svg-gRGmuNSjLcvOtYIv .node .katex path{fill:#000;stroke:#000;stroke-width:1px;}#mermaid-svg-gRGmuNSjLcvOtYIv .rough-node .label,#mermaid-svg-gRGmuNSjLcvOtYIv .node .label,#mermaid-svg-gRGmuNSjLcvOtYIv .image-shape .label,#mermaid-svg-gRGmuNSjLcvOtYIv .icon-shape .label{text-align:center;}#mermaid-svg-gRGmuNSjLcvOtYIv .node.clickable{cursor:pointer;}#mermaid-svg-gRGmuNSjLcvOtYIv .root .anchor path{fill:#333333!important;stroke-width:0;stroke:#333333;}#mermaid-svg-gRGmuNSjLcvOtYIv .arrowheadPath{fill:#333333;}#mermaid-svg-gRGmuNSjLcvOtYIv .edgePath .path{stroke:#333333;stroke-width:2.0px;}#mermaid-svg-gRGmuNSjLcvOtYIv .flowchart-link{stroke:#333333;fill:none;}#mermaid-svg-gRGmuNSjLcvOtYIv .edgeLabel{background-color:rgba(232,232,232, 0.8);text-align:center;}#mermaid-svg-gRGmuNSjLcvOtYIv .edgeLabel p{background-color:rgba(232,232,232, 0.8);}#mermaid-svg-gRGmuNSjLcvOtYIv .edgeLabel rect{opacity:0.5;background-color:rgba(232,232,232, 0.8);fill:rgba(232,232,232, 0.8);}#mermaid-svg-gRGmuNSjLcvOtYIv .labelBkg{background-color:rgba(232, 232, 232, 0.5);}#mermaid-svg-gRGmuNSjLcvOtYIv .cluster rect{fill:#ffffde;stroke:#aaaa33;stroke-width:1px;}#mermaid-svg-gRGmuNSjLcvOtYIv .cluster text{fill:#333;}#mermaid-svg-gRGmuNSjLcvOtYIv .cluster span{color:#333;}#mermaid-svg-gRGmuNSjLcvOtYIv div.mermaidTooltip{position:absolute;text-align:center;max-width:200px;padding:2px;font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:12px;background:hsl(80, 100%, 96.2745098039%);border:1px solid #aaaa33;border-radius:2px;pointer-events:none;z-index:100;}#mermaid-svg-gRGmuNSjLcvOtYIv .flowchartTitleText{text-anchor:middle;font-size:18px;fill:#333;}#mermaid-svg-gRGmuNSjLcvOtYIv rect.text{fill:none;stroke-width:0;}#mermaid-svg-gRGmuNSjLcvOtYIv .icon-shape,#mermaid-svg-gRGmuNSjLcvOtYIv .image-shape{background-color:rgba(232,232,232, 0.8);text-align:center;}#mermaid-svg-gRGmuNSjLcvOtYIv .icon-shape p,#mermaid-svg-gRGmuNSjLcvOtYIv .image-shape p{background-color:rgba(232,232,232, 0.8);padding:2px;}#mermaid-svg-gRGmuNSjLcvOtYIv .icon-shape .label rect,#mermaid-svg-gRGmuNSjLcvOtYIv .image-shape .label rect{opacity:0.5;background-color:rgba(232,232,232, 0.8);fill:rgba(232,232,232, 0.8);}#mermaid-svg-gRGmuNSjLcvOtYIv .label-icon{display:inline-block;height:1em;overflow:visible;vertical-align:-0.125em;}#mermaid-svg-gRGmuNSjLcvOtYIv .node .label-icon path{fill:currentColor;stroke:revert;stroke-width:revert;}#mermaid-svg-gRGmuNSjLcvOtYIv :root{--mermaid-font-family:"trebuchet ms",verdana,arial,sans-serif;} 完全没有标注
有少量标注(<1%)
有一定标注(1%-10%)
是
否
高(>70%准确率)
低
低维(<100维)
高维
标注数据稀缺场景
有无标注数据?
自监督学习
对比学习/掩码自编码
数据有两个独立视图?
数据维度?
协同训练
Co-Training
初始模型置信度?
自训练
Self-Training
自监督预训练
- 半监督微调
标签传播/扩散
Label Spreading
自训练
或降维后做标签传播
下游任务微调
直接用于目标任务
常见陷阱与对策:
陷阱一:伪标签错误传播
自训练的核心风险是:初始模型的错误预测被当作伪标签,导致错误在迭代中放大。
对策:使用较高的置信度阈值(0.85-0.95);监控每轮迭代后验证集性能,如果下降则停止。
陷阱二:标签传播的稀疏图问题
当无标注样本和标注样本之间相似度很低时(分布偏移),标签传播效果极差。
对策:先用自监督预训练得到好的表示,再在表示空间构建图------相似度图质量会大幅提升。
陷阱三:协同训练的视图独立性假设
协同训练假设两个视图是"条件独立"的------给定标签后,两个视图没有相关性。这个假设在实践中很难严格满足。
对策:不要死扣理论,尝试多种视图划分方案,用验证集 F1 选择最优划分。
七、实战:文本分类(少量标注场景)
模拟一个真实场景:客服工单分类,只有 300 条有标注,30000 条无标注。
python
import numpy as np
from sklearn.datasets import fetch_20newsgroups
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.semi_supervised import SelfTrainingClassifier, LabelSpreading
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import f1_score, classification_report
from sklearn.decomposition import TruncatedSVD
from sklearn.pipeline import make_pipeline
def text_classification_semi_supervised_demo():
"""
20 Newsgroups 子集模拟少量标注文本分类
选4个类别,每类只有 75 条标注(共 300 条),测试半监督效果
"""
categories = [
'sci.space', 'comp.graphics',
'rec.sport.hockey', 'talk.politics.misc'
]
# 训练集(大部分无标注)
train_data = fetch_20newsgroups(
subset='train', categories=categories,
remove=('headers', 'footers', 'quotes')
)
test_data = fetch_20newsgroups(
subset='test', categories=categories,
remove=('headers', 'footers', 'quotes')
)
# TF-IDF 特征
vectorizer = TfidfVectorizer(
max_features=10000, ngram_range=(1, 2),
min_df=3, sublinear_tf=True
)
X_train = vectorizer.fit_transform(train_data.data)
X_test = vectorizer.transform(test_data.data)
y_train = train_data.target
y_test = test_data.target
# 降维:LabelSpreading 在高维稀疏矩阵上效果差,先用 SVD 降维
svd = TruncatedSVD(n_components=300, random_state=42)
X_train_dense = svd.fit_transform(X_train)
X_test_dense = svd.transform(X_test)
# 模拟标注场景:每类只保留 75 条标注
rng = np.random.RandomState(42)
n_labeled_per_class = 75
labeled_idx = []
for c in range(len(categories)):
class_idx = np.where(y_train == c)[0]
selected = rng.choice(class_idx, size=n_labeled_per_class, replace=False)
labeled_idx.extend(selected.tolist())
y_partial = np.full(len(y_train), -1)
y_partial[labeled_idx] = y_train[labeled_idx]
results = {}
# 方法一:纯监督基线
lr = LogisticRegression(C=1.0, max_iter=1000)
lr.fit(X_train_dense[labeled_idx], y_train[labeled_idx])
results['纯监督(300条)'] = f1_score(y_test, lr.predict(X_test_dense), average='macro')
# 方法二:全量监督上界
lr_full = LogisticRegression(C=1.0, max_iter=1000)
lr_full.fit(X_train_dense, y_train)
results['全量监督上界'] = f1_score(y_test, lr_full.predict(X_test_dense), average='macro')
# 方法三:自训练
st = SelfTrainingClassifier(
LogisticRegression(C=1.0, max_iter=500),
threshold=0.85, max_iter=10, verbose=False
)
st.fit(X_train_dense, y_partial)
results['自训练(300+无标注)'] = f1_score(y_test, st.predict(X_test_dense), average='macro')
# 方法四:LabelSpreading
ls = LabelSpreading(kernel='knn', k=10, alpha=0.2, max_iter=200)
# 注意:LabelSpreading 只能处理密集矩阵
ls.fit(X_train_dense, y_partial)
results['标签扩散(300+无标注)'] = f1_score(y_test, ls.predict(X_test_dense), average='macro')
# 打印结果
print("\n文本分类半监督对比(20 Newsgroups 4类,300条标注)")
print("=" * 55)
for method, score in results.items():
print(f"{method:<30} Macro F1 = {score:.4f}")
# 分析自训练的迭代过程
print(f"\n自训练迭代轮数: {st.n_iter_}")
labeled_added = (st.transduction_ != y_partial) & (y_partial == -1)
print(f"被加入的伪标签数量: {labeled_added.sum()} / {(y_partial == -1).sum()}")
return results
# 运行演示
# results = text_classification_semi_supervised_demo()
典型实验结果(供参考,实际会因随机种子略有浮动):
| 方法 | Macro F1 |
|---|---|
| 纯监督(300条标注) | 0.71 |
| 自训练(300+无标注) | 0.78 |
| 标签扩散(300+无标注) | 0.76 |
| 全量监督上界 | 0.88 |
自训练弥补了约 50% 的上限差距,标签扩散弥补了约 29%------这个差异来自高维稀疏空间中图连接质量较差。
八、工程实践中的实用原则
标注预算分配建议:
| 总样本量 | 建议最少标注量 | 可期待的半监督效果 |
|---|---|---|
| 1 万条 | 200-500 条 | 接近 5000 条纯监督的水平 |
| 10 万条 | 500-2000 条 | 接近 1-2 万条纯监督的水平 |
| 100 万条 | 2000-5000 条 | 结合自监督预训练效果显著 |
选方法的三个判断维度:
- 标注比例:<0.5% → 优先考虑自监督预训练;0.5%-5% → 自训练或标签传播;5%-10% → 纯监督可能已经够用
- 数据维度:高维稀疏(文本词袋)→ 先降维再做标签传播;低维密集 → 标签传播直接有效
- 错误成本:错误标注代价高(医疗诊断)→ 使用更高的置信度阈值;错误代价低 → 可以降低阈值换取更多伪标签
与模块一的知识衔接:
前文介绍的不平衡数据处理、聚类算法、降维方法都可以与半监督学习形成协同:
- 聚类:先聚类发现自然分组,再用少量标注标记每个簇的代表样本,通过标签传播完成全量标注------"主动学习 + 半监督"的组合
- 降维:高维数据先用 PCA/UMAP 降维,再在低维空间做标签传播,图的连接质量大幅提升
- 不平衡:半监督在少数类标注样本极少时作用更大,但伪标签选取需要对每个类别单独设置置信度阈值
如果这篇文章对理解半监督学习有所帮助,欢迎点赞收藏。关注账号,更新不会错过。
前文推荐: