【机器学习】机器学习的基本分类-半监督学习ーSemi-supervised SVM

Semi-supervised SVM(S³VM)

Semi-supervised SVM (S³VM) 是一种半监督支持向量机方法,旨在结合标注数据和未标注数据训练模型。相比传统的监督学习 SVM,它在优化过程中利用未标注数据的分布信息,使分类边界更符合未标注数据的结构。


S³VM 的核心思想

  1. 监督学习部分

    • 使用标注数据训练分类模型,找到最优分类超平面。

    • 损失函数与标准 SVM 相同:

  2. 无监督学习部分

    • 未标注数据假定类别未知,需要找到分类边界,使未标注数据远离分类边界(边界平滑假设)。

    • 给未标注数据分配伪标签,并最小化伪标签分类的损失:

      其中,伪标签

  3. 目标函数

    • 将监督损失和无监督损失结合:

    • 是权重参数,用于平衡标注数据与未标注数据的贡献。

  4. 优化目标

    • 寻找分类边界,使得标注数据点被正确分类,同时未标注数据点尽可能远离分类边界。

关键技术

  1. 伪标签分配

    • 在训练过程中为未标注数据分配伪标签
    • 初始伪标签可以基于最近邻分类器或聚类结果生成。
    • 伪标签会随模型更新而动态调整。
  2. 边界平滑

    • 假设未标注数据的分布是分段分离的。
    • 分类边界应穿过数据密度较低的区域,避免高密度区域。
  3. 优化问题的复杂性

    • S³VM 的优化问题是一个非凸问题,直接求解较困难。
    • 常用优化方法:
      • 梯度下降法。
      • 分段优化:交替优化 w, b 和伪标签

算法流程

  1. 初始化模型参数 w,bw, bw,b 和未标注数据的伪标签
  2. 优化带有伪标签的目标函数:
    • 固定伪标签 ,优化分类器 w, b。
    • 固定分类器 w, b,重新分配伪标签
  3. 迭代上述过程,直至收敛。
  4. 输出最终分类器

优势与局限性

优势
  • 在标注数据稀缺时,可以利用未标注数据提高分类性能。
  • 保留 SVM 的优点,如决策边界清晰和泛化能力强。
局限性
  1. 非凸优化问题
    • S³VM 的目标函数非凸,可能陷入局部最优。
    • 需要设计高效的优化算法。
  2. 伪标签敏感性
    • 错误的伪标签可能降低模型性能。
    • 初始伪标签和更新策略的选择十分重要。
  3. 计算复杂度高
    • 需要频繁优化伪标签和模型参数,计算成本较高。

Python 示例

以下是使用 scikit-learn 和自定义代码实现一个简单的 S³VM 示例:

复制代码
python 复制代码
import numpy as np
from sklearn.svm import SVC
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split

# 数据生成
X, y = make_classification(
    n_samples=100, n_features=5, n_classes=2,
    n_informative=2, n_redundant=1, n_repeated=0,
    random_state=42
)

# 数据拆分
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

# 模拟部分未标注数据
num_unlabeled = int(0.5 * len(y_train))
unlabeled_idx = np.random.choice(len(y_train), size=num_unlabeled, replace=False)
y_train[unlabeled_idx] = -1  # -1 表示未标注数据

# 初始化分类器
model = SVC(kernel='linear', C=1.0, probability=True)

# 半监督训练
for _ in range(5):  # 迭代 5 次
    # 使用标注数据训练模型
    labeled_mask = y_train != -1
    model.fit(X_train[labeled_mask], y_train[labeled_mask])

    # 为未标注数据生成伪标签
    unlabeled_mask = y_train == -1
    pseudo_labels = model.predict(X_train[unlabeled_mask])

    # 根据置信度筛选伪标签
    pseudo_probs = model.predict_proba(X_train[unlabeled_mask]).max(axis=1)
    high_confidence_idx = pseudo_probs > 0.8  # 置信度阈值
    y_train[unlabeled_mask][high_confidence_idx] = pseudo_labels[high_confidence_idx]

# 测试模型
accuracy = model.score(X_test, y_test)
print(f"Accuracy: {accuracy:.2f}")

输出结果

python 复制代码
Accuracy: 0.97

应用场景

  1. 文本分类
    • 互联网评论或情感分析数据中,标注数据较少但未标注文本较多。
  2. 医学影像分析
    • 标注的病灶数据有限,结合未标注数据进行诊断。
  3. 网页分类
    • 大量未标注网页数据可用,通过半监督学习提高分类性能。

总结

S³VM 是半监督学习的一种重要方法,特别适用于标注数据稀缺、未标注数据丰富的场景。尽管其优化问题复杂,但通过合适的初始化和优化方法,S³VM 能有效提升分类性能。

相关推荐
martian6653 分钟前
支持向量机(SVM)深度解析:从数学根基到工程实践
算法·机器学习·支持向量机
FF-Studio31 分钟前
【硬核数学 · LLM篇】3.1 Transformer之心:自注意力机制的线性代数解构《从零构建机器学习、深度学习到LLM的数学认知》
人工智能·pytorch·深度学习·线性代数·机器学习·数学建模·transformer
贾全1 小时前
第十章:HIL-SERL 真实机器人训练实战
人工智能·深度学习·算法·机器学习·机器人
GIS小天1 小时前
AI+预测3D新模型百十个定位预测+胆码预测+去和尾2025年7月4日第128弹
人工智能·算法·机器学习·彩票
我是小哪吒2.02 小时前
书籍推荐-《对抗机器学习:攻击面、防御机制与人工智能中的学习理论》
人工智能·深度学习·学习·机器学习·ai·语言模型·大模型
慕婉03072 小时前
深度学习前置知识全面解析:从机器学习到深度学习的进阶之路
人工智能·深度学习·机器学习
蓝婷儿3 小时前
Python 机器学习核心入门与实战进阶 Day 2 - KNN(K-近邻算法)分类实战与调参
python·机器学习·近邻算法
麻雀无能为力4 小时前
CAU数据挖掘 支持向量机
人工智能·支持向量机·数据挖掘·中国农业大学计算机
IT古董4 小时前
【第二章:机器学习与神经网络概述】04.回归算法理论与实践 -(3)决策树回归模型(Decision Tree Regression)
神经网络·机器学习·回归
烟锁池塘柳06 小时前
【大模型】解码策略:Greedy Search、Beam Search、Top-k/Top-p、Temperature Sampling等
人工智能·深度学习·机器学习