pytorch实现半监督学习

人工智能例子汇总:AI常见的算法和例子-CSDN博客

半监督学习(Semi-Supervised Learning,SSL)结合了有监督学习和无监督学习的特点,通常用于部分数据有标签、部分数据无标签的场景。其主要步骤如下:

1. 数据准备

  • 有标签数据(Labeled Data):数据集的一部分带有真实的类别标签。
  • 无标签数据(Unlabeled Data):数据集的另一部分没有标签,仅有特征信息。
  • 数据预处理:对数据进行清理、标准化、特征工程等处理,以保证数据质量。

2. 选择半监督学习方法

常见的半监督学习方法包括:

  • 基于生成模型(Generative Models):如高斯混合模型(GMM)、变分自编码器(VAE)。
  • 基于一致性正则化(Consistency Regularization):如 MixMatch、FixMatch,利用数据增强来约束模型预测一致性。
  • 基于伪标签(Pseudo-Labeling):先用模型预测无标签数据的类别,然后将高置信度的预测作为新标签加入训练。
  • 图神经网络(Graph-Based Methods):如 Label Propagation,通过构造数据之间的图结构传播标签信息。

3. 训练初始模型

  • 仅使用有标签数据训练一个初始模型。
  • 选择合适的损失函数,如交叉熵损失(Cross-Entropy Loss)或均方误差(MSE Loss)。
  • 训练过程中可以使用数据增强、正则化等优化策略。

4. 利用无标签数据增强训练

  • 伪标签方法:用初始模型对无标签数据进行预测,筛选高置信度样本,加入有标签数据训练。
  • 一致性正则化:对无标签数据进行不同变换,要求模型的预测结果一致。
  • 联合训练:构造有监督损失(Supervised Loss)和无监督损失(Unsupervised Loss),综合优化。

5. 模型迭代更新

  • 重新利用训练后的模型预测无标签数据,产生新的伪标签或调整模型参数。
  • 通过半监督策略不断优化模型,使其对无标签数据的预测更加稳定。

6. 评估和测试

  • 使用测试集(通常是有标签的数据)评估模型性能。
  • 选择合适的评估指标,如准确率(Accuracy)、F1-score、AUC-ROC 等。

7. 调优和部署

  • 根据实验结果调整超参数,如伪标签置信度阈值、学习率等。
  • 结合业务需求,将最终模型部署到实际应用中。

关键步骤:

  1. 初始化模型:首先使用有标签数据训练模型。

  2. 生成伪标签:用训练好的模型对无标签数据进行预测,生成伪标签。

  3. 结合有标签和伪标签数据进行训练:用带有标签和无标签(伪标签)数据一起训练模型。

  4. 迭代训练:不断迭代,使用更新的模型生成新的伪标签,进一步优化模型。

    import torch
    import torch.nn as nn
    import torch.optim as optim
    import torch.nn.functional as F
    from torch.utils.data import DataLoader, Dataset
    import matplotlib.pyplot as plt

    简化的神经网络模型

    class SimpleCNN(nn.Module):
    def init(self):
    super(SimpleCNN, self).init()
    self.conv1 = nn.Conv2d(1, 8, kernel_size=3) # 缩小卷积层的输出通道
    self.fc1 = nn.Linear(8 * 26 * 26, 10) # 调整全连接层的输入和输出尺寸

     def forward(self, x):
         x = F.relu(self.conv1(x))
         x = x.view(x.size(0), -1)  # 展平
         x = self.fc1(x)
         return x
    

    自定义数据集

    class CustomDataset(Dataset):
    def init(self, data, labels=None):
    self.data = data
    self.labels = labels

     def __len__(self):
         return len(self.data)
    
     def __getitem__(self, idx):
         if self.labels is not None:
             return self.data[idx], self.labels[idx]
         else:
             return self.data[idx], -1  # 无标签数据
    

    半监督训练函数

    def pseudo_labeling_training(model, labeled_loader, unlabeled_loader, optimizer, device, threshold=0.95):
    model.train()
    labeled_loss_value = 0
    pseudo_loss_value = 0

     for (labeled_data, labeled_labels), (unlabeled_data, _) in zip(labeled_loader, unlabeled_loader):
         labeled_data, labeled_labels = labeled_data.to(device), labeled_labels.to(device)
         unlabeled_data = unlabeled_data.to(device)
    
         # 1. 有标签数据训练
         optimizer.zero_grad()
         labeled_output = model(labeled_data)
         labeled_loss = F.cross_entropy(labeled_output, labeled_labels)
         labeled_loss.backward()
    
         # 2. 无标签数据伪标签生成
         unlabeled_output = model(unlabeled_data)
         probs = F.softmax(unlabeled_output, dim=1)
         max_probs, pseudo_labels = torch.max(probs, dim=1)
    
         # 伪标签置信度筛选
         pseudo_mask = max_probs > threshold  # 置信度大于阈值的数据作为伪标签
         if pseudo_mask.sum() > 0:
             pseudo_labels = pseudo_labels[pseudo_mask]
             unlabeled_data_pseudo = unlabeled_data[pseudo_mask]
    
             # 3. 使用伪标签数据进行训练(确保无标签数据参与反向传播)
             optimizer.zero_grad()  # 清除之前的梯度
             pseudo_output = model(unlabeled_data_pseudo)
             pseudo_loss = F.cross_entropy(pseudo_output, pseudo_labels)
             pseudo_loss.backward()  # 计算反向梯度
    
         optimizer.step()  # 更新模型参数
    
         # 累加损失用于展示
         labeled_loss_value += labeled_loss.item()
         if pseudo_mask.sum() > 0:
             pseudo_loss_value += pseudo_loss.item()
    
     return labeled_loss_value / len(labeled_loader), pseudo_loss_value / len(unlabeled_loader)
    

    模拟数据

    num_labeled = 1000
    num_unlabeled = 5000
    data_dim = (1, 28, 28) # 28x28 灰度图像
    num_classes = 10

    labeled_data = torch.randn(num_labeled, *data_dim)
    labeled_labels = torch.randint(0, num_classes, (num_labeled,))
    unlabeled_data = torch.randn(num_unlabeled, *data_dim)

    labeled_dataset = CustomDataset(labeled_data, labeled_labels)
    unlabeled_dataset = CustomDataset(unlabeled_data)

    labeled_loader = DataLoader(labeled_dataset, batch_size=32, shuffle=True) # 缩小批量大小
    unlabeled_loader = DataLoader(unlabeled_dataset, batch_size=32, shuffle=True) # 缩小批量大小

    模型、优化器和设备设置

    device = torch.device("cpu") # 临时使用 CPU
    model = SimpleCNN().to(device)
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    训练过程并记录损失

    num_epochs = 10
    labeled_losses = []
    pseudo_losses = []

    for epoch in range(num_epochs):
    labeled_loss, pseudo_loss = pseudo_labeling_training(model, labeled_loader, unlabeled_loader, optimizer, device)
    labeled_losses.append(labeled_loss)
    pseudo_losses.append(pseudo_loss)
    print(f"Epoch [{epoch + 1}/{num_epochs}] | Labeled Loss: {labeled_loss:.4f} | Pseudo Loss: {pseudo_loss:.4f}")

    绘制损失曲线

    plt.plot(range(num_epochs), labeled_losses, label='Labeled Loss')
    plt.plot(range(num_epochs), pseudo_losses, label='Pseudo Label Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.title('Training Losses Over Epochs')
    plt.show()

    展示伪标签生成效果(可视化一些样本的伪标签预测结果)

    model.eval()
    with torch.no_grad():
    sample_unlabeled_data = unlabeled_data[:10].to(device)
    output = model(sample_unlabeled_data)
    probs = F.softmax(output, dim=1)
    _, predicted_labels = torch.max(probs, dim=1)

     # 展示预测的标签
     print("Generated Pseudo Labels for Samples:")
     print(predicted_labels)
    
     # 假设这些是伪标签预测的图片
     fig, axes = plt.subplots(2, 5, figsize=(12, 5))
     for i, ax in enumerate(axes.flat):
         # 将tensor转换为NumPy数组
         img = sample_unlabeled_data[i].cpu().numpy().squeeze()  # 转为NumPy数组
         ax.imshow(img, cmap='gray')  # 使用灰度显示图像
         ax.set_title(f"Pred: {predicted_labels[i].item()}")
         ax.axis('off')
     plt.show()
    
相关推荐
weixin_3077791319 分钟前
Apache Iceberg数据湖技术在海量实时数据处理、实时特征工程和模型训练的应用技术方案和具体实施步骤及代码
大数据·人工智能·语言模型·音视频
知识鱼丸1 小时前
自定义数据集 使用scikit-learn中svm的包实现svm分类
人工智能
说私域2 小时前
基于开源AI智能名片2 + 1链动模式S2B2C商城小程序视角下的个人IP人设构建研究
人工智能·小程序·开源
山海青风2 小时前
OpenAI 实战进阶教程 - 第七节: 与数据库集成 - 生成 SQL 查询与优化
数据库·人工智能·python·sql
Chatopera 研发团队3 小时前
计算图 Compute Graph 和自动求导 Autograd | PyTorch 深度学习实战
人工智能·pytorch·深度学习
YuLiu123213 小时前
Vue3学习笔记-Vue开发前准备-1
vue.js·笔记·学习
Bluesonli3 小时前
UE5 蓝图学习计划 - Day 11:材质与特效
学习·ue5·虚幻·材质·虚幻引擎·unreal engine
白白糖3 小时前
深度学习 Pytorch 基础网络手动搭建与快速实现
人工智能·pytorch·深度学习
AI浩3 小时前
【Block总结】HWD,小波下采样,适用分类、分割、目标检测等任务|即插即用
人工智能·目标检测·分类