半监督

实际上就是在加载dataloader那里做了调整,采样器

这段代码定义了一个名为create_data_loaders的函数,用于创建训练集和验证集的数据加载器。

复制代码
def create_data_loaders(train_transform, eval_transform, datadir, config):
    traindir = os.path.join(datadir, config.train_subdir)
    trainset = torchvision.datasets.ImageFolder(traindir, train_transform)

首先,将训练集的路径拼接起来,然后使用torchvision.datasets.ImageFolder函数加载训练集。ImageFolder是一个用于处理图像文件夹数据集的类,它假设图像文件夹的结构是按照类别分组的,每个类别的图像放在对应的子文件夹中。

复制代码
    if config.labels:
        with open(config.labels) as f:
            labels = dict(line.split(' ') for line in f.read().splitlines())
        labeled_idxs, unlabeled_idxs = datasets.relabel_dataset(trainset, labels)

如果配置中提供了标签文件的路径config.labels,则打开标签文件并将其读取为一个字典。标签文件中的每一行包含图像文件名和对应的标签,通过空格分隔。relabel_dataset函数根据标签文件将训练集中的样本分为有标签和无标签样本,并返回有标签样本的索引和无标签样本的索引。

复制代码
    assert len(trainset.imgs) == len(labeled_idxs) + len(unlabeled_idxs)

确保有标签样本和无标签样本的数量与训练集中的总样本数量相等。

复制代码
    if config.labeled_batch_size < config.batch_size:
        assert len(unlabeled_idxs) > 0
        batch_sampler = datasets.TwoStreamBatchSampler(unlabeled_idxs, labeled_idxs, config.batch_size, config.labeled_batch_size)
    else:
        sampler = SubsetRandomSampler(labeled_idxs)
        batch_sampler = BatchSampler(sampler, config.batch_size, drop_last=True)

根据配置中的有标签批次大小config.labeled_batch_size和总批次大小config.batch_size,决定使用哪种批次采样方式。如果有标签批次大小小于总批次大小,将使用datasets.TwoStreamBatchSampler创建一个两流批次采样器,该采样器在每个批次中同时包含有标签和无标签样本。否则,将使用SubsetRandomSampler创建一个只包含有标签样本的采样器。

复制代码
    train_loader = torch.utils.data.DataLoader(trainset, batch_sampler=batch_sampler, num_workers=config.workers, pin_memory=True)

使用torch.utils.data.DataLoader创建训练集的数据加载器,其中采用了上面创建的批次采样器。num_workers参数指定了用于数据加载的子进程数量,pin_memory=True表示将数据加载到固定的内存区域,可以加速数据传输。

复制代码
    evaldir = os.path.join(datadir, config.eval_subdir)
    evalset = torchvision.datasets.ImageFolder(evaldir, eval_transform)
    eval_loader = torch.utils.data.DataLoader(evalset, batch_size=config.batch_size, shuffle=False, num_workers=2*config.workers, pin_memory=True, drop_last=False)

接下来,将验证集的路径拼接起来,然后使用torchvision.datasets.ImageFolder加载验证集。与训练集类似,也使用torch.utils.data.DataLoader创建验证集的数据加载器。

最后,将训练集和验证集的数据加载器作为结果返回。

这段代码的作用是根据配置中的设置,创建训练集和验证集的数据加载器。在半监督学习中,训练集中的样本被分为有标签和无标签样本,并使用不同的批次采样方式对它们进行训练。

相关推荐
上弦月-编程3 分钟前
指针编程:高效内存管理核心
java·数据结构·算法
贾红平3 分钟前
Python装饰器实战指南
python
罗超驿4 分钟前
双指针算法经典案例:LeetCode 283. 移动零(Java详解)
java·算法·leetcode
qcx235 分钟前
拆解 Warp AI Agent(五):跨生态联邦——10 种 Skill + MCP + 多 Harness 互操作设计
人工智能·rust·ai agent·skill·warp·mcp·harness
清水白石0086 分钟前
深入 Python 循环引用与垃圾回收:如何应对内存管理的挑战
java·jvm·python
生成论实验室6 分钟前
《事件关系阴阳博弈动力学:识势应势之道》第五篇:安全关键关系——故障、障碍与冲突
运维·服务器·人工智能·安全·架构
weixin_446260857 分钟前
应用实战篇:利用 DeepSeek V4 构建生产级 AI 应用的全流程与最佳实践
大数据·linux·人工智能
AI科技星9 分钟前
全域数学视角下N维广义数系的推广与本源恒等式构建【乖乖数学】
人工智能·机器学习·数学建模·数据挖掘
qcx2310 分钟前
拆解 Warp AI Agent(二):风险分级执行——Agent 如何做到安全并行、危险排队
人工智能·安全·ai·agent·源码解析·warp
小白蒋博客11 分钟前
【ai开发段永平投资理财的知识图谱网站】第一天:搭 Vite + Vue 项目,跑通 Hello World
vue.js·人工智能·trae