PyTorch应用实战二:实现卷积神经网络进行图像分类

文章目录

实验环境

python3.6 + pytorch1.8.0

python 复制代码
import torch
print(torch.__version__)
1.8.0

MNIST数据集

MNIST数字数据集是一组手写数字图像的数据集,用于机器学习中的图像分类任务。

该数据集包含60,000张训练图像和10,000张测试图像,每张图像都是28x28像素大小的灰度图像。每张图像都被标记为0到9中的一个数字。

该数据集是由美国国家标准与技术研究所(NIST)收集和创建,因此得名为MNIST(Modified National Institute of Standards and Technology)。它已成为机器学习领域中广泛使用的基准数据集之一。

1.网络结构

网络层 参数 输出尺寸

Input N×1×28×28

Conv1 ksize=5, C_out=4, pad=0, stride=2 N×4×12×12

ReLU N×4×12×12

Conv2 ksize=3, C_out=8, pad=0, stride=2 N×8×5×5

ReLU N×8×5×5

Flatten N×200

Linear num_out=10 N×10
输出图像的维度计算方法:(图像大小 - 卷积核大小)÷步长+1

2.程序实现

2.1 导入相关库

python 复制代码
import torch
from torch.nn import ReLU
from torch.nn.functional import conv2d, cross_entropy 
from torchvision import datasets, transforms

2.2 构建卷积神经网络模型

定义ReLU激活函数

python 复制代码
def relu(x):
    return torch.clamp(x, min=0)

torch.clamp函数是一个张量操作函数,在PyTorch中实现。该函数的作用是将输入的张量每个元素都限制在指定的区间内,返回一个新的张量。

可以通过指定min或max参数的值来实现对张量元素的限制,也可以同时指定同时限制上下限。若min和max均不指定,则默认为min=0,max=1。

定义线性函数

python 复制代码
def linear(x, weight, bias):
    out = torch.matmul(x, weight) + bias.view(1, -1)
    return out

定义神经网络模型

conv2d的五大参数:x, w, b, stride, pad

python 复制代码
def model(x, params):
    # 卷积层1
    # w=(4,1,5,5),输出通道为4,输入通道为1,卷积核为(5,5)
    x = conv2d(x, params[0], params[1], 2, 0)
    # (N×4×12×12)
    x = relu(x)
    # 卷积层2
    # w=(8,4,3,3),输出通道为8,输入通道为4,卷积核为(3,3)
    x = conv2d(x, params[2], params[3], 2, 0)
    x = relu(x)
    # (N×8×5×5)
    x = x.view(-1, 200)
    # 全连接层
    x = linear(x, params[4], params[5])
    return x

该程序定义了一个卷积神经网络模型,输入参数包括待处理的数据x和模型参数params。程序通过使用卷积层、ReLU激活函数和全连接层构建了一个简单的卷积神经网络模型。其中,程序使用了两个卷积层和一个全连接层。具体的模型参数包括两个卷积层的卷积核权重、偏置、以及一个全连接层的权重和偏置。程序返回模型的预测结果x。

初始化模型

python 复制代码
init_std = 0.1
params = [
    torch.randn(4, 1, 5, 5) * init_std,
    torch.zeros(4),
    torch.randn(8, 4, 3, 3) * init_std,
    torch.zeros(8),
    torch.randn(200, 10) * init_std,
    torch.zeros(10)
]
for p in params:
    p.requires_grad=True   #自动微分
python 复制代码
params
[tensor([[[[ 0.1062, -0.0596, -0.0730,  0.0613,  0.1273],
           [ 0.0751, -0.0852,  0.0648, -0.0774, -0.0355],
           [ 0.1565, -0.0221,  0.0574, -0.1055,  0.0350],
           [-0.1299,  0.0169,  0.0297,  0.1494, -0.0993],
           [ 0.0464,  0.0628, -0.0133, -0.0545,  0.1265]]],
 
 
         [[[ 0.0291,  0.1538, -0.0692, -0.0637,  0.0829],
           [ 0.0735, -0.0594, -0.1185, -0.0026, -0.0351],
           [ 0.0697,  0.1032, -0.1001, -0.0212, -0.0946],
           [ 0.0311,  0.1461,  0.0641, -0.0407,  0.1615],
           [-0.0517,  0.0298, -0.0482,  0.0984, -0.0602]]],
 
 
         [[[ 0.0102, -0.1541, -0.1040,  0.0335,  0.0115],
           [-0.1167,  0.1155,  0.0832,  0.0561,  0.0435],
           [ 0.0429, -0.1574,  0.0323, -0.1353, -0.1211],
           [-0.2472, -0.2379, -0.0963,  0.0105, -0.0845],
           [ 0.0059,  0.0433,  0.0111,  0.0422, -0.0131]]],
 
 
         [[[-0.0741,  0.1411,  0.0006,  0.1485,  0.1257],
           [ 0.0446,  0.0822, -0.0458,  0.1525,  0.0695],
           [ 0.0616,  0.1892,  0.1525, -0.0594,  0.1515],
           [-0.0490,  0.1179, -0.1175,  0.1448,  0.0811],
           [-0.0641, -0.0494, -0.0980, -0.1119,  0.0599]]]], requires_grad=True),
 tensor([0., 0., 0., 0.], requires_grad=True),
 tensor([[[[-0.1564,  0.1481,  0.0995],
           [ 0.0157,  0.0025, -0.0513],
           [ 0.1011, -0.0417, -0.1049]],
 
          [[-0.2052, -0.0514, -0.0995],
           [ 0.1915, -0.0586, -0.0985],
           [-0.1371, -0.2874, -0.0977]],
 
          [[-0.0367, -0.2326,  0.0306],
           [ 0.0193,  0.0762,  0.0243],
           [-0.1507, -0.1265, -0.1493]],
 
          [[ 0.0769, -0.1014,  0.0888],
           [-0.0632, -0.0782, -0.1765],
           [ 0.0521,  0.2349, -0.0833]]],
 
 
         [[[ 0.0041, -0.0487,  0.1597],
           [-0.0210, -0.1051,  0.0374],
           [ 0.1981,  0.1395,  0.0108]],
 
          [[ 0.0418,  0.1592,  0.0219],
           [ 0.1168,  0.0305, -0.0702],
           [ 0.2217, -0.0670, -0.0037]],
 
          [[ 0.0501, -0.2496,  0.0381],
           [-0.1487, -0.0202,  0.0236],
           [-0.0738,  0.0733, -0.1244]],
 
          [[-0.0825,  0.0158, -0.0877],
           [ 0.0337,  0.2011,  0.1339],
           [-0.1452, -0.1665,  0.0141]]],
 
 
         [[[-0.0779, -0.1749,  0.0731],
           [-0.0936,  0.0519,  0.1093],
           [ 0.1049,  0.0406, -0.0594]],
 
          [[ 0.1012,  0.0804, -0.0153],
           [ 0.0899, -0.0954, -0.0520],
           [ 0.0724, -0.0487, -0.0048]],
 
          [[-0.0814, -0.0918,  0.0481],
           [-0.0482,  0.1069, -0.1442],
           [-0.0863,  0.0290,  0.0701]],
 
          [[ 0.1068,  0.1104, -0.0105],
           [ 0.1059,  0.0294,  0.2377],
           [ 0.0855, -0.0029,  0.0322]]],
 
 
         [[[ 0.0467, -0.1335, -0.0698],
           [-0.0683,  0.0323, -0.0197],
           [ 0.1748,  0.1601,  0.0385]],
 
          [[ 0.0100, -0.0644,  0.0374],
           [ 0.2065, -0.0637, -0.1515],
           [-0.1963,  0.0413,  0.0476]],
 
          [[ 0.0600, -0.0431,  0.0280],
           [ 0.0428,  0.0220, -0.0793],
           [-0.0876,  0.0013, -0.0618]],
 
          [[ 0.3182,  0.0358,  0.1933],
           [-0.0536,  0.1208, -0.0318],
           [ 0.0144,  0.0707, -0.0400]]],
 
 
         [[[ 0.1354,  0.0365,  0.0468],
           [-0.0399,  0.0050,  0.0589],
           [ 0.0335,  0.0415, -0.0896]],
 
          [[-0.2733, -0.0715, -0.0500],
           [-0.0501, -0.0715, -0.0644],
           [ 0.0130,  0.0041, -0.0051]],
 
          [[-0.0304, -0.0240, -0.0137],
           [ 0.0931,  0.0814,  0.0466],
           [ 0.1731,  0.0708,  0.1007]],
 
          [[-0.2261,  0.0397, -0.1092],
           [-0.0359, -0.1310, -0.0156],
           [ 0.1561, -0.0511,  0.0229]]],
 
 
         [[[ 0.1716,  0.1097,  0.0117],
           [-0.1617,  0.2321,  0.1619],
           [ 0.0721,  0.0619,  0.0418]],
 
          [[ 0.0714, -0.0578, -0.0358],
           [-0.0079,  0.0858,  0.1151],
           [ 0.0559,  0.1615, -0.1431]],
 
          [[ 0.0542,  0.0115, -0.0027],
           [-0.0479, -0.0977, -0.0463],
           [-0.0527,  0.1211, -0.3093]],
 
          [[-0.0261, -0.0288, -0.0048],
           [ 0.1089,  0.1711,  0.1310],
           [ 0.0552, -0.0664, -0.0463]]],
 
 
         [[[-0.0420,  0.1513, -0.1023],
           [-0.1159, -0.1849, -0.0109],
           [-0.1040, -0.0087,  0.0621]],
 
          [[ 0.0287,  0.1017,  0.0153],
           [-0.0875, -0.0534, -0.0378],
           [ 0.0076, -0.0756, -0.3164]],
 
          [[ 0.1255,  0.0647, -0.0439],
           [ 0.0973, -0.1135, -0.1946],
           [ 0.0965,  0.0432,  0.1767]],
 
          [[ 0.0021,  0.0215,  0.0423],
           [-0.0698, -0.0198, -0.0033],
           [-0.0363, -0.0935,  0.0433]]],
 
 
         [[[-0.0605, -0.0037,  0.0548],
           [ 0.0856, -0.0815,  0.0877],
           [ 0.0714,  0.1070, -0.1506]],
 
          [[ 0.1159,  0.0648, -0.0505],
           [-0.2376,  0.0337,  0.1003],
           [ 0.1495,  0.0696,  0.0500]],
 
          [[ 0.1888,  0.1441,  0.0078],
           [ 0.1022, -0.0609, -0.2317],
           [-0.0881, -0.0983,  0.1033]],
 
          [[ 0.0118,  0.0604,  0.0070],
           [ 0.0215,  0.0990,  0.1829],
           [-0.0226,  0.1505,  0.0927]]]], requires_grad=True),
 tensor([0., 0., 0., 0., 0., 0., 0., 0.], requires_grad=True),
 tensor([[ 0.1092, -0.0588, -0.0678,  ..., -0.0369,  0.1425, -0.1710],
         [-0.2353, -0.1057,  0.0156,  ...,  0.0120,  0.0512,  0.1080],
         [-0.0902,  0.1036,  0.1558,  ..., -0.1726,  0.1594, -0.0046],
         ...,
         [-0.1067,  0.1090,  0.0657,  ...,  0.1041, -0.1314,  0.0274],
         [ 0.0735, -0.0332,  0.0949,  ...,  0.0044, -0.1386,  0.0113],
         [ 0.0212, -0.0620,  0.1167,  ...,  0.0424,  0.0393, -0.0940]],
        requires_grad=True),
 tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], requires_grad=True)]

2.3 加载MNIST数据集

python 复制代码
train_batch_size = 100  #一共60000个数据,100个数据一组,一共600组(N=100)
test_batch_size = 100   #一共10000个数据,100个数据一组,一共100组(N=100)
# 训练集
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST(
        './data',train=True,download=True,  
        transform=transforms.Compose([
            transforms.ToTensor(),  #转为tensor
            transforms.Normalize((0.5),(0.5)) #正则化
        ])
    ),
    batch_size = train_batch_size, shuffle=True  #打乱位置
)
# 测试集
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST(
        './data',train=False,download=True,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5),(0.5))
        ])
    ),
    batch_size = test_batch_size, shuffle=False
)

这段代码主要是用来读取MNIST数据集并进行数据预处理,将其转换为可以在神经网络中使用的格式。其中train_batch_size和test_batch_size分别为训练集和测试集每批处理的数据大小。通过torch.utils.data.DataLoader函数,将MNIST数据集读入,并通过transforms.Compose函数对数据进行预处理,包括将其转换为tensor以及进行正则化处理。在训练集和测试集中,还通过shuffle参数打乱数据集的顺序,以增加数据集的多样性。最终将处理好的数据通过train_loader和test_loader返回。

2.4 训练模型

python 复制代码
alpha = 0.1  #学习率
epochs = 100  #训练次数
interval = 100  #打印间隔
for epoch in range(epochs):
    for i, (data, label) in enumerate(train_loader):
        output = model(data, params)
        loss = cross_entropy(output, label)  #交叉熵函数
        for p in params:   
            if p.grad is not None:  #如果梯度不为零
                p.grad.zero_()      #梯度置零
        loss.backward()   #反向求导
        
        for p in params:  #更新参数
            p.data = p.data - alpha * p.grad.data
        
        if i % interval == 0:
            print("Epoch %03d [%03d/%03d]\tLoss:%.4f"%(epoch, i, len(train_loader), loss.item()))
    correct_num = 0
    total_num = 0
    with torch.no_grad():
        for data, label in test_loader:
            output = model(data, params)
            pred = output.max(1)[1]
            correct_num += (pred == label).sum().item()
            total_num += len(data)
    acc = correct_num / total_num
    print('...Testing @ Epoch %03d\tAcc:%.4f'%(epoch, acc))
Epoch 000 [000/600]	Loss:2.3276
Epoch 000 [100/600]	Loss:0.4883
Epoch 000 [200/600]	Loss:0.2200
Epoch 000 [300/600]	Loss:0.1809
Epoch 000 [400/600]	Loss:0.1448
Epoch 000 [500/600]	Loss:0.2617
...Testing @ Epoch 000	Acc:0.9469
Epoch 001 [000/600]	Loss:0.1120
Epoch 001 [100/600]	Loss:0.1738
Epoch 001 [200/600]	Loss:0.1212
Epoch 001 [300/600]	Loss:0.1924
Epoch 001 [400/600]	Loss:0.0743
Epoch 001 [500/600]	Loss:0.3233
...Testing @ Epoch 001	Acc:0.9613
Epoch 002 [000/600]	Loss:0.0553
Epoch 002 [100/600]	Loss:0.2300
Epoch 002 [200/600]	Loss:0.0489
Epoch 002 [300/600]	Loss:0.0816
Epoch 002 [400/600]	Loss:0.1188
Epoch 002 [500/600]	Loss:0.1968
...Testing @ Epoch 002	Acc:0.9670
Epoch 003 [000/600]	Loss:0.2476
Epoch 003 [100/600]	Loss:0.1316
Epoch 003 [200/600]	Loss:0.2912
Epoch 003 [300/600]	Loss:0.0378
Epoch 003 [400/600]	Loss:0.1028
Epoch 003 [500/600]	Loss:0.1199
...Testing @ Epoch 003	Acc:0.9693
Epoch 004 [000/600]	Loss:0.0853
Epoch 004 [100/600]	Loss:0.1661
Epoch 004 [200/600]	Loss:0.1067
Epoch 004 [300/600]	Loss:0.0579
Epoch 004 [400/600]	Loss:0.2422
Epoch 004 [500/600]	Loss:0.0573
...Testing @ Epoch 004	Acc:0.9695
Epoch 005 [000/600]	Loss:0.0816
Epoch 005 [100/600]	Loss:0.0845
Epoch 005 [200/600]	Loss:0.0607
Epoch 005 [300/600]	Loss:0.0548
Epoch 005 [400/600]	Loss:0.1351
Epoch 005 [500/600]	Loss:0.0569
...Testing @ Epoch 005	Acc:0.9726
Epoch 006 [000/600]	Loss:0.0441
Epoch 006 [100/600]	Loss:0.0912
Epoch 006 [200/600]	Loss:0.1213
Epoch 006 [300/600]	Loss:0.0405
Epoch 006 [400/600]	Loss:0.0311
Epoch 006 [500/600]	Loss:0.0755
...Testing @ Epoch 006	Acc:0.9743
Epoch 007 [000/600]	Loss:0.0342
Epoch 007 [100/600]	Loss:0.0480
Epoch 007 [200/600]	Loss:0.1276
Epoch 007 [300/600]	Loss:0.1255
Epoch 007 [400/600]	Loss:0.0572
Epoch 007 [500/600]	Loss:0.0186
...Testing @ Epoch 007	Acc:0.9724
Epoch 008 [000/600]	Loss:0.0515
Epoch 008 [100/600]	Loss:0.0291
Epoch 008 [200/600]	Loss:0.0362
Epoch 008 [300/600]	Loss:0.0979
Epoch 008 [400/600]	Loss:0.0950
Epoch 008 [500/600]	Loss:0.0263
...Testing @ Epoch 008	Acc:0.9768
Epoch 009 [000/600]	Loss:0.0766
Epoch 009 [100/600]	Loss:0.1084
Epoch 009 [200/600]	Loss:0.0495
Epoch 009 [300/600]	Loss:0.0260
Epoch 009 [400/600]	Loss:0.0708
Epoch 009 [500/600]	Loss:0.0809
...Testing @ Epoch 009	Acc:0.9775
Epoch 010 [000/600]	Loss:0.0179
Epoch 010 [100/600]	Loss:0.0293
Epoch 010 [200/600]	Loss:0.0534
Epoch 010 [300/600]	Loss:0.0983
Epoch 010 [400/600]	Loss:0.1292
Epoch 010 [500/600]	Loss:0.0820
...Testing @ Epoch 010	Acc:0.9781
Epoch 011 [000/600]	Loss:0.0542
Epoch 011 [100/600]	Loss:0.0978
Epoch 011 [200/600]	Loss:0.0596
Epoch 011 [300/600]	Loss:0.0725
Epoch 011 [400/600]	Loss:0.1350
Epoch 011 [500/600]	Loss:0.0166
...Testing @ Epoch 011	Acc:0.9784
Epoch 012 [000/600]	Loss:0.1395
Epoch 012 [100/600]	Loss:0.0577
Epoch 012 [200/600]	Loss:0.0437
Epoch 012 [300/600]	Loss:0.1207
Epoch 012 [400/600]	Loss:0.0795
Epoch 012 [500/600]	Loss:0.0415
...Testing @ Epoch 012	Acc:0.9797
Epoch 013 [000/600]	Loss:0.0897
Epoch 013 [100/600]	Loss:0.0530
Epoch 013 [200/600]	Loss:0.0421
Epoch 013 [300/600]	Loss:0.0982
Epoch 013 [400/600]	Loss:0.0646
Epoch 013 [500/600]	Loss:0.0215
...Testing @ Epoch 013	Acc:0.9773
Epoch 014 [000/600]	Loss:0.0544
Epoch 014 [100/600]	Loss:0.0554
Epoch 014 [200/600]	Loss:0.0236
Epoch 014 [300/600]	Loss:0.0154
Epoch 014 [400/600]	Loss:0.0262
Epoch 014 [500/600]	Loss:0.1322
...Testing @ Epoch 014	Acc:0.9764
Epoch 015 [000/600]	Loss:0.0204
Epoch 015 [100/600]	Loss:0.0563
Epoch 015 [200/600]	Loss:0.0309
Epoch 015 [300/600]	Loss:0.0844
Epoch 015 [400/600]	Loss:0.1167
Epoch 015 [500/600]	Loss:0.0856
...Testing @ Epoch 015	Acc:0.9788
Epoch 016 [000/600]	Loss:0.0052
Epoch 016 [100/600]	Loss:0.1451
Epoch 016 [200/600]	Loss:0.0279
Epoch 016 [300/600]	Loss:0.0199
Epoch 016 [400/600]	Loss:0.0765
Epoch 016 [500/600]	Loss:0.1028
...Testing @ Epoch 016	Acc:0.9807
Epoch 017 [000/600]	Loss:0.0149
Epoch 017 [100/600]	Loss:0.0160
Epoch 017 [200/600]	Loss:0.0454
Epoch 017 [300/600]	Loss:0.1048
Epoch 017 [400/600]	Loss:0.1013
Epoch 017 [500/600]	Loss:0.0834
...Testing @ Epoch 017	Acc:0.9801
Epoch 018 [000/600]	Loss:0.0062
Epoch 018 [100/600]	Loss:0.0501
Epoch 018 [200/600]	Loss:0.0478
Epoch 018 [300/600]	Loss:0.0331
Epoch 018 [400/600]	Loss:0.0333
Epoch 018 [500/600]	Loss:0.0163
...Testing @ Epoch 018	Acc:0.9795
Epoch 019 [000/600]	Loss:0.0331
Epoch 019 [100/600]	Loss:0.0200
Epoch 019 [200/600]	Loss:0.0242
Epoch 019 [300/600]	Loss:0.0566
Epoch 019 [400/600]	Loss:0.0917

这是一个简单的PyTorch训练循环,主要包括以下内容:

1.设置学习率(alpha)和迭代次数(epochs)

2.通过循环迭代训练数据集(train_loader),获得模型预测输出(output),并计算损失(loss)

3.将参数的梯度清零,进行反向传播求导(loss.backward())

4.使用梯度下降法更新参数的数值,即 p.data = p.data - alpha * p.grad.data

5.按照指定的间隔(interval)打印训练信息,包括训练次数(epoch),当前训练进度(i/len(train_loader))和当前损失(loss.item())

6.在测试数据集(test_loader)上对模型进行测试,获得预测结果(output),并计算准确率(acc)

7.打印测试结果,包括训练次数(epoch)和测试准确率(acc)

需要注意的是,这里采用的是交叉熵函数(cross_entropy),用于计算loss。同时也清空了参数梯度,防止梯度累加。

cross-entropy(交叉熵)是一种度量概率分布之间的差异性的方法,常用于机器学习中的分类问题。在分类问题中,预测结果通常表示为概率分布,而交叉熵则表示这个预测结果和实际类别分布之间的差异。交叉熵越小,则模型预测的概率分布越接近实际类别分布,模型的性能也越好。交叉熵公式如下:

H ( p , q ) = − ∑ i p ( i ) log ⁡ q ( i ) H(p,q)=-\sum_{i}p(i)\log q(i) H(p,q)=−∑ip(i)logq(i)

其中, p p p 表示实际的类别分布, q q q 表示模型预测的概率分布。

附:系列文章

序号 文章目录 直达链接
1 PyTorch应用实战一:实现卷积操作 https://want595.blog.csdn.net/article/details/132575530
2 PyTorch应用实战二:实现卷积神经网络进行图像分类 https://want595.blog.csdn.net/article/details/132575702
3 PyTorch应用实战三:构建神经网络 https://want595.blog.csdn.net/article/details/132575758
4 PyTorch应用实战四:基于PyTorch构建复杂应用 https://want595.blog.csdn.net/article/details/132625270
5 PyTorch应用实战五:实现二值化神经网络 https://want595.blog.csdn.net/article/details/132625348
6 PyTorch应用实战六:利用LSTM实现文本情感分类 https://want595.blog.csdn.net/article/details/132625382
相关推荐
静心问道4 小时前
WGAN算法
深度学习·算法·机器学习
清纯世纪5 小时前
基于深度学习的图像分类或识别系统(含全套项目+PyQt5界面)
开发语言·python·深度学习
AIPaPerPass写论文5 小时前
写论文去哪个网站?2024最佳五款AI毕业论文学术网站
人工智能·深度学习·chatgpt·powerpoint·ai写作
5pace6 小时前
PyTorch深度学习快速入门教程【土堆】基础知识篇
人工智能·pytorch·深度学习
aWty_6 小时前
机器学习--卷积神经网络(包括python实现)
人工智能·机器学习·cnn
AI完全体8 小时前
AI小项目4-用Pytorch从头实现Transformer(详细注解)
人工智能·pytorch·深度学习·机器学习·语言模型·transformer·注意力机制
AI知识分享官9 小时前
智能绘画Midjourney AIGC在设计领域中的应用
人工智能·深度学习·语言模型·chatgpt·aigc·midjourney·llama
天南星9 小时前
PaddleOCR和PaddleLite的关联和区别
深度学习·图像识别
十有久诚9 小时前
TaskRes: Task Residual for Tuning Vision-Language Models
人工智能·深度学习·提示学习·视觉语言模型
PD我是你的真爱粉10 小时前
GPTo1论文详解
人工智能·深度学习