【pytorch】数据增强与时俱进,未来的改进和功能将仅添加到 torchvision.transforms.v2 转换中

  • 大多数转换都接受 PIL 图像和张量输入。支持 CPU 和 CUDA 张量。两种后端(PIL 或张量)的结果应该非常接近。一般来说,我们建议依赖张量后端以获得性能。可以使用转换转换在 PIL 图像之间或转换 dtypes 和范围。张量图像的形状应为 (C, H, W),其中 C 是通道数,H 和 W 分别指高度和宽度。大多数转换支持批量张量输入。一批张量图像的形状为 (N, C, H, W),其中 N 是批次中的图像数量。v2 转换通常接受任意数量的前导维度 (..., C, H, W),并且可以处理批量图像或批量视频 。张量图像值的期望范围由张量 dtype 隐式定义。具有 float dtype 的张量图像的值应在 [0, 1] 范围内。具有 integer dtype 的张量图像的值应在 [0, MAX_DTYPE] 范围内,其中 MAX_DTYPE 是该 dtype 中可表示的最大值。通常,dtype 为 torch.uint8 的图像的值应在 [0, 255] 范围内。

  • torchvision.transforms.v2 命名空间中发布了一套新的转换。与 v1(在 torchvision.transforms 中)相比,这些转换具有许多优势:它们不仅可以转换图像,还可以 转换边界框、掩码或视频。这为图像分类以外的任务(如检测、分割、视频分类等)提供了支持。它们支持更多转换,例如 CutMixMixUp。支持任意输入结构(dicts、lists、tuples 等)。更快。未来的改进和功能将仅添加到 v2 转换中 。推荐以下指南以从转换中获得最佳性能:依赖 torchvision.transforms.v2 中的 v2 转换,使用张量而不是 PIL 图像,使用 torch.uint8 dtype,特别是对于调整大小操作,使用 bilinear 或 bicubic 模式进行调整大小。在依赖于 torch.utils.data.DataLoader 且 num_workers > 0 的典型训练环境中,以上建议应能提供最佳性能。

    python 复制代码
    from torchvision.transforms import v2
    transforms = v2.Compose([
        v2.ToImage(),  # Convert to tensor, only needed if you had a PIL image
        v2.ToDtype(torch.uint8, scale=True),  # optional, most input are already uint8 at this point
        # ...
        v2.RandomResizedCrop(size=(224, 224), antialias=True),  # Or Resize(antialias=True)
        # ...
        v2.ToDtype(torch.float32, scale=True),  # Normalize expects float input
        v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
  • 转换往往对输入 strides / 内存格式敏感。一些转换对于 channels-first 图像会更快,而另一些则更喜欢 channels-last。像 [torch] 运算符一样,大多数转换将保留输入的内存格式,但这可能由于实现细节而无法始终得到遵守。如果您追求最佳性能,可能需要进行一些实验。在单独的转换(例如 [Normalize] )上使用 [torzch.compile()] 也可能有助于排除内存格式变量的影响。

V2 API 参考

几何转换
v2.Resize(size[, interpolation, max_size, ...]) 将输入调整为给定大小。
v2.ScaleJitter(target_size[, scale_range, ...]) 根据"Simple Copy-Paste is a Strong Data Augmentation Method for Instance Segmentation"对输入执行大尺度抖动。
v2.RandomShortestSize(min_size[, max_size, ...]) 随机调整输入大小。
v2.RandomResize(min_size, max_size[, ...]) 随机调整输入大小。
v2.RandomCrop(size[, padding, ...]) 在随机位置裁剪输入。
v2.RandomResizedCrop(size[, scale, ratio, ...]) 裁剪输入的随机部分并将其调整为给定大小。
v2.RandomIoUCrop([min_scale, max_scale, ...]) 来自"SSD: Single Shot MultiBox Detector"的随机 IoU 裁剪转换。
v2.CenterCrop(size) 在中心裁剪输入。
v2.FiveCrop(size) 将图像或视频裁剪为四个角和中心区域。
v2.TenCrop(size[, vertical_flip]) 将图像或视频裁剪为四个角和中心区域,以及这些区域的翻转版本(默认使用水平翻转)。
v2.RandomHorizontalFlip([p]) 以给定概率水平翻转输入。
v2.RandomVerticalFlip([p]) 以给定概率垂直翻转输入。
v2.Pad(padding[, fill, padding_mode]) 使用给定的"填充"值在所有侧面填充输入。
v2.RandomZoomOut([fill, side_range, p]) 来自"SSD: Single Shot MultiBox Detector"的"缩小 (Zoom out)"转换。
v2.RandomRotation(degrees[, interpolation, ...]) 按角度旋转输入。
v2.RandomAffine(degrees[, translate, scale, ...]) 对输入进行随机仿射变换,保持中心不变。
v2.RandomPerspective([distortion_scale, p, ...]) 以给定概率对输入执行随机透视变换。
v2.ElasticTransform([alpha, sigma, ...]) 使用弹性变换对输入进行转换。
  • 仿射变换就是线性变换 + 平移。变换后直线依然是直线,平行线依然是平行线,直线间的相对位置关系不变,因此非共线的三个对应点便可确定唯一的一个仿射变换 ,线性变换 4 个自由度 + 平移 2 个自由度 →仿射变换自由度为 6 。用cv2.getAffineTransform()生成变换矩阵,接下来再用cv2.warpAffine()实现变换。

    python 复制代码
    import cv2
    import numpy as np
    import matplotlib.pyplot as plt
    img = cv2.imread('drawing.jpg')
    rows, cols = img.shape[:2]
    # 变换前的三个点
    pts1 = np.float32([[50, 65], [150, 65], [210, 210]])
    # 变换后的三个点
    pts2 = np.float32([[50, 100], [150, 65], [100, 250]])
    # 生成变换矩阵
    M = cv2.getAffineTransform(pts1, pts2)
    dst = cv2.warpAffine(img, M, (cols, rows))
    plt.subplot(121), plt.imshow(img), plt.title('input')
    plt.subplot(122), plt.imshow(dst), plt.title('output')
    plt.show()
    • 平移就是 x 和 y 方向上的直接移动,可以上下/左右移动,自由度为 2,变换矩阵可以表示为:

    • u v \] = \[ 1 0 0 1 \] \[ x y \] + \[ t x t y \] \\begin{bmatrix}u \\\\ v \\end{bmatrix} = \\begin{bmatrix} 1 \& 0 \\\\ 0 \& 1\\end{bmatrix}\\begin{bmatrix} x \\\\ y\\end{bmatrix}+ \\begin{bmatrix} t_x \\\\ t_y\\end{bmatrix} \[uv\]=\[1001\]\[xy\]+\[txty

    • 旋转是坐标轴方向饶原点旋转一定的角度 θ,自由度为 1,不包含平移,如顺时针旋转可以表示为:

    • u v \] = \[ c o s θ − s i n θ s i n θ c o s θ \] \[ x y \] + \[ 0 0 \] \\begin{bmatrix}u \\\\ v\\end{bmatrix} = \\begin{bmatrix} cosθ \& -sinθ \\\\ sinθ \& cosθ\\end{bmatrix} \\begin{bmatrix} x \\\\ y\\end{bmatrix}+ \\begin{bmatrix} 0 \\\\ 0 \\end{bmatrix} \[uv\]=\[cosθsinθ−sinθcosθ\]\[xy\]+\[00

    • 翻转是 x 或 y 某个方向或全部方向上取反,自由度为 2,比如这里以垂直翻转为例:

    • u v \] = \[ 1 0 0 − 1 \] \[ x y \] + \[ t x t y \] \\begin{bmatrix} u \\\\ v\\end{bmatrix} = \\begin{bmatrix} 1 \& 0 \\\\ 0 \& -1 \\end{bmatrix} \\begin{bmatrix} x \\\\ y \\end{bmatrix} + \\begin{bmatrix} t_x \\\\ t_y \\end{bmatrix} \[uv\]=\[100−1\]\[xy\]+\[txty

    • 旋转 + 平移也称刚体变换(Rigid Transform),就是说如果图像变换前后两点间的距离仍然保持不变,那么这种变化就称为刚体变换。刚体变换包括了平移、旋转和翻转,自由度为 3。由于只是旋转和平移,刚体变换保持了直线间的长度不变,所以也称欧式变换(变化前后保持欧氏距离)。变换矩阵可以表示为:

    • u v \] = \[ c o s θ − s i n θ s i n θ c o s θ \] \[ x y \] + \[ t x t y \] \\begin{bmatrix} u \\\\ v \\end{bmatrix} =\\begin{bmatrix} cosθ \& -sinθ \\\\ sinθ \& cosθ \\end{bmatrix} \\begin{bmatrix} x \\\\ y \\end{bmatrix} + \\begin{bmatrix} t_x \\\\ t_y \\end{bmatrix} \[uv\]=\[cosθsinθ−sinθcosθ\]\[xy\]+\[txty

    • 缩放是 x 和 y 方向的尺度(倍数)变换,在有些资料上非等比例的缩放也称为拉伸/挤压,等比例缩放自由度为 1,非等比例缩放自由度为 2,矩阵可以表示为:

    • u v \] = \[ s x 0 0 s y \] \[ x y \] + \[ 0 0 \] \\begin{bmatrix}u \\\\ v\\end{bmatrix}=\\begin{bmatrix} s_x \& 0 \\\\ 0 \& s_y \\end{bmatrix} \\begin{bmatrix} x \\\\ y \\end{bmatrix} + \\begin{bmatrix} 0 \\\\ 0 \\end{bmatrix} \[uv\]=\[sx00sy\]\[xy\]+\[00

    • 相似变换又称缩放旋转,相似变换包含了旋转、等比例缩放和平移等变换,自由度为 4。相似变换相比刚体变换加了缩放,所以并不会保持欧氏距离不变,但直线间的夹角依然不变。在 OpenCV 中,旋转就是用相似变换实现的。若缩放比例为 scale,旋转角度为 θ,旋转中心是(centerx,centery),则仿射变换可以表示为:

    • u v \] = \[ s c a l e ⋅ c o s θ s c a l e ⋅ s i n θ − s c a l e ⋅ s i n θ s c a l e ⋅ c o s θ \] \[ x y \] + \[ ( 1 − s c a l e ⋅ c o s θ ) c e n t e r x − ( s c a l e ⋅ s i n θ ) c e n t e r y ( s c a l e ⋅ s i n θ ) c e n t e r x + ( 1 − s c a l e ⋅ c o s θ ) c e n t e r y \] \\begin{bmatrix} u \\\\ v\\end{bmatrix} = \\begin{bmatrix} scale⋅cosθ \& scale⋅sinθ \\\\ -scale⋅sinθ \& scale⋅cosθ \\end{bmatrix} \\begin{bmatrix} x \\\\ y \\end{bmatrix} + \\begin{bmatrix} (1−scale⋅cosθ)center_x−(scale⋅sinθ)center_y \\\\ (scale⋅sinθ)center_x+(1−scale⋅cosθ)center_y \\end{bmatrix} \[uv\]=\[scale⋅cosθ−scale⋅sinθscale⋅sinθscale⋅cosθ\]\[xy\]+\[(1−scale⋅cosθ)centerx−(scale⋅sinθ)centery(scale⋅sinθ)centerx+(1−scale⋅cosθ)centery

  • 透视变换(Perspective Transformation)是将二维的图片投影到一个三维视平面上,然后再转换到二维坐标下,所以也称为投影映射(Projective Mapping)。简单来说就是二维 → 三维 → 二维的一个过程。透视变换相比仿射变换更加灵活,变换后会产生一个新的四边形,但不一定是平行四边形,所以需要非共线的四个点才能唯一确定 ,原图中的直线变换后依然是直线。因为四边形包括了所有的平行四边形,所以透视变换包括了所有的仿射变换。OpenCV 中首先根据变换前后的四个点用cv2.getPerspectiveTransform()生成 3×3 的变换矩阵,然后再用cv2.warpPerspective()进行透视变换。

    python 复制代码
    img = cv2.imread('card.jpg')
    # 原图中卡片的四个角点
    pts1 = np.float32([[148, 80], [437, 114], [94, 247], [423, 288]])
    # 变换后分别在左上、右上、左下、右下四个点
    pts2 = np.float32([[0, 0], [320, 0], [0, 178], [320, 178]])
    # 生成透视变换矩阵
    M = cv2.getPerspectiveTransform(pts1, pts2)
    # 进行透视变换,参数 3 是目标图像大小
    dst = cv2.warpPerspective(img, M, (320, 178))
    plt.subplot(121), plt.imshow(img[:, :, ::-1]), plt.title('input')
    plt.subplot(122), plt.imshow(dst[:, :, ::-1]), plt.title('output')
    plt.show()
    • 写成齐次矩阵的形式:

    • X Y Z \] = \[ a 1 b 1 c 1 a 2 b 2 c 2 a 3 b 3 c 3 \] \[ x y 1 \] \\begin{bmatrix} X \\\\ Y \\\\ Z \\end{bmatrix} = \\begin{bmatrix} a_1 \& b_1 \& c_1\\\\ a_2 \& b_2 \& c_2\\\\ a_3 \& b_3 \& c_3 \\end{bmatrix} \\begin{bmatrix} x \\\\ y \\\\ 1 \\end{bmatrix} XYZ = a1a2a3b1b2b3c1c2c3 xy1

颜色转换
v2.ColorJitter([brightness, contrast, ...]) 随机更改图像或视频的亮度、对比度、饱和度和色调。
v2.RandomChannelPermutation() 随机置换图像或视频的通道。
v2.RandomPhotometricDistort([brightness, ...]) 随机扭曲图像或视频,如"SSD: Single Shot MultiBox Detector"中所用。
v2.Grayscale([num_output_channels]) 将图像或视频转换为灰度图。
v2.RGB() 将图像或视频转换为 RGB(如果它们不是 RGB)。
v2.RandomGrayscale([p]) 以概率 p(默认 0.1)随机将图像或视频转换为灰度图。
v2.GaussianBlur(kernel_size[, sigma]) 使用随机选择的高斯模糊核模糊图像。
v2.GaussianNoise([mean, sigma, clip]) 为图像或视频添加高斯噪声。
v2.RandomInvert([p]) 以给定概率反转给定图像或视频的颜色。
v2.RandomPosterize(bits[, p]) 以给定概率通过减少每个颜色通道的位数来对图像或视频进行色调分离(Posterize)。
v2.RandomSolarize(threshold[, p]) 以给定概率通过反转高于阈值的所有像素值来对图像或视频进行曝光过度(Solarize)。
v2.RandomAdjustSharpness(sharpness_factor[, p]) 以给定概率调整图像或视频的锐度。
v2.RandomAutocontrast([p]) 以给定概率对给定图像或视频的像素进行自动对比度调整(Autocontrast)。
v2.RandomEqualize([p]) 以给定概率均衡给定图像或视频的直方图。
进阶
v2.LinearTransformation(...) 使用离线计算的方阵变换矩阵和均值向量来变换张量图像或视频。
v2.Normalize(mean, std[, inplace]) 使用均值和标准差对张量图像或视频进行归一化。
v2.RandomErasing([p, scale, ratio, value, ...]) 在输入的图像或视频中随机选择一个矩形区域并擦除其像素。
v2.Lambda(lambd, *types) 将用户定义的函数作为变换应用。
v2.SanitizeBoundingBoxes([min_size, ...]) 移除退化/无效的边界框及其对应的标签和掩码。
v2.ClampBoundingBoxes() 将边界框限制在对应的图像尺寸内。
v2.UniformTemporalSubsample(num_samples) 从视频的时间维度中均匀采样 num_samples 个索引。
v2.JPEG(quality) 对给定的图像应用 JPEG 压缩和解压缩。
v2.Compose(transforms) 将多个变换组合在一起。
v2.RandomApply(transforms[, p]) 以给定的概率随机应用一系列变换。
v2.RandomChoice(transforms[, p]) 从列表中随机选择并应用单个变换。
v2.RandomOrder(transforms) 以随机顺序应用一系列变换。
  • 以下一些转换变换在执行转换时会对值进行缩放,而另一些则不会。缩放是指例如将 uint8 -> float32 的 [0, 255] 范围映射到 [0, 1](反之亦然)。

    v2.ToImage() 将张量、ndarray 或 PIL 图像转换为 Image;此操作不缩放值。
    v2.ToPureTensor() 将所有 TVTensors 转换为纯张量,移除相关元数据(如有)。
    v2.PILToTensor() 将 PIL 图像转换为相同类型的张量 - 此操作不缩放值。
    v2.ToPILImage([mode]) 将张量或 ndarray 转换为 PIL 图像
    v2.ToDtype(dtype[, scale]) 将输入转换为特定的 dtype,可选地对图像或视频的值进行缩放。
    v2.ConvertBoundingBoxFormat(format) 将边界框坐标转换为给定的 format,例如从 "CXCYWH" 转换为 "XYXY"。
    v2.Transform() 实现自定义 v2 变换的基类。
    v2.query_size(flat_inputs) 返回高度和宽度。
    v2.query_chw(flat_inputs) 返回通道数、高度和宽度。
    v2.get_bounding_boxes(flat_inputs) 返回输入中的边界框。
v2 弃用
v2.ToTensor() [已弃用] 请改用 v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)])
v2.functional.to_tensor(inpt) [已弃用] 请改用 to_image() 和 to_dtype()。
v2.ConvertImageDtype([dtype]) [已弃用] 请改用 v2.ToDtype(dtype, scale=True)
v2.functional.convert_image_dtype(image[, dtype]) [已弃用] 请改用 to_dtype()。
自动增强
增强 作用 参数建议
RandomResizedCrop 模拟不同距离拍摄 scale=(0.8,1.0) 避免裁剪过多
ColorJitter 模拟光照、白平衡变化 hue=0.1 防止颜色失真过大
RandomEqualize / Invert 增加对比度鲁棒性 低概率 p=0.1~0.2
GaussianBlur 模拟运动模糊、对焦不准 sigma=(0.1,2.0)
RandomErasing 模拟局部遮挡(如树枝、手指) scale=(0.02,0.2) 避免大面积擦除
JPEG 模拟网络传输压缩 quality=(60,95)
Mixup/CutMix 增强泛化,防止过拟合 小样本场景特别有效
  • 在图像分类任务中,Mixup 和 CutMix 是 标签混合(soft label) 技术,它们不使用 one-hot 标签,而是生成 两个样本的线性组合 ,从而提升泛化、防止过拟合。Mixup 原理,给定两个样本 ( x i , y i x_i, y_i xi,yi) 和 ( x j , y j x_j, y_j xj,yj),Mixup 生成:

    python 复制代码
    lambda ~ Beta(α, α)
    x_mix = λ * x_i + (1 - λ) * x_j
    y_mix = λ * y_i_onehot + (1 - λ) * y_j_onehot
    • y_i_onehot:类别标签转为 one-hot 向量;y_mix 是一个 soft label (如 [0.2, 0.0, 0.8, ...])。

    • 类似 Mixup,但不是线性混合像素,而是:从 x_j 中裁剪一个区域,粘贴到 x_i 上,标签按 面积比例 混合:

    python 复制代码
    area_ratio = cut_area / total_area
    y_mix = area_ratio * y_j_onehot + (1 - area_ratio) * y_i_onehot
  • 图像分类任务中,在 torchvision.transforms.v2 中使用 Mixup

    python 复制代码
    num_classes = 10  # 替换为你的类别数
    # 定义 Mixup 和 CutMix
    mixup = T.RandomMixup(alpha=0.2, num_classes=num_classes)
    cutmix = T.RandomCutmix(alpha=1.0, num_classes=num_classes)
    # 可以组合使用
    transform_with_mixup = T.RandomChoice([mixup, cutmix])
    for images, labels in dataloader:
        images = images.to(device)
        labels = labels.to(device)
        # 应用 Mixup/CutMix
        images, labels = transform_with_mixup(images, labels)
        # 此时 labels 是 (B, num_classes) 的 soft label
        # 模型前向
        outputs = model(images)
        loss = criterion(outputs, labels)  # 使用 BCEWithLogitsLoss 或 KLDivLoss
        loss.backward()
    • 由于标签是 soft label(概率分布),不能用 CrossEntropyLoss(它要求 hard label)。

相比 v1 在 v2 中的独特实现

v2.ScaleJitter(target_size[, scale_range, ...]) 根据"Simple Copy-Paste is a Strong Data Augmentation Method for Instance Segmentation"对输入执行大尺度抖动。
v2.RandomShortestSize(min_size[, max_size, ...]) 随机调整输入大小。
v2.RandomResize(min_size, max_size[, ...]) 随机调整输入大小。
v2.RandomIoUCrop([min_scale, max_scale, ...]) 来自"SSD: Single Shot MultiBox Detector"的随机 IoU 裁剪转换。
v2.RandomZoomOut([fill, side_range, p]) 来自"SSD: Single Shot MultiBox Detector"的"缩小 (Zoom out)"转换。
v2.RandomChannelPermutation() 随机置换图像或视频的通道。
v2.RandomPhotometricDistort([brightness, ...]) 随机扭曲图像或视频,如"SSD: Single Shot MultiBox Detector"中所用。
v2.RGB() 将图像或视频转换为 RGB(如果它们不是 RGB)。
v2.GaussianNoise([mean, sigma, clip]) 为图像或视频添加高斯噪声。
v2.SanitizeBoundingBoxes([min_size, ...]) 移除退化/无效的边界框及其对应的标签和掩码。
v2.ClampBoundingBoxes() 将边界框限制在对应的图像尺寸内。
v2.UniformTemporalSubsample(num_samples) 从视频的时间维度中均匀采样 num_samples 个索引。
v2.JPEG(quality) 对给定的图像应用 JPEG 压缩和解压缩。
v2.ToImage() 将张量、ndarray 或 PIL 图像转换为 Image;此操作不缩放值。
v2.ToPureTensor() 将所有 TVTensors 转换为纯张量,移除相关元数据(如有)。
v2.ConvertBoundingBoxFormat(format) 将边界框坐标转换为给定的 format,例如从 "CXCYWH" 转换为 "XYXY"。
v2.CutMix(*[, alpha, num_classes, labels_getter]) 对提供的图像和标签批次应用 CutMix。
v2.MixUp(*[, alpha, num_classes, labels_getter]) 对提供的图像和标签批次应用 MixUp。
v2.Transform() 实现自定义 v2 变换的基类。
v2.functional.register_kernel(functional, ...) 装饰一个 kernel 以便将其注册到 functional 和(自定义的)tv_tensor 类型。
v2.query_size(flat_inputs) 返回高度和宽度。
v2.query_chw(flat_inputs) 返回通道数、高度和宽度。
v2.get_bounding_boxes(flat_inputs) 返回输入中的边界框。

V1 API 参考

几何变换
Resize(size[, interpolation, max_size, ...]) 将输入图像调整到给定大小。
RandomCrop(size[, padding, pad_if_needed, ...]) 在随机位置裁剪给定的图像。
RandomResizedCrop(size[, scale, ratio, ...]) 裁剪图像的随机部分并将其调整到给定大小。
CenterCrop(size) 在中心裁剪给定的图像。
FiveCrop(size) 将给定的图像裁剪为四个角和中心部分。
TenCrop(size[, vertical_flip]) 将给定的图像裁剪为四个角和中心部分,以及这些部分的翻转版本(默认为水平翻转)。
Pad(padding[, fill, padding_mode]) 使用给定的"pad"值在给定图像的四周填充。
RandomRotation(degrees[, interpolation, ...]) 按角度旋转图像。
RandomAffine(degrees[, translate, scale, ...]) 对图像进行随机仿射变换,同时保持中心不变。
RandomPerspective([distortion_scale, p, ...]) 以给定概率对给定图像执行随机透视变换。
ElasticTransform([alpha, sigma, ...]) 使用弹性变换对张量图像进行变换。
RandomHorizontalFlip([p]) 以给定概率随机水平翻转给定的图像。
RandomVerticalFlip([p]) 以给定概率随机垂直翻转给定的图像。
色彩变换
ColorJitter([brightness, contrast, ...]) 随机改变图像的亮度、对比度、饱和度和色相。
Grayscale([num_output_channels]) 将图像转换为灰度图。
RandomGrayscale([p]) 以概率 p(默认 0.1)随机将图像转换为灰度图。
GaussianBlur(kernel_size[, sigma]) 使用随机选择的高斯模糊对图像进行模糊。
RandomInvert([p]) 以给定概率随机反转给定图像的颜色。
RandomPosterize(bits[, p]) 以给定概率通过减少每个颜色通道的位数随机对图像进行色调分离(posterize)。
RandomSolarize(threshold[, p]) 以给定概率通过反转阈值以上的所有像素值随机对图像进行曝光过度(solarize)。
RandomAdjustSharpness(sharpness_factor[, p]) 以给定概率随机调整图像的锐度。
RandomAutocontrast([p]) 以给定概率随机对给定图像的像素进行自动对比度调整。
RandomEqualize([p]) 以给定概率随机对给定图像的直方图进行均衡化。
进阶
Compose(transforms) 将多个变换组合在一起。
RandomApply(transforms[, p]) 以给定的概率随机应用一系列变换。
RandomChoice(transforms[, p]) 从列表中随机选择并应用单个变换。
RandomOrder(transforms) 以随机顺序应用一系列变换。
LinearTransformation(transformation_matrix, ...) 使用离线计算的方阵变换矩阵和均值向量对张量图像进行变换。
Normalize(mean, std[, inplace]) 使用均值和标准差对张量图像进行归一化。
RandomErasing([p, scale, ratio, value, inplace]) 在 torch.Tensor 图像中随机选择一个矩形区域并擦除其像素。
Lambda(lambd) 将用户定义的 lambda 函数作为变换应用。
ToPILImage([mode]) 将张量或 ndarray 转换为 PIL 图像
ToTensor() 将PIL图像或ndarray转换为张量并相应地缩放值。
PILToTensor() 将 PIL 图像转换为相同类型的张量 - 此操作不缩放值。
ConvertImageDtype(dtype) 将张量图像转换为给定的 dtype 并相应地缩放值。
  • 一些转换变换在执行转换时会对值进行缩放,而另一些则不会。缩放是指例如将 uint8 -> float32 的 [0, 255] 范围映射到 [0, 1](反之亦然)。AutoAugment 是一种常用的数据增强技术,可以提高图像分类模型的准确性。尽管数据增强策略与其训练的数据集直接相关,但实证研究表明,ImageNet 策略应用于其他数据集时也能带来显著改进。在 TorchVision 中,我们实现了从以下数据集学到的 3 种策略:ImageNet、CIFAR10 和 SVHN。这个新的变换既可以单独使用,也可以与现有变换混合搭配使用。

AutoAugmentPolicy(value) 在不同数据集上学习到的AutoAugment策略。
AutoAugment([policy, interpolation, fill]) 基于 "AutoAugment: Learning Augmentation Strategies from Data" 的 AutoAugment 数据增强方法。
RandAugment([num_ops, magnitude, ...]) 基于 "RandAugment: Practical automated data augmentation with a reduced search space" 的 RandAugment 数据增强方法。
TrivialAugmentWide([num_magnitude_bins, ...]) 基于 "TrivialAugment: Tuning-free Yet State-of-the-Art Data Augmentation" 中描述的与数据集无关的 TrivialAugment Wide 数据增强。
AugMix([severity, mixture_width, ...]) 基于 "AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty" 的 AugMix 数据增强方法。
  • Few-Shot Class-Incremental Learning via Class-Aware Bilateral Distillation 是一个基于双分支蒸馏的小样本类增量学习项目,主要包含以下几个核心模块:

    • 数据处理模块(dataloader目录),data_utils.py: 基础数据加载工具,包括基础类和增量类的数据加载,samplers.py: 数据采样器,用于构建few-shot任务,exemplar_set.py: 样本集管理,用于保存和管理历史类别的样本
    • 模型结构模块(models目录),resnet18_encoder.py: ResNet-18特征提取器,resnet20_cifar.py: 用于CIFAR数据集的ResNet-20,cosine_classifier.py: 基于余弦相似度的分类器;binary_classifier.py: 二分类器;mlp_models.py: 多层感知机模型
    • 核心方法模块(methods目录),cosine_classifier.py: 余弦分类器实现,包含基础训练和增量学习的处理逻辑
    • 工具模块(utils目录),fsl_inc.py: Few-shot增量学习相关工具函数;utils.py: 通用工具函数
  • 项目采用标准的小样本类增量学习设置:**基础训练阶段(Session 0)**使用基础类别(如MiniImageNet中的60类)进行预训练,每个类别有大量样本用于训练基础分类器。增量学习阶段(Session 1-N):每个增量session引入少量新类别(如每轮5个新类),每个新类别只有少量样本(如5个)用于微调,使用历史类别的样本(exemplars)防止遗忘。 Exemplar存储机制,为了防止遗忘,项目会为每个已学习的类别存储代表性样本(exemplars),在增量学习时使用这些样本来进行知识蒸馏。

  • 基础模型训练使用train.py进行基础类别训练,采用余弦分类器,通过特征与类别权重的余弦相似度进行分类,基础训练参数:--epoch: 训练轮数,--batch_size: 批处理大小,--init_lr: 初始学习率,--milestones: 学习率衰减的里程碑。使用基础类别数据训练特征提取器和基础分类器,保存模型权重用于后续增量学习。

  • Stable (稳定): 这些功能将长期维护,通常不会有主要的性能限制或文档空白。我们也期望保持向后兼容性(尽管可能发生破坏性变更,但会提前一个版本发出通知)。Beta (测试版): 功能被标记为 Beta 版,是因为其 API 可能根据用户反馈而变化,性能需要改进,或者对算子的覆盖尚未完全。对于 Beta 功能,我们致力于将其推进到 Stable 分类。但我们不承诺向后兼容性。Prototype (原型): 这些功能通常不作为 PyPI 或 Conda 等二进制发行版的一部分提供,除非有时通过运行时标志启用,并且处于早期阶段,用于收集反馈和进行测试。

  • 增量学习过程使用test.py进行增量测试,每个session使用少量新样本进行微调,采用双边蒸馏策略保持旧知识并学习新知识。增量学习参数:--ft_iters: 微调迭代次数,--ft_lr: 微调学习率,--w_d: 知识蒸馏损失权重;--bilateral: 是否使用双边分支;--BC_hidden_dim: 双分支网络隐藏层维度;--w_BC_binary: 双分支二分类损失权重。每个session加载之前训练好的模型,使用新类别的少量样本和历史类别的exemplar样本进行微调,采用知识蒸馏保持旧知识不被遗忘,双边分支结构增强模型泛化能力。

shell 复制代码
python train.py --dataset mydataset --dataroot /path/to/your/data/root --exp_dir experiment --epoch 200 --batch_size 256 --init_lr 0.1 --milestones 120 160 --val_start 100 --change_val_interval 160
  • --dataset: 数据集名称,--dataroot: 你的数据集根目录路径,--base_class: 基础类别数,--num_classes: 总类别数,--n_way: 每个增量session的新类别数,--n_shot: 每个新类别的样本数,--n_sessions: 增量session数。

yaml 复制代码
args.base_class = 100      # 基础类别数(你的数据集为100)
args.num_classes = 150     # 总类别数(根据你的数据集总类别数调整)
args.n_way = 5             # 每次增量学习的新类别数
args.n_shot = 5            # 每个新类别的样本数(可以设为最大值)
args.n_sessions = 10       # 增量学习的总轮数(根据你的总类别数计算)
  • 微调相关参数: --needs_finetune:是否需要微调;--ft_iters 100: 微调迭代次数;--ft_lr 0.01:微调学习率;--ft_T 4.0:知识蒸馏温度参数;损失函数权重:--w_d 100 :知识蒸馏损失权重;--w_BC_binary 50:二分类损失权重;--EMA_logits:使用指数移动平均logits;双边分支参数:--bilateral : 使用双边分支;--BC_hidden_dim 64 :双分支网络隐藏层维度;--BC_lr 0.01:双分支学习率

yaml 复制代码
--dataroot /path/to/data/           # 数据集根目录
--dataset mydataset                 # 数据集名称
--method imprint                    # 分类方法,imprint表示使用印记学习
--base_mode avg_cos                 # 基础模式,avg_cos表示使用平均特征作为类别权重
--norm_first                        # 是否先进行特征归一化
--exp_dir experiment                # 实验结果保存目录
# 训练参数 (Training Arguments)
--epoch 200                         # 训练轮数
--batch_size 8                      # 训练批次大小
--batch_size_new 0                  # 增量学习时新类别的批次大小,0表示使用全部数据
--batch_size_test 100               # 测试批次大小
--init_lr 0.1                       # 初始学习率
--schedule Milestone                # 学习率调度策略
--milestones 120 160                # 学习率衰减的里程碑
--step 40                           # Step调度策略的步长
--gamma 0.1                         # 学习率衰减因子
--val_start 100                     # 开始验证的轮数
--val_interval 5                    # 验证间隔
--change_val_interval 160           # 改变验证间隔的轮数
--num_workers 8                     # 数据加载时的线程数
--report_binary                     # 是否报告二分类结果
# 微调参数 (Finetuning Arguments)
--needs_finetune                    # 是否需要微调
--using_exemplars                   # 是否使用样本集
--save_all_data_base                # 是否保存所有基础类别的数据作为样本
--save_all_data_novel               # 是否保存所有新类别的数据作为样本
--imprint_ft_weight                 # 是否使用印记学习初始化微调权重
--bn_eval                           # 是否固定BN层的均值和方差
--part_frozen                       # 是否只更新最后的残差块
--ft_optimizer SGD                  # 微调优化器
--ft_lr 0.01                        # 微调学习率
--ft_factor 0.1                     # 骨干网络学习率因子
--ft_iters 100                      # 微调迭代次数
--ft_n_repeat 1                     # 构造批次时重复采样的次数
--ft_T 4.0                          # 知识蒸馏温度
--ft_teacher fixed                  # 教师模型选择:fixed, prev, ema
--ft_momentum 0.9                   # EMA动量因子
--ft_momentum_type 1                # EMA动量类型
--ft_KD_all                         # 是否蒸馏所有logits还是仅之前的logits
--ft_reinit                         # 每个session开始时是否重新初始化学生模型
--w_cls 1                           # 新类别交叉熵损失权重
--w_e 5                             # 额外损失函数权重
--w_d 100                           # 知识蒸馏损失权重
--w_l 0                             # L1/L2归一化损失权重
--w_l_order 1                       # L1/L2归一化损失阶数
--margin 0                          # 使用margin-based softmax损失函数
--triplet                           # 是否使用triplet损失
--triplet_gap 0                     # triplet损失间隔
--KD_rectified                      # 是否修正教师模型的logits
--KD_rectified_factor 0.8           # KD修正因子
--weighted_kd                       # 是否对不同类别使用不同的蒸馏权重
--w_kd_novel 1.0                    # 新类别的蒸馏权重
--vis_exemplars                     # 微调后是否保存和可视化样本
--vis_exemplars_nrow 10             # 网格显示中每行的样本数
--vis_logits                        # 微调后是否保存logits
--logits_tag saved_logits           # logits保存标签
# EMA Logits 参数
--EMA_logits                        # 是否在教师模型logits中使用指数移动平均
--EMA_prob                          # 是否对softmax分布而不是logits使用EMA
--EMA_type learnable_mlp_b          # EMA类型:linear, window, linear_t, learnable_mpl_b
--EMA_w_size 3                      # EMA_type=window时的窗口大小
--EMA_scalar 0                      # EMA_type=learnable时的可学习参数初始化
--EMA_scalar_lr 0.01                # EMA_type=learnable时的可学习参数学习率
--EMA_factor_b_1 1.0                # EMA_type=linear/window/linear_t时基础logits起始因子
--EMA_factor_b_2 1.0                # EMA_type=linear/window/linear_t时基础logits结束因子
--EMA_factor_n_1 0.5                # EMA_type=linear/window/linear_t时新logits起始因子
--EMA_factor_n_2 0.5                # EMA_type=linear/window/linear_t时新logits结束因子
--EMA_top_k 1                       # EMA_type=learnable_s时计算基础相似度的topk
--EMA_FC_dim 64                     # EMA_type=learnable_mpl_v/c/b时的隐藏维度
--EMA_FC_lr 0.01                    # EMA_type=learnable_s/mpl_v/c时的学习率
--EMA_FC_K 1                        # EMA_type=learnable_s时的K*x+b中的K
--EMA_FC_b 1                        # EMA_type=learnable_s时的K*x+b中的b
--EMA_s_type 0                      # EMA_type=learnable_s时的类型
--EMA_reinit                        # EMA_type=learnable_s时是否重新初始化K和b
# 双分支参数 (Bilateral-branch Arguments)
--bilateral                         # 测试时是否使用双分支
--report_binary                     # 是否报告二分类结果
--main_branch current               # 主分支:current, ema
--second_branch fixed               # 辅助分支:fixed, ema
--merge_strategy attn               # 两个分支的融合策略
--branch_selector logits_current    # 分支选择器
--masking_novel                     # 是否遮蔽新类别
--branch_weights 0.5                # 分支权重
--BC_hidden_dim 64                  # 二分类网络隐藏维度
--BC_lr 0.01                        # 二分类网络学习率
--BC_flatten org                    # 特征展平方式
--BC_detach                         # 是否分离二分类网络梯度
--BC_detach_f                       # 是否分离特征梯度
--BC_binary_factor 1.0              # 二分类因子
--w_BC_cls 5                        # 二分类交叉熵损失权重
--w_BC_binary 50                    # 二分类损失权重
  • best_model_avg_cos.tar文件是一个PyTorch模型保存文件,通常包含以下内容:模型权重 :训练好的模型参数;优化器状态 :优化器的状态信息(如果保存了的话);训练轮数 :训练的epoch数;其他元数据:可能包括学习率调度器状态等。要进行增量学习,你需要使用测试脚本加载这个 best_model_avg_cos.tar 预训练模型,然后进行增量学习。

    python 复制代码
    # 加载预训练模型
    model = CosClassifier(args, phase='meta_test')
    tmp = torch.load(args.model_path)
    model.load_state_dict(tmp['state'], strict=False)
    model.cuda()
    • 准备增量学习数据,确保你已经创建了session_1.txt文件,其中包含5个新类别的名称。

    • 可以使用my_test.py脚本进行增量学习:

    shell 复制代码
    python my_test.py \
      --dataset mydataset \
      --dataroot /path/to/your/data \
      --exp_dir experiment \
      --load_tag avg_cos \
      --needs_finetune \
      --ft_iters 50 \
      --ft_lr 0.001 \
      --ft_T 4.0 \
      --w_d 50 \
      --bilateral \
      --BC_hidden_dim 64 \
      --BC_lr 0.01 \
      --w_BC_binary 30
    • 执行增量学习

    python 复制代码
    # 在所有增量session上进行评估
    acc_list = model_test.test_inc_loop()
    • test_inc_loop()方法会依次处理每个session,包括:加载对应session的数据;为新类别生成权重(imprint);如果需要微调,则进行微调;评估模型性能。如果你只想对特定session进行增量学习,可以修改测试脚本中的循环范围。例如,只对session 2进行增量学习:

    python 复制代码
    # 在test_inc_loop方法中修改循环范围
    for session in range(0, 3):  # 只处理session 0, 1, 2
  • 在增量学习过程中,基础模型的100个类别权重已经保存在best_model_avg_cos.tar文件中,不需要再从基础训练数据目录中读取这些类别的数据。你只需要为新类别提供数据即可。算法会使用知识蒸馏等技术来保持对基础类别的识别能力。

相关推荐
可触的未来,发芽的智生3 小时前
新奇特:负权重橡皮擦,让神经网络学会主动遗忘
人工智能·python·神经网络·算法·架构
咖啡Beans3 小时前
Python常用系统自带库之json解析
python
付玉祥3 小时前
第 6 章 异常处理与文件操作
python
森诺Alyson3 小时前
前沿技术借鉴研讨-2025.9.23 (数据不平衡)
论文阅读·人工智能·经验分享·深度学习·论文笔记
AI原吾3 小时前
ClaudeCode真经第二章:核心功能深度解析
python·ai编程·claudecode
东方芷兰4 小时前
LLM 笔记 —— 03 大语言模型安全性评定
人工智能·笔记·python·语言模型·自然语言处理·nlp·gpt-3
MediaTea4 小时前
Python 库手册:keyword 关键字查询
开发语言·python
java1234_小锋4 小时前
Scikit-learn Python机器学习 - 模型保存及加载
python·机器学习·scikit-learn
睿思达DBA_WGX4 小时前
使用 python-docx 库操作 word 文档(1):文件操作
开发语言·python·word