偏标记学习+图像分类(论文复现)

偏标记学习+图像分类(论文复现)

本文所涉及所有资源均在传知代码平台可获取

文章目录

    • 偏标记学习+图像分类(论文复现)
        • 概述
        • 算法原理
        • 核心逻辑
        • 效果演示
        • 使用方式
概述

本文复现论文提出的偏标记学习方法,随着深度神经网络的发展,机器学习任务对标注数据的需求不断增加。然而,大量的标注数据十分依赖人力资源与标注者的专业知识。弱监督学习可以有效缓解这一问题,因其不需要完全且准确的标注数据。该论文关注一个重要的弱监督学习问题------偏标记学习(Partial Label Learning),其中每个训练实例与一组候选标签相关联,但仅有一个标签是真实的

该论文提出了一种渐进式真实标签识别方法,旨在训练过程中逐渐确定样本的真实标签。该论文所提出的方法获得了接近监督学习的性能,且与具体的网络结构、损失函数、随机优化算法无关

算法原理

传统的监督学习常用交叉熵损失和随机梯度下降来优化深度神经网络。交叉熵损失定义如下

核心逻辑

具体的核心逻辑如下所示:

bash 复制代码
import models
import datasets
import torch
from torch.utils.data import DataLoader
import numpy as np
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
import torch.nn.functional as F
import torchvision.transforms as transforms
from tqdm import tqdm

def CE_loss(probs, targets):
    """交叉熵损失函数"""
    loss = -torch.sum(targets * torch.log(probs), dim = -1)
    loss_avg = torch.sum(loss)/probs.shape[0]
    return loss_avg

class Proden:
    def __init__(self, configs):
        self.configs = configs
    
    def train(self, save = False):
        configs = self.configs
        # 读取数据集
        dataset_path = configs['dataset path']
        if configs['dataset'] == 'CIFAR-10':
            train_data, train_labels, test_data, test_labels = datasets.cifar10_read(dataset_path)
            train_dataset = datasets.Cifar(train_data, train_labels)
            test_dataset = datasets.Cifar(test_data, test_labels)
            output_dimension = 10
        elif configs['dataset'] == 'CIFAR-100':
            train_data, train_labels, test_data, test_labels = datasets.cifar100_read(dataset_path)
            train_dataset = datasets.Cifar(train_data, train_labels)
            test_dataset = datasets.Cifar(test_data, test_labels)
            output_dimension = 100
        # 生成偏标记
        partial_labels = datasets.generate_partial_labels(train_labels, configs['partial rate'])
        train_dataset.load_partial_labels(partial_labels)
        # 计算数据的均值和方差,用于模型输入的标准化
        mean = [np.mean(train_data[:, i, :, :]) for i in range(3)]
        std = [np.std(train_data[:, i, :, :]) for i in range(3)]
        normalize = transforms.Normalize(mean, std)
        # 设备:GPU或CPU
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        # 加载模型
        if configs['model'] == 'ResNet18':
            model = models.ResNet18(output_dimension = output_dimension).to(device)
        elif configs['model'] == 'ConvNet':
            model = models.ConvNet(output_dimension = output_dimension).to(device)
        # 设置学习率等超参数
        lr = configs['learning rate']
        weight_decay = configs['weight decay']
        momentum = configs['momentum']
        optimizer = optim.SGD(model.parameters(), lr = lr, weight_decay = weight_decay, momentum = momentum)
        lr_step = configs['learning rate decay step']
        lr_decay = configs['learning rate decay rate']
        lr_scheduler = StepLR(optimizer, step_size=lr_step, gamma=lr_decay)
        for epoch_id in range(configs['epoch count']):
            # 训练模型
            train_dataloader = DataLoader(train_dataset, batch_size = configs['batch size'], shuffle = True)
            model.train()
            for batch in tqdm(train_dataloader, desc='Training(Epoch %d)' % epoch_id, ascii=' 123456789#'):
                ids = batch['ids']
                # 标准化输入
                data = normalize(batch['data'].to(device))
                partial_labels = batch['partial_labels'].to(device)
                targets = batch['targets'].to(device)
                optimizer.zero_grad()
                # 计算预测概率
                logits = model(data)
                probs = F.softmax(logits, dim=-1)
                # 更新软标签
                with torch.no_grad():
                    new_targets = F.normalize(probs * partial_labels, p=1, dim=-1)
                    train_dataset.targets[ids] = new_targets.cpu().numpy()
                # 计算交叉熵损失
                loss = CE_loss(probs, targets)
                loss.backward()
                # 更新模型参数
                optimizer.step()
            # 调整学习率
            lr_scheduler.step()
效果演示

我提前在 CIFAR-10[2] 数据集和 12 层的 ConvNet[3] 网络上训练了一份模型参数。为了测试其准确率,需要配置环境并运行main.py脚本,得到结果如下

由图可见,该算法在测试集上获得了 89.8% 的准确率。

进一步地,测试训练出的模型在真实图片上的预测结果。在线部署模型后,将一张轮船的图片输入,可以得到输出的预测类型为 "Ship":

使用方式

解压附件压缩包并进入工作目录。如果是Linux系统,请使用如下命令:

bash 复制代码
unzip Proden-implemention.zip
cd Proden-implemention

代码的运行环境可通过如下命令进行配置

bash 复制代码
pip install -r requirements.txt

运行如下命令以下载并解压数据集

bash 复制代码
bash download.sh

如果希望在本地训练模型,请运行如下命令

bash 复制代码
python main.py -c [你的配置文件路径] -r [选择下者之一:"train"、"test"、"infer"]

如果希望在线部署,请运行如下命令

bash 复制代码
python main-flask.py

文章代码资源点击附件获取

相关推荐
吴法刚3 小时前
14-Hugging Face 模型微调训练(基于 BERT 的中文评价情感分析(二分类))
人工智能·深度学习·自然语言处理·分类·langchain·bert·langgraph
viperrrrrrrrrr73 小时前
大数据学习(105)-Hbase
大数据·学习·hbase
随风飘摇的土木狗5 小时前
【MATLAB第114期】基于MATLAB的SHAP可解释神经网络分类模型(敏感性分析方法)
神经网络·matlab·分类·全局敏感性分析·gsa·敏感性分析·shap
行思理5 小时前
go语言应该如何学习
开发语言·学习·golang
拓端研究室TRL6 小时前
Python贝叶斯回归、强化学习分析医疗健康数据拟合截断删失数据与参数估计3实例
开发语言·人工智能·python·数据挖掘·回归
oceanweave7 小时前
【k8s学习之CSI】理解 LVM 存储概念和相关操作
学习·容器·kubernetes
吴梓穆8 小时前
UE5学习笔记 FPS游戏制作43 UI材质
笔记·学习·ue5
学会870上岸华师9 小时前
c语言学习16——内存函数
c语言·开发语言·学习
XYN619 小时前
【嵌入式面试】
笔记·python·单片机·嵌入式硬件·学习
啊哈哈哈哈哈啊哈哈9 小时前
R3打卡——tensorflow实现RNN心脏病预测
人工智能·深度学习·学习