偏标记学习+图像分类(论文复现)

偏标记学习+图像分类(论文复现)

本文所涉及所有资源均在传知代码平台可获取

文章目录

    • 偏标记学习+图像分类(论文复现)
        • 概述
        • 算法原理
        • 核心逻辑
        • 效果演示
        • 使用方式
概述

本文复现论文提出的偏标记学习方法,随着深度神经网络的发展,机器学习任务对标注数据的需求不断增加。然而,大量的标注数据十分依赖人力资源与标注者的专业知识。弱监督学习可以有效缓解这一问题,因其不需要完全且准确的标注数据。该论文关注一个重要的弱监督学习问题------偏标记学习(Partial Label Learning),其中每个训练实例与一组候选标签相关联,但仅有一个标签是真实的

该论文提出了一种渐进式真实标签识别方法,旨在训练过程中逐渐确定样本的真实标签。该论文所提出的方法获得了接近监督学习的性能,且与具体的网络结构、损失函数、随机优化算法无关

算法原理

传统的监督学习常用交叉熵损失和随机梯度下降来优化深度神经网络。交叉熵损失定义如下

核心逻辑

具体的核心逻辑如下所示:

bash 复制代码
import models
import datasets
import torch
from torch.utils.data import DataLoader
import numpy as np
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
import torch.nn.functional as F
import torchvision.transforms as transforms
from tqdm import tqdm

def CE_loss(probs, targets):
    """交叉熵损失函数"""
    loss = -torch.sum(targets * torch.log(probs), dim = -1)
    loss_avg = torch.sum(loss)/probs.shape[0]
    return loss_avg

class Proden:
    def __init__(self, configs):
        self.configs = configs
    
    def train(self, save = False):
        configs = self.configs
        # 读取数据集
        dataset_path = configs['dataset path']
        if configs['dataset'] == 'CIFAR-10':
            train_data, train_labels, test_data, test_labels = datasets.cifar10_read(dataset_path)
            train_dataset = datasets.Cifar(train_data, train_labels)
            test_dataset = datasets.Cifar(test_data, test_labels)
            output_dimension = 10
        elif configs['dataset'] == 'CIFAR-100':
            train_data, train_labels, test_data, test_labels = datasets.cifar100_read(dataset_path)
            train_dataset = datasets.Cifar(train_data, train_labels)
            test_dataset = datasets.Cifar(test_data, test_labels)
            output_dimension = 100
        # 生成偏标记
        partial_labels = datasets.generate_partial_labels(train_labels, configs['partial rate'])
        train_dataset.load_partial_labels(partial_labels)
        # 计算数据的均值和方差,用于模型输入的标准化
        mean = [np.mean(train_data[:, i, :, :]) for i in range(3)]
        std = [np.std(train_data[:, i, :, :]) for i in range(3)]
        normalize = transforms.Normalize(mean, std)
        # 设备:GPU或CPU
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        # 加载模型
        if configs['model'] == 'ResNet18':
            model = models.ResNet18(output_dimension = output_dimension).to(device)
        elif configs['model'] == 'ConvNet':
            model = models.ConvNet(output_dimension = output_dimension).to(device)
        # 设置学习率等超参数
        lr = configs['learning rate']
        weight_decay = configs['weight decay']
        momentum = configs['momentum']
        optimizer = optim.SGD(model.parameters(), lr = lr, weight_decay = weight_decay, momentum = momentum)
        lr_step = configs['learning rate decay step']
        lr_decay = configs['learning rate decay rate']
        lr_scheduler = StepLR(optimizer, step_size=lr_step, gamma=lr_decay)
        for epoch_id in range(configs['epoch count']):
            # 训练模型
            train_dataloader = DataLoader(train_dataset, batch_size = configs['batch size'], shuffle = True)
            model.train()
            for batch in tqdm(train_dataloader, desc='Training(Epoch %d)' % epoch_id, ascii=' 123456789#'):
                ids = batch['ids']
                # 标准化输入
                data = normalize(batch['data'].to(device))
                partial_labels = batch['partial_labels'].to(device)
                targets = batch['targets'].to(device)
                optimizer.zero_grad()
                # 计算预测概率
                logits = model(data)
                probs = F.softmax(logits, dim=-1)
                # 更新软标签
                with torch.no_grad():
                    new_targets = F.normalize(probs * partial_labels, p=1, dim=-1)
                    train_dataset.targets[ids] = new_targets.cpu().numpy()
                # 计算交叉熵损失
                loss = CE_loss(probs, targets)
                loss.backward()
                # 更新模型参数
                optimizer.step()
            # 调整学习率
            lr_scheduler.step()
效果演示

我提前在 CIFAR-10[2] 数据集和 12 层的 ConvNet[3] 网络上训练了一份模型参数。为了测试其准确率,需要配置环境并运行main.py脚本,得到结果如下

由图可见,该算法在测试集上获得了 89.8% 的准确率。

进一步地,测试训练出的模型在真实图片上的预测结果。在线部署模型后,将一张轮船的图片输入,可以得到输出的预测类型为 "Ship":

使用方式

解压附件压缩包并进入工作目录。如果是Linux系统,请使用如下命令:

bash 复制代码
unzip Proden-implemention.zip
cd Proden-implemention

代码的运行环境可通过如下命令进行配置

bash 复制代码
pip install -r requirements.txt

运行如下命令以下载并解压数据集

bash 复制代码
bash download.sh

如果希望在本地训练模型,请运行如下命令

bash 复制代码
python main.py -c [你的配置文件路径] -r [选择下者之一:"train"、"test"、"infer"]

如果希望在线部署,请运行如下命令

bash 复制代码
python main-flask.py

文章代码资源点击附件获取

相关推荐
AAA阿giao8 分钟前
从零开始学 React:用搭积木的方式构建你的第一个网页!
前端·javascript·学习·react.js·前端框架·vite·jsx
Arciab18 分钟前
C++ 学习_流程控制
c++·学习
HyperAI超神经18 分钟前
【vLLM 学习】vLLM TPU 分析
开发语言·人工智能·python·学习·大语言模型·vllm·gpu编程
xiaoxiaoxiaolll41 分钟前
前沿速递 | Adv. Eng. Mater.:基于LPBF与压力渗透的FeSi2.9-Bakelite多功能复合材料设计与性能调控
学习
Freshman小白1 小时前
《人工智能与创新》网课答案2025
人工智能·学习·答案·网课答案
Y_fulture1 小时前
datawhale组队学习:第一章习题
学习·机器学习·概率论
阿蒙Amon1 小时前
JavaScript学习笔记:15.迭代器与生成器
javascript·笔记·学习
来两个炸鸡腿1 小时前
DW动手学大模型应用全栈开发 - (1)大模型应用开发应知必会
python·深度学习·学习·nlp
小徐不会敲代码~1 小时前
Vue3 学习2
前端·javascript·学习
我命由我123451 小时前
Python Flask 开发 - Flask 快速上手(Flask 最简单的案例、Flask 处理跨域、Flask 基础接口)
服务器·开发语言·后端·python·学习·flask·学习方法