猫狗识别(PyTorch)

任务介绍

数据结构为:

big_data

├── train

│ └── cat

│ └── XXX.jpg(每个文件夹含若干张图像)

│ └── dog

│ └── XXX.jpg(每个文件夹含若干张图像)

├── val

│ └── cat

│ └── XXX.jpg(每个文件夹含若干张图像)

│ └── dog

└── ─── └── XXX.jpg(每个文件夹含若干张图像)

需要对train数据集进行训练,达到给定val数据集中的一张猫 / 狗的图片,识别其是猫还是狗

数据预处理

引入头文件

python 复制代码
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda, Compose
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
import numpy as np

数据读取与预处理

在这里使用torch.utils.data.DataLoader()方法,把训练数据分成多个小组,此函数每次抛出一组数据,直至把所有的数据都抛出

常见参数介绍如下(借鉴自:传送门):

1、dataset:(数据类型 dataset)

输入的数据类型,这里是原始数据的输入,PyTorch内也有这种数据结构

2、batch_size:(数据类型 int)

批训练数据量的大小,根据具体情况设置即可(默认:1)

PyTorch训练模型时调用数据不是一行一行进行的(这样太没效率),而是一捆一捆来的,这里就是定义每次喂给神经网络多少行数据,其每次随机读取大小为batch_size,如果dataset中的数据个数不是batch_size的整数倍,则最后一次把剩余的数据全部输出,若想把剩下的不足batch size个的数据丢弃,则将drop_last设置为True,会将多出来不足一个batch的数据丢弃

3、shuffle:(数据类型 bool)

洗牌,默认设置为False

在每次迭代训练时是否将数据洗牌,将输入数据的顺序打乱,是为了使数据更有独立性,但如果数据是有序列特征的,就不要设置成True了

4、collate_fn:(数据类型 callable)

将一小段数据合并成数据列表,默认设置是False

如果设置成True,系统会在返回前会将张量数据(Tensors)复制到CUDA内存中

5、batch_sampler:(数据类型 Sampler)

批量采样,默认设置为None

但每次返回的是一批数据的索引(注意:不是数据),其和batch_size、shuffle 、sampler and drop_last参数是不兼容的

6、sampler:(数据类型 Sampler)

采样,默认设置为None

根据定义的策略从数据集中采样输入,如果定义采样规则,则洗牌(shuffle)设置必须为False。

7、drop_last:(数据类型 bool)

丢弃最后数据,默认为False

设置了 batch_size 的数目后,最后一批数据未必是设置的数目,有可能会小些,选择这你是否需要丢弃这批数据

代码:

python 复制代码
train_datadir = './big_data/train/'
test_datadir  = './big_data/val/'

# https://pytorch-cn.readthedocs.io/zh/latest/torchvision/torchvision-transform/
# https://pytorch.org/vision/stable/auto_examples/plot_transforms.html#sphx-glr-auto-examples-plot-transforms-py
# https://pytorch.org/vision/stable/transforms.html
train_transforms = transforms.Compose([
    transforms.Resize([224, 224]),  # 将输入图片resize成统一尺寸
    transforms.RandomRotation(degrees=(-10, 10)),  #随机旋转,-10到10度之间随机选
    # transforms.RandomHorizontalFlip(p=0.5),  #随机水平翻转 选择一个概率概率
    # transforms.RandomVerticalFlip(p=0.5),  #随机垂直翻转(效果可能会变差)
    transforms.RandomPerspective(distortion_scale=0.6, p=1.0), # 随机视角
    transforms.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 5)),  #随机选择的高斯模糊模糊图像
    transforms.ToTensor(),          # 将PIL Image或numpy.ndarray转换为tensor,并归一化到[0,1]之间
    transforms.Normalize(           # 标准化处理-->转换为标准正太分布(高斯分布),使模型更容易收敛
        mean=[0.485, 0.456, 0.406], 
        std=[0.229, 0.224, 0.225])  # 其中 mean=[0.485,0.456,0.406]与std=[0.229,0.224,0.225] 从数据集中随机抽样计算得到的
])

test_transforms = transforms.Compose([
    transforms.Resize([224, 224]),  # 将输入图片resize成统一尺寸
    transforms.ToTensor(),          # 将PIL Image或numpy.ndarray转换为tensor,并归一化到[0,1]之间
    transforms.Normalize(           # 标准化处理-->转换为标准正太分布(高斯分布),使模型更容易收敛
        mean=[0.485, 0.456, 0.406], 
        std=[0.229, 0.224, 0.225])  # 其中 mean=[0.485,0.456,0.406]与std=[0.229,0.224,0.225] 从数据集中随机抽样计算得到的。
])

train_data = datasets.ImageFolder(train_datadir,transform=train_transforms)

test_data  = datasets.ImageFolder(test_datadir,transform=test_transforms)

train_loader = torch.utils.data.DataLoader(train_data,
                                          batch_size=4,
                                          shuffle=True,
                                          num_workers=1)
test_loader  = torch.utils.data.DataLoader(test_data,
                                          batch_size=4,
                                          shuffle=True,
                                          num_workers=1)

查看每组数据的格式:

python 复制代码
for X, y in test_loader:
    print("Shape of X [N, C, H, W]: ", X.shape)
    print("Shape of y: ", y.shape, y.dtype)
    break

输出:

python 复制代码
Shape of X [N, C, H, W]:  torch.Size([4, 3, 224, 224])
Shape of y:  torch.Size([4]) torch.int64

图像可视化

python 复制代码
def im_convert(tensor):
    """ 展示数据"""
#     tensor.clone()  返回tensor的拷贝,返回的新tensor和原来的tensor具有同样的大小和数据类型
#     tensor.detach() 从计算图中脱离出来。
    image = tensor.to("cpu").clone().detach()
    
#     numpy.squeeze()这个函数的作用是去掉矩阵里维度为1的维度
    image = image.numpy().squeeze()
    
#     将npimg的数据格式由(channels,imagesize,imagesize)转化为(imagesize,imagesize,channels),
#     进行格式的转换后方可进行显示
    image = image.transpose(1,2,0)
    
#     和标准差操作正好相反即可
    image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))
    
#     使用image.clip(0, 1) 将数据 限制在0和1之间
    image = image.clip(0, 1)

    return image

fig=plt.figure(figsize=(20, 20))
columns = 2
rows = 2

dataiter = iter(train_loader)
inputs, classes = dataiter.next()

for idx in range (columns*rows):
    ax = fig.add_subplot(rows, columns, idx+1, xticks=[], yticks=[])
    if classes[idx] == 0:
        ax.set_title("cat", fontsize = 35)
    else:
        ax.set_title("dog", fontsize = 35)
    plt.imshow(im_convert(inputs[idx]))
plt.savefig('pic1.jpg', dpi=600) #指定分辨率保存
plt.show()

输出:

建立模型

python 复制代码
import torch.nn.functional as F

# 找到可以用于训练的 GPU
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using {} device".format(device))

# 定义模型
class LeNet(nn.Module):
    # 一般在__init__中定义网络需要的操作算子,比如卷积、全连接算子等等
    def __init__(self):
        super(LeNet, self).__init__()
        # Conv2d的第一个参数是输入的channel数量,第二个是输出的channel数量,第三个是kernel size
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        # 由于上一层有16个channel输出,每个feature map大小为53*53,所以全连接层的输入是16*53*53
        self.fc1 = nn.Linear(16*53*53, 120)
        self.fc2 = nn.Linear(120, 84)
        # 最终有10类,所以最后一个全连接层输出数量是10
        self.fc3 = nn.Linear(84, 2)
        self.pool = nn.MaxPool2d(2, 2)
    # forward这个函数定义了前向传播的运算,只需要像写普通的python算数运算那样就可以了
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool(x)
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        # 下面这步把二维特征图变为一维,这样全连接层才能处理
        x = x.view(-1, 16*53*53)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

model = LeNet().to(device)
print(model)

输出:

python 复制代码
Using cuda device
LeNet(
  (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=44944, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=2, bias=True)
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)

损失函数与优化器:

python 复制代码
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

定义训练函数

python 复制代码
def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        # 计算预测误差
        pred = model(X)
        loss = loss_fn(pred, y)

        # 反向传播
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

定义测试函数

python 复制代码
def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

进行训练

在单个训练循环中,模型对训练数据集进行预测(分批提供给它),并反向传播预测误差从而调整模型的参数

python 复制代码
epochs = 20
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(train_loader, model, loss_fn, optimizer)
    test(test_loader, model, loss_fn)
print("Done!")

输出:

python 复制代码
Epoch 1
-------------------------------
loss: 0.697169  [    0/ 2819]
loss: 0.717117  [  400/ 2819]
loss: 0.692842  [  800/ 2819]
loss: 0.663078  [ 1200/ 2819]
loss: 0.687370  [ 1600/ 2819]
loss: 0.727401  [ 2000/ 2819]
loss: 0.715081  [ 2400/ 2819]
loss: 0.640342  [ 2800/ 2819]
Test Error: 
 Accuracy: 71.1%, Avg loss: 0.662693 

Epoch 2
-------------------------------
loss: 0.603678  [    0/ 2819]
loss: 0.574420  [  400/ 2819]
loss: 0.609898  [  800/ 2819]
loss: 0.756877  [ 1200/ 2819]
loss: 0.612366  [ 1600/ 2819]
loss: 0.538443  [ 2000/ 2819]
loss: 0.406403  [ 2400/ 2819]
loss: 0.762628  [ 2800/ 2819]
Test Error: 
 Accuracy: 73.8%, Avg loss: 0.603772 

Epoch 3
-------------------------------
loss: 0.532889  [    0/ 2819]
loss: 0.627735  [  400/ 2819]
loss: 0.573161  [  800/ 2819]
loss: 0.369012  [ 1200/ 2819]
loss: 0.450293  [ 1600/ 2819]
loss: 0.562393  [ 2000/ 2819]
loss: 0.436040  [ 2400/ 2819]
loss: 0.464255  [ 2800/ 2819]
Test Error: 
 Accuracy: 63.2%, Avg loss: 0.625167 

Epoch 4
-------------------------------
loss: 0.561073  [    0/ 2819]
loss: 0.684834  [  400/ 2819]
loss: 0.419241  [  800/ 2819]
loss: 0.829945  [ 1200/ 2819]
loss: 0.486006  [ 1600/ 2819]
loss: 0.258347  [ 2000/ 2819]
loss: 0.322795  [ 2400/ 2819]
loss: 0.767529  [ 2800/ 2819]
Test Error: 
 Accuracy: 72.6%, Avg loss: 0.572731 

Epoch 5
-------------------------------
loss: 0.542559  [    0/ 2819]
loss: 0.662963  [  400/ 2819]
loss: 0.488824  [  800/ 2819]
loss: 0.542599  [ 1200/ 2819]
loss: 0.584228  [ 1600/ 2819]
loss: 0.807718  [ 2000/ 2819]
loss: 1.087557  [ 2400/ 2819]
loss: 0.384040  [ 2800/ 2819]
Test Error: 
 Accuracy: 65.2%, Avg loss: 0.610963 

Epoch 6
-------------------------------
loss: 0.372240  [    0/ 2819]
loss: 0.391429  [  400/ 2819]
loss: 0.355201  [  800/ 2819]
loss: 0.504742  [ 1200/ 2819]
loss: 0.190237  [ 1600/ 2819]
loss: 0.808446  [ 2000/ 2819]
loss: 0.189117  [ 2400/ 2819]
loss: 0.520030  [ 2800/ 2819]
Test Error: 
 Accuracy: 78.8%, Avg loss: 0.481039 

Epoch 7
-------------------------------
loss: 0.849491  [    0/ 2819]
loss: 0.462627  [  400/ 2819]
loss: 0.917399  [  800/ 2819]
loss: 1.132419  [ 1200/ 2819]
loss: 0.308156  [ 1600/ 2819]
loss: 0.763916  [ 2000/ 2819]
loss: 0.360583  [ 2400/ 2819]
loss: 0.164551  [ 2800/ 2819]
Test Error: 
 Accuracy: 79.0%, Avg loss: 0.454252 

Epoch 8
-------------------------------
loss: 0.255556  [    0/ 2819]
loss: 0.510223  [  400/ 2819]
loss: 0.271078  [  800/ 2819]
loss: 0.164843  [ 1200/ 2819]
loss: 0.536149  [ 1600/ 2819]
loss: 0.621327  [ 2000/ 2819]
loss: 0.644994  [ 2400/ 2819]
loss: 0.227701  [ 2800/ 2819]
Test Error: 
 Accuracy: 82.3%, Avg loss: 0.397003 

Epoch 9
-------------------------------
loss: 1.072562  [    0/ 2819]
loss: 0.368804  [  400/ 2819]
loss: 0.297367  [  800/ 2819]
loss: 0.511167  [ 1200/ 2819]
loss: 0.244405  [ 1600/ 2819]
loss: 0.233891  [ 2000/ 2819]
loss: 0.174815  [ 2400/ 2819]
loss: 0.492107  [ 2800/ 2819]
Test Error: 
 Accuracy: 81.8%, Avg loss: 0.405467 

Epoch 10
-------------------------------
loss: 0.676339  [    0/ 2819]
loss: 0.423010  [  400/ 2819]
loss: 0.472313  [  800/ 2819]
loss: 0.124012  [ 1200/ 2819]
loss: 0.132490  [ 1600/ 2819]
loss: 0.374766  [ 2000/ 2819]
loss: 0.202931  [ 2400/ 2819]
loss: 0.639156  [ 2800/ 2819]
Test Error: 
 Accuracy: 86.4%, Avg loss: 0.352187 

Epoch 11
-------------------------------
loss: 0.164241  [    0/ 2819]
loss: 0.402599  [  400/ 2819]
loss: 0.075091  [  800/ 2819]
loss: 0.253864  [ 1200/ 2819]
loss: 0.227414  [ 1600/ 2819]
loss: 0.188128  [ 2000/ 2819]
loss: 0.437947  [ 2400/ 2819]
loss: 0.231940  [ 2800/ 2819]
Test Error: 
 Accuracy: 86.7%, Avg loss: 0.354747 

Epoch 12
-------------------------------
loss: 0.411373  [    0/ 2819]
loss: 0.596428  [  400/ 2819]
loss: 0.419576  [  800/ 2819]
loss: 0.983684  [ 1200/ 2819]
loss: 0.713979  [ 1600/ 2819]
loss: 0.491828  [ 2000/ 2819]
loss: 0.196907  [ 2400/ 2819]
loss: 0.087960  [ 2800/ 2819]
Test Error: 
 Accuracy: 86.4%, Avg loss: 0.325611 

Epoch 13
-------------------------------
loss: 0.368461  [    0/ 2819]
loss: 0.276991  [  400/ 2819]
loss: 0.715205  [  800/ 2819]
loss: 0.151266  [ 1200/ 2819]
loss: 0.474812  [ 1600/ 2819]
loss: 0.868296  [ 2000/ 2819]
loss: 0.097645  [ 2400/ 2819]
loss: 0.232329  [ 2800/ 2819]
Test Error: 
 Accuracy: 89.3%, Avg loss: 0.286453 

Epoch 14
-------------------------------
loss: 0.297342  [    0/ 2819]
loss: 0.082402  [  400/ 2819]
loss: 0.270308  [  800/ 2819]
loss: 0.064577  [ 1200/ 2819]
loss: 0.175540  [ 1600/ 2819]
loss: 0.373381  [ 2000/ 2819]
loss: 0.661588  [ 2400/ 2819]
loss: 0.223183  [ 2800/ 2819]
Test Error: 
 Accuracy: 86.6%, Avg loss: 0.303336 

Epoch 15
-------------------------------
loss: 0.435931  [    0/ 2819]
loss: 0.115896  [  400/ 2819]
loss: 0.582083  [  800/ 2819]
loss: 0.094237  [ 1200/ 2819]
loss: 0.191783  [ 1600/ 2819]
loss: 0.229429  [ 2000/ 2819]
loss: 0.657790  [ 2400/ 2819]
loss: 0.540002  [ 2800/ 2819]
Test Error: 
 Accuracy: 89.7%, Avg loss: 0.263210 

Epoch 16
-------------------------------
loss: 0.054571  [    0/ 2819]
loss: 0.298561  [  400/ 2819]
loss: 0.043506  [  800/ 2819]
loss: 0.380746  [ 1200/ 2819]
loss: 0.167052  [ 1600/ 2819]
loss: 0.391230  [ 2000/ 2819]
loss: 0.057670  [ 2400/ 2819]
loss: 0.131060  [ 2800/ 2819]
Test Error: 
 Accuracy: 88.3%, Avg loss: 0.279989 

Epoch 17
-------------------------------
loss: 0.064786  [    0/ 2819]
loss: 0.745590  [  400/ 2819]
loss: 0.211468  [  800/ 2819]
loss: 0.314401  [ 1200/ 2819]
loss: 0.139008  [ 1600/ 2819]
loss: 0.461783  [ 2000/ 2819]
loss: 0.654826  [ 2400/ 2819]
loss: 0.506236  [ 2800/ 2819]
Test Error: 
 Accuracy: 90.7%, Avg loss: 0.226764 

Epoch 18
-------------------------------
loss: 0.643280  [    0/ 2819]
loss: 0.049473  [  400/ 2819]
loss: 0.027292  [  800/ 2819]
loss: 0.074732  [ 1200/ 2819]
loss: 0.054178  [ 1600/ 2819]
loss: 0.484333  [ 2000/ 2819]
loss: 0.103221  [ 2400/ 2819]
loss: 0.521000  [ 2800/ 2819]
Test Error: 
 Accuracy: 90.4%, Avg loss: 0.238688 

Epoch 19
-------------------------------
loss: 0.045256  [    0/ 2819]
loss: 0.360988  [  400/ 2819]
loss: 0.541000  [  800/ 2819]
loss: 0.484565  [ 1200/ 2819]
loss: 0.060333  [ 1600/ 2819]
loss: 0.108621  [ 2000/ 2819]
loss: 0.137073  [ 2400/ 2819]
loss: 0.015121  [ 2800/ 2819]
Test Error: 
 Accuracy: 91.0%, Avg loss: 0.225625 

Epoch 20
-------------------------------
loss: 0.140982  [    0/ 2819]
loss: 0.092246  [  400/ 2819]
loss: 0.072535  [  800/ 2819]
loss: 0.215292  [ 1200/ 2819]
loss: 0.099996  [ 1600/ 2819]
loss: 0.126051  [ 2000/ 2819]
loss: 0.009734  [ 2400/ 2819]
loss: 0.050575  [ 2800/ 2819]
Test Error: 
 Accuracy: 89.7%, Avg loss: 0.244002 

Done!

保存完整模型

python 复制代码
torch.save(model, "model.pth")
# 读取
# model = torch.load("model.pth")
相关推荐
code04号2 小时前
df = pd.DataFrame(data)中的data可以是什么类型的数据?
python·pandas
机械&编程攻城狮(好哥)4 小时前
伺服电机控制驱动器选择
python·modbus tcp·驱动器·伺服电机驱动·canable·modbus rtu
Algorithm_Engineer_4 小时前
机器学习中常用的降维方法-主成分分析法(PCA)
python·机器学习
最爱番茄味4 小时前
Python之字符串基础篇
python
西岭千秋雪_7 小时前
设计模式の装饰者&组合&外观模式
java·python·设计模式·组合模式·装饰器模式·外观模式
爱写代码的小朋友7 小时前
Python模块导入:import与from...import的深度解析
python
最强菜鸟8 小时前
python爬虫爬取淘宝热销(热门)台式电脑商品信息(课程设计;提供源码、使用说明文档及相关文档;)
爬虫·python·课程设计·淘宝·drissionpage·电脑数据
乐茵安全9 小时前
基于python绘制数据表(上)
java·前端·python
程序员大金9 小时前
基于python+django+vue的高校成绩管理系统
vue.js·python·django