PyTorch入门之【CNN】

参考:https://www.bilibili.com/video/BV1114y1d79e/?spm_id_from=333.999.0.0\&vd_source=98d31d5c9db8c0021988f2c2c25a9620

书接上回的MLP故本章就不详细解释了

目录

train

python 复制代码
import torch
from torchvision.transforms import ToTensor
from torchvision import datasets
import torch.nn as nn

# load MNIST dataset
training_data = datasets.MNIST(
    root='../02_dataset/data',
    train=True,
    download=True,
    transform=ToTensor()
)

train_data_loader = torch.utils.data.DataLoader(training_data, batch_size=64, shuffle=True)

# define a CNN model
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv_1 = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=1),
            nn.BatchNorm2d(32),
            nn.ReLU()
        )
        self.conv_2 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=3, stride=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
        )
        self.maxpool = nn.MaxPool2d(2)
        self.flatten = nn.Flatten()
        self.fc_1 = nn.Sequential(
            nn.Linear(9216, 128),
            nn.BatchNorm1d(128),
            nn.ReLU()
        )
        self.fc_2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv_1(x)
        x = self.conv_2(x)
        x = self.maxpool(x)
        x = self.flatten(x)
        x = self.fc_1(x)
        logits = self.fc_2(x)
        return logits

# create a CNN model
device = 'cuda' if torch.cuda.is_available() else 'cpu'
cnn = CNN().to(device)
optimizer = torch.optim.Adam(cnn.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()

# train the model
num_epochs = 20

for epoch in range(num_epochs):
    print(f'Epoch {epoch+1}\n-------------------------------')
    for idx, (img, label) in enumerate(train_data_loader):
        size = len(train_data_loader.dataset)
        img, label = img.to(device), label.to(device)

        # compute prediction error
        pred = cnn(img)
        loss = loss_fn(pred, label)

        # backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if idx % 400 == 0:
            loss, current = loss.item(), idx*len(img)
            print(f'loss: {loss:>7f} [{current:>5d}/{size:>5d}]')

# save the model
torch.save(cnn.state_dict(), 'cnn.pth')
print('Saved PyTorch Model State to cnn.pth')

test

python 复制代码
import torch
from torchvision import datasets
from torchvision import transforms
from torchvision.transforms import ToTensor
from torchvision.datasets import ImageFolder
import torch.nn as nn

# load test data
test_data = datasets.MNIST(
    root='../02_dataset/data',
    train=False,
    download=True,
    transform=ToTensor()
)
test_data_loader = torch.utils.data.DataLoader(test_data, batch_size=64, shuffle=True)

transform = transforms.Compose([
    transforms.Grayscale(),
    transforms.RandomRotation(10),
    transforms.ToTensor()
])
my_mnist = ImageFolder(root='../02_dataset/my-mnist', transform=transform)
my_mnist_loader = torch.utils.data.DataLoader(my_mnist, batch_size=64, shuffle=True)

# define a CNN model
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv_1 = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=1),
            nn.BatchNorm2d(32),
            nn.ReLU()
        )
        self.conv_2 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=3, stride=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
        )
        self.maxpool = nn.MaxPool2d(2)
        self.flatten = nn.Flatten()
        self.fc_1 = nn.Sequential(
            nn.Linear(9216, 128),
            nn.BatchNorm1d(128),
            nn.ReLU()
        )
        self.fc_2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv_1(x)
        x = self.conv_2(x)
        x = self.maxpool(x)
        x = self.flatten(x)
        x = self.fc_1(x)
        logits = self.fc_2(x)
        return logits

# load the pretrained model
device = 'cuda' if torch.cuda.is_available() else 'cpu'
cnn = CNN()
cnn.load_state_dict(torch.load('cnn.pth', map_location=device))
cnn.eval().to(device)

# test the pretrained model on MNIST test data
size = len(test_data_loader.dataset)
correct = 0

with torch.no_grad():
    for img, label in test_data_loader:
        img, label = img.to(device), label.to(device)
        pred = cnn(img)

        correct += (pred.argmax(1) == label).type(torch.float).sum().item()

correct /= size
print(f'Accuracy on MNIST: {(100*correct):>0.1f}%')

# test the pretrained model on my MNIST test data
size = len(my_mnist_loader.dataset)
correct = 0

with torch.no_grad():
    for img, label in my_mnist_loader:
        img, label = img.to(device), label.to(device)
        pred = cnn(img)

        correct += (pred.argmax(1) == label).type(torch.float).sum().item()

correct /= size
print(f'Accuracy on my MNIST: {(100*correct):>0.1f}%')
相关推荐
xier_ran22 分钟前
深度学习:Mini-Batch 梯度下降(Mini-Batch Gradient Descent)
人工智能·深度学习·batch
Microvision维视智造33 分钟前
变速箱阀芯上料易错漏?通用 2D 视觉方案高效破局,成汽车制造检测优选!
人工智能
AAA小肥杨34 分钟前
探索K8s与AI的结合:PyTorch训练任务在k8s上调度实践
人工智能·pytorch·docker·ai·云原生·kubernetes
飞哥数智坊1 小时前
TRAE Friends 落地济南!首场线下活动圆满结束
人工智能·trae·solo
m0_527653901 小时前
NVIDIA Orin NX使用Jetpack安装CUDA、cuDNN、TensorRT、VPI时的error及解决方法
linux·人工智能·jetpack·nvidia orin nx
wbzuo1 小时前
Clip:Learning Transferable Visual Models From Natural Language Supervision
论文阅读·人工智能·transformer
带土11 小时前
2. YOLOv5 搭建一个完整的目标检测系统核心步骤
人工智能·yolo·目标检测
1***Q7842 小时前
PyTorch图像分割实战,U-Net模型训练与部署
人工智能·pytorch·python
阿十六2 小时前
OUC AI Lab 第六章:基于卷积的注意力机制
人工智能
努力の小熊2 小时前
基于tensorflow框架的MSCNN-LSTM模型在CWRU轴承故障诊断的应用
人工智能·tensorflow·lstm