拒绝显存溢出!手把手教你写原生 4K 超分辨率数据集 (SurgiSR4K) 的 PyTorch DataLoader

前言

在医疗影像的超分辨率(Super-Resolution, SR)研究中,高质量的开源数据一直是稀缺资源。近期由香港中文大学、直觉外科公司等机构联合推出的 SurgiSR4K 数据集打破了这一僵局。作为首个针对机器人辅助微创手术的原生 4K (3840×2160) 视频数据集,它不仅提供了完美的亚像素级对齐,还涵盖了反光、遮挡、烟雾等极具挑战性的真实手术场景。

然而,面对 4K 分辨率的庞大张量,传统的 PyTorch 数据加载方式会瞬间导致显存溢出(OOM)。本文将带你从零构建一个专为超分辨率任务定制的、支持**联合随机裁剪(Joint Random Crop)**的高效 DataLoader,帮你平滑开启 4K 医疗影像的算法训练。


核心知识点

在编写超分辨率的 DataLoader 时,我们必须跨越以下三个技术门槛:

  1. 嵌套路径与标签匹配

    SurgiSR4K 采用了按"分辨率"和"视频ID/器械复杂度"嵌套的文件夹结构(如 480x270p/vid_001_480x270p_1tool/)。我们需要在代码中动态解析相对路径,将输入图(LR)精准映射到 4K 标签图(HR)。

  2. 联合随机裁剪 (Joint Random Crop) ------ 突破显存瓶颈

    整张 4K 图片无法直接喂入网络。我们需要在 LR 图像上随机切取小块(如 64×64),并严格根据放大倍率(Scale Factor)在 HR 图像上切取对应的大块(如 8 倍超分下的 512×512)。这不仅解决了显存问题,还是一种极佳的数据扩增手段。

  3. 联合几何变换 (Joint Geometric Transformations)

    超分任务中的数据增强必须是"绑定"的。LR 做了怎样的水平/垂直翻转或旋转,HR 必须做一模一样的动作,否则会导致模型学出重影。


步骤与核心代码

为了实现上述功能,我们不使用易造成随机状态不一致的常规 transforms,而是引入 torchvision.transforms.functional 进行底层的精细控制。

以下是完整且可直接用于工程化训练的 dataset.py 核心代码:

Python

ini 复制代码
import random
from pathlib import Path
from PIL import Image
import torchvision.transforms.functional as TF
import torch
from torch.utils.data import Dataset, DataLoader

class SurgiSR4KDataset(Dataset):
    """
    支持联合随机裁剪与同步数据增强的 SurgiSR4K Dataset
    """
    def __init__(self, data_root, lr_res="480x270p", hr_res="3840x2160p", 
                 scale_factor=8, lr_patch_size=64, is_train=True):
        self.data_root = Path(data_root)
        self.lr_res = lr_res
        self.hr_res = hr_res
        self.scale_factor = scale_factor
        self.lr_patch_size = lr_patch_size
        self.is_train = is_train
        
        self.lr_dir = self.data_root / self.lr_res
        self.hr_dir = self.data_root / self.hr_res
        
        # 使用 rglob 递归搜索所有子文件夹中的 png 图片
        self.lr_image_paths = sorted(list(self.lr_dir.rglob("*.png")))
        if not self.lr_image_paths:
            raise ValueError(f"在 {self.lr_dir} 及其子文件夹中未找到图像!")

    def __len__(self):
        return len(self.lr_image_paths)

    def __getitem__(self, idx):
        # 1. 获取 LR 路径并推导 HR 绝对路径
        lr_path = self.lr_image_paths[idx]
        rel_path = lr_path.relative_to(self.lr_dir) # 提取包含子文件夹的相对路径
        hr_rel_path_str = str(rel_path).replace(self.lr_res, self.hr_res)
        hr_path = self.hr_dir / hr_rel_path_str
            
        # 2. 读取图像
        lr_img = Image.open(lr_path).convert("RGB")
        hr_img = Image.open(hr_path).convert("RGB")
        
        if self.is_train:
            # =====================================
            # 3. 联合随机裁剪 (Joint Random Crop)
            # =====================================
            lr_w, lr_h = lr_img.size
            
            # 在 LR 图上随机生成不越界的左上角坐标
            lr_x = random.randint(0, lr_w - self.lr_patch_size)
            lr_y = random.randint(0, lr_h - self.lr_patch_size)
            
            # 严格映射到 HR 图像的坐标和尺寸
            hr_x = lr_x * self.scale_factor
            hr_y = lr_y * self.scale_factor
            hr_patch_size = self.lr_patch_size * self.scale_factor
            
            # 同步执行裁剪
            lr_img = lr_img.crop((lr_x, lr_y, lr_x + self.lr_patch_size, lr_y + self.lr_patch_size))
            hr_img = hr_img.crop((hr_x, hr_y, hr_x + hr_patch_size, hr_y + hr_patch_size))
            
            # =====================================
            # 4. 联合数据增强 (Joint Augmentation)
            # =====================================
            # 随机水平翻转
            if random.random() < 0.5:
                lr_img = TF.hflip(lr_img)
                hr_img = TF.hflip(hr_img)
                
            # 随机垂直翻转
            if random.random() < 0.5:
                lr_img = TF.vflip(lr_img)
                hr_img = TF.vflip(hr_img)
                
            # 随机旋转 (0, 90, 180, 270 度)
            angle = random.choice([0, 90, 180, 270])
            if angle != 0:
                lr_img = TF.rotate(lr_img, angle)
                hr_img = TF.rotate(hr_img, angle)

        # 5. 转换为 Tensor 并归一化到 [0, 1]
        lr_tensor = TF.to_tensor(lr_img)
        hr_tensor = TF.to_tensor(hr_img)
            
        return {"lr": lr_tensor, "hr": hr_tensor}

# 辅助包装函数
def create_surgisr4k_dataloader(data_root, batch_size=4, num_workers=4):
    dataset = SurgiSR4KDataset(
        data_root=data_root, 
        lr_res="480x270p", hr_res="3840x2160p", 
        scale_factor=8, lr_patch_size=64, is_train=True
    )
    return DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)

注意事项与避坑指南

  1. 禁用破坏性增强 :在图像分类中常用的 Color Jitter(色彩抖动)、Random Gaussian Blur(高斯模糊)等增强手段,严禁在 SR 的 HR 标签上使用。SR 模型的任务是学习精确的像素映射,改变标签颜色或模糊度会破坏 Ground Truth。
  2. 多进程读取陷阱 :如果在 Windows 系统下调试代码遇到 DataLoader 报错,请将 num_workers 临时设为 0。在 Linux 服务器上进行正式训练时,建议将其设为 48 并配合 pin_memory=True,以最大化 GPU 吞吐量。
  3. 路径转义问题 :Windows 系统下填写 data_root 路径时,务必在字符串前加 r(如 r"D:\datasets..."),防止 \n\t 被错误转义。

总结

构建一个稳健的数据流水线是训练深度学习模型的第一步。针对 SurgiSR4K 这样的原生 4K 数据集,普通的 DataLoader 是无法胜任的。通过递归路径匹配联合随机裁剪 以及底层函数级的数据增强同步,我们成功解决了显存溢出与标签对齐的难题。

将上述代码保存到你的工程中,调整好 Batch Size,你的超分网络现在就可以畅快地吸入这些高质量的 4K 医疗影像数据了。

你的下一个超分辨率模型准备用哪种网络架构(SwinIR / HAT / RCAN)呢?欢迎在评论区交流!

相关推荐
junjunzai1232 小时前
设置cuda:1但是cuda:0在波动的问题
人工智能·深度学习
智算菩萨4 小时前
多目标超启发式算法系统文献综述:人机协同大语言模型方法论深度精读
论文阅读·人工智能·深度学习·ai·多目标·综述
简单光学4 小时前
ISDM: 基于生成扩散模型的散射介质成像重建技术报告
深度学习·扩散模型·散射成像·分数匹配·随机微分方程
IT阳晨。4 小时前
PyTorch深度学习实践
人工智能·pytorch·深度学习
智算菩萨5 小时前
【How Far Are We From AGI】5 AGI的“道德罗盘“——价值对齐的技术路径与伦理边界
论文阅读·人工智能·深度学习·ai·接口·agi·对齐技术
Sakuraba Ema5 小时前
从零理解 MoE(Mixture of Experts)混合专家:原理、数学、稀疏性、专家数量影响与手写 PyTorch 实现
人工智能·pytorch·python·深度学习·数学·llm·latex
freewlt5 小时前
科技热点速递:AI技术集中爆发
人工智能·深度学习·计算机视觉
南宫乘风6 小时前
LLaMA-Factory 给 Qwen1.5 做 LoRA 微调 实战
人工智能·深度学习·llama
小陈phd6 小时前
多模态大模型学习笔记(二十一)—— 基于 Scaling Law方法 的大模型训练算力估算与 GPU 资源配置
笔记·深度学习·学习·自然语言处理·transformer