PyTorch深度学习进阶(四)(数据增广)

数据增广

对图片做不同处理,如去掉部分像素,对颜色变换,对亮度变换

一般是将不同的生成方法随机的用在数据上

总结

代码

基础操作

读取图片

python 复制代码
img = d2l.Image.open('01_Data/02_cat.jpg')

显示图片

python 复制代码
d2l.plt.imshow(img)

传入aug图片增广方法

python 复制代码
def apply(img, aug, num_rows=2, num_cols=4, scale=1.5)

用aug方法对图片作用八次

python 复制代码
Y = [aug(img) for _ in range(num_rows * num_cols)]

生成结果用num_cols行,num_cols列展示

python 复制代码
d2l.show_images(Y, num_rows, num_cols, scale=scale) 

水平随机翻转

python 复制代码
apply(img, torchvision.transforms.RandomHorizontalFlip())

上下随机翻转

python 复制代码
apply(img, torchvision.transforms.RandomVerticalFlip())

随机剪裁,剪裁后的大小为(200,200)

(0.1,1)使得随即剪裁原始图片的10%到100%区域里的大小,ratio=(0.5,2)使得高宽比为2:1,下面是显示时显示的1:1

python 复制代码
shape_aug = torchvision.transforms.RandomResizedCrop((200,200),scale=(0.1,1),ratio=(0.5,2))     
apply(img,shape_aug)

随即更改图像的亮度

python 复制代码
apply(img,torchvision.transforms.ColorJitter(brightness=0.5,contrast=0,saturation=0,hue=0))

随即改变色调

python 复制代码
apply(img,torchvision.transforms.ColorJitter(brightness=0,contrast=0,saturation=0,hue=0.5))

随机更改图像的亮度(brightness)、对比度(constrast)、饱和度(saturation)和色调(hue)

python 复制代码
color_aug = torchvision.transforms.ColorJitter(brightness=0.5,contrast=0.5,saturation=0.5,hue=0.5)
apply(img,color_aug)

结合多种图像增广方法

先随即水平翻转,再做颜色增广,再做形状增广

python 复制代码
augs = torchvision.transforms.Compose([torchvision.transforms.RandomHorizontalFlip(),color_aug,shape_aug])   
apply(img,augs)

训练

下载图片,并显示部分图片

python 复制代码
all_images = torchvision.datasets.CIFAR10(train=True, root='01_Data/03_CIFAR10', download=True)    
d2l.show_images([all_images[i][0] for i in range(32)], 4, 8, scale=0.8)

只使用最简单的随机左右翻转

python 复制代码
train_augs = torchvision.transforms.Compose([
    torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.ToTensor()])

test_augs = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor()])  

定义一个辅助函数,以便于读取图像和应用图像增广

python 复制代码
def load_cifar10(is_train, augs, batch_size):
    dataset = torchvision.datasets.CIFAR10(root='01_Data/03_CIFAR10',train=is_train,
                                         transform=augs, download=True)
    dataloader = torch.utils.data.DataLoader(dataset,batch_size=batch_size,shuffle=is_train,
                                            num_workers = 0)   
    return dataloader

定义一个函数,使用多GPU模式进行训练和评估

python 复制代码
def train_batch_ch13(net, X, y, loss, trainer, devices):

如果X是一个list,则把数据一个接一个都挪到devices[0]上

python 复制代码
if isinstance(X, list):
    X = [x.to(devices[0]) for x in X]

训练一个batch

如果X不是一个list,则把X挪到devices[0]上

python 复制代码
else:
    X = X.to(devices[0])
python 复制代码
    y = y.to(devices[0])
    net.train()
    trainer.zero_grad()
    pred = net(X)
    l = loss(pred, y)
    l.sum().backward()
    trainer.step()
    train_loss_sum = l.sum()
    train_acc_sum = d2l.accuracy(pred, y)
    return train_loss_sum, train_acc_sum
python 复制代码
def train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs, devices=d2l.try_all_gpus()):
    timer, num_batches = d2l.Timer(), len(train_iter)
    animator = d2l.Animator(xlabel='epoch',xlim=[1,num_epochs],ylim=[0,1],
                           legend=['train loss', 'train acc', 'test acc'])
    # nn.DataParallel使用多GPU
    net = nn.DataParallel(net, device_ids=devices).to(devices[0])
    for epoch in range(num_epochs):
        metric = d2l.Accumulator(4)
        for i, (features, labels) in enumerate(train_iter):
            timer.start()
            l, acc = train_batch_ch13(net,features,labels,loss,trainer,devices)   
            metric.add(l,acc,labels.shape[0],labels.numel())
            timer.stop()
            if (i + 1) % (num_batches // 5) == 0 or i == num_batches -1:
                animator.add(
                    epoch + (i + 1) / num_batches,
                    (metric[0] / metric[2], metric[1] / metric[3], None))              
        test_acc = d2l.evaluate_accuracy_gpu(net,test_iter)
        animator.add(epoch+1,(None,None,test_acc))
    print(f'loss {metric[0] / metric[2]:.3f}, train acc'
         f' {metric[1] / metric[3]:.3f}, test acc {test_acc:.3f}')
    print(f' {metric[2] * num_epochs / timer.sum():.1f} examples/sec on '
         f' {str(devices)}') 

定义train_with_data_aug函数,使用图像增广来训练模型

python 复制代码
batch_size, devices, net = 256, d2l.try_all_gpus(), d2l.resnet18(10,3)

def init_weights(m):
    if type(m) in [nn.Linear, nn.Conv2d]:
        nn.init.xavier_uniform_(m.weight)
        
net.apply(init_weights)

def train_with_data_aug(train_augs, test_augs, net, lr=0.001):
    train_iter = load_cifar10(True, train_augs, batch_size)
    test_iter = load_cifar10(False, test_augs, batch_size)
    loss = nn.CrossEntropyLoss(reduction="none")
    # Adam优化器算是一个比较平滑的SGD,它对学习率调参不是很敏感
    trainer = torch.optim.Adam(net.parameters(),lr=lr)
    train_ch13(net, train_iter, test_iter, loss, trainer, 10, devices)
    
train_with_data_aug(train_augs, test_augs, net)

结果

相关推荐
强盛小灵通专卖员1 小时前
基于RT-DETR的电力设备过热故障红外图像检测
人工智能·目标检测·sci·研究生·小论文·大论文·延毕
倔强青铜三1 小时前
AI编程革命:React + shadcn/ui 将终结前端框架之战
前端·人工智能·ai编程
喵个咪1 小时前
基于 Go-Kratos 与 MCP 的推荐服务实战指南
后端·深度学习·微服务
sali-tec1 小时前
C# 基于halcon的视觉工作流-章62 点云采样
开发语言·图像处理·人工智能·算法·计算机视觉
EAIReport1 小时前
通过数据分析自动化产品实现AI生成PPT的完整流程
人工智能·数据分析·自动化
swanwei1 小时前
量子科技对核心产业的颠覆性影响及落地时间表(全文2500字)
大数据·网络·人工智能·程序人生·量子计算
AKAMAI2 小时前
从 Cloudflare 服务中断,看建立多维度风险应对机制的必要
人工智能·云原生·云计算
道可云2 小时前
道可云人工智能每日资讯|2025青岛虚拟现实创新大会即将举行
人工智能·vr
酷雷曼VR全景2 小时前
身边的变化丨从“尝鲜”到“刚需”,VR全景让生活“立体化”
人工智能·生活·vr·vr全景·酷雷曼·合作商