【动手学深度学习】(十四)数据增广+微调

文章目录

一、数据增强

1.理论知识

  • 增加一个已有数据集,使得有更多的多样性
    • 在语言里面加入各种不同的背景噪音
    • 改变图片的颜色和形状

使用增强数据训练
翻转

  • 左右翻转
  • 上下翻转
    • 不总是可行

切割

  • 从图片中切割一块,然后变形到固定形状
    • 随机高宽比
    • 随机大小
    • 随机位置

颜色

  • 改变色调,饱和度,明亮度

总结

  • 数据增广通过变形数据来获取多样性从而使得模型泛化性能更好
  • 常见图片增广包括翻转、切割、变色

2.代码

1.读取图像

python 复制代码
%matplotlib inline
import torch
import torchvision
from torch import nn
from d2l import torch as d2l


d2l.set_figsize()
img = d2l.Image.open('../img/test.png')
d2l.plt.imshow(img);
python 复制代码
def apply(img, aug, num_rows=2, num_cols=4, scale=1.5):
    Y = [aug(img) for _ in range(num_rows * num_cols)]
    d2l.show_images(Y, num_rows, num_cols, scale=scale)

水平翻转

python 复制代码
apply(img, torchvision.transforms.RandomHorizontalFlip())
# 在水平方向进行随机翻转

上下翻转图像

python 复制代码
# 上下翻转图像
apply(img, torchvision.transforms.RandomVerticalFlip())

随机裁剪

python 复制代码
shape_aug = torchvision.transforms.RandomResizedCrop(
    (200, 200), scale=(0.1, 1), ratio=(0.5, 2))
apply(img, shape_aug)

随机更改图片亮度

python 复制代码
apply(img, torchvision.transforms.ColorJitter(
    brightness=0.5, contrast=0, saturation=0, hue=0))

随机更改图片的色调,亮度(brightness)对比度(contrast)饱和度(saturation)色调(hue)

python 复制代码
# 随机更改图片的色调,亮度(brightness)对比度(contrast)饱和度(saturation)色调(hue)
color_aug = torchvision.transforms.ColorJitter(
    brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5)
apply(img, color_aug)

结合多种图像增广方法

python 复制代码
augs = torchvision.transforms.Compose([
    torchvision.transforms.RandomHorizontalFlip(),
    color_aug, shape_aug])
apply(img, augs)
python 复制代码
all_images = torchvision.datasets.CIFAR10(
    train=True, root="../data", download=True)
d2l.show_images([
    all_images[i][0] for i in range(32)], 4, 8, scale=0.8);
python 复制代码
# 只使用最简单的随机左右翻转
train_augs = torchvision.transforms.Compose([
    torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.ToTensor()])

test_augs = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor()])
python 复制代码
# 定义一个辅助函数,以便于读取图像和应用图像增广
def load_cifar10(is_train, augs, batch_size):
    dataset = torchvision.datasets.CIFAR10(
        root="../data", train=is_train,
        transform=augs, download=True)
    dataloader = torch.utils.data.DataLoader(
        dataset, batch_size=batch_size, shuffle=is_train,
        num_workers=0)
    return dataloader

二、微调

1.理论知识

标注一个数据集很贵
网络架构

  • 一个神经网络一般可以分成两块
    • 特征抽取将原始像素变成容易线性分割的特征
    • 线性分类器来做分类

微调

微调中的权重初始化

相关推荐
Niuguangshuo4 分钟前
自编码器与变分自编码器:【2】自编码器的局限性
pytorch·深度学习·机器学习
摘星编程9 分钟前
CANN内存管理机制:从分配策略到性能优化
人工智能·华为·性能优化
likerhood16 分钟前
3. pytorch中数据集加载和处理
人工智能·pytorch·python
Robot侠17 分钟前
ROS1从入门到精通 10:URDF机器人建模(从零构建机器人模型)
人工智能·机器人·ros·机器人操作系统·urdf机器人建模
haiyu_y18 分钟前
Day 46 TensorBoard 使用介绍
人工智能·深度学习·神经网络
阿里云大数据AI技术23 分钟前
DataWorks 又又又升级了,这次我们通过 Arrow 列存格式让数据同步速度提升10倍!
大数据·人工智能
做科研的周师兄24 分钟前
中国土壤有机质数据集
人工智能·算法·机器学习·分类·数据挖掘
IT一氪25 分钟前
一款 AI 驱动的 Word 文档翻译工具
人工智能·word
lovingsoft28 分钟前
Vibe coding 氛围编程
人工智能
百***074533 分钟前
GPT-Image-1.5 极速接入全流程及关键要点
人工智能·gpt·计算机视觉