PyTorch入门之【CNN】

参考:https://www.bilibili.com/video/BV1114y1d79e/?spm_id_from=333.999.0.0\&vd_source=98d31d5c9db8c0021988f2c2c25a9620

书接上回的MLP故本章就不详细解释了

目录

train

python 复制代码
import torch
from torchvision.transforms import ToTensor
from torchvision import datasets
import torch.nn as nn

# load MNIST dataset
training_data = datasets.MNIST(
    root='../02_dataset/data',
    train=True,
    download=True,
    transform=ToTensor()
)

train_data_loader = torch.utils.data.DataLoader(training_data, batch_size=64, shuffle=True)

# define a CNN model
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv_1 = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=1),
            nn.BatchNorm2d(32),
            nn.ReLU()
        )
        self.conv_2 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=3, stride=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
        )
        self.maxpool = nn.MaxPool2d(2)
        self.flatten = nn.Flatten()
        self.fc_1 = nn.Sequential(
            nn.Linear(9216, 128),
            nn.BatchNorm1d(128),
            nn.ReLU()
        )
        self.fc_2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv_1(x)
        x = self.conv_2(x)
        x = self.maxpool(x)
        x = self.flatten(x)
        x = self.fc_1(x)
        logits = self.fc_2(x)
        return logits

# create a CNN model
device = 'cuda' if torch.cuda.is_available() else 'cpu'
cnn = CNN().to(device)
optimizer = torch.optim.Adam(cnn.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()

# train the model
num_epochs = 20

for epoch in range(num_epochs):
    print(f'Epoch {epoch+1}\n-------------------------------')
    for idx, (img, label) in enumerate(train_data_loader):
        size = len(train_data_loader.dataset)
        img, label = img.to(device), label.to(device)

        # compute prediction error
        pred = cnn(img)
        loss = loss_fn(pred, label)

        # backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if idx % 400 == 0:
            loss, current = loss.item(), idx*len(img)
            print(f'loss: {loss:>7f} [{current:>5d}/{size:>5d}]')

# save the model
torch.save(cnn.state_dict(), 'cnn.pth')
print('Saved PyTorch Model State to cnn.pth')

test

python 复制代码
import torch
from torchvision import datasets
from torchvision import transforms
from torchvision.transforms import ToTensor
from torchvision.datasets import ImageFolder
import torch.nn as nn

# load test data
test_data = datasets.MNIST(
    root='../02_dataset/data',
    train=False,
    download=True,
    transform=ToTensor()
)
test_data_loader = torch.utils.data.DataLoader(test_data, batch_size=64, shuffle=True)

transform = transforms.Compose([
    transforms.Grayscale(),
    transforms.RandomRotation(10),
    transforms.ToTensor()
])
my_mnist = ImageFolder(root='../02_dataset/my-mnist', transform=transform)
my_mnist_loader = torch.utils.data.DataLoader(my_mnist, batch_size=64, shuffle=True)

# define a CNN model
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv_1 = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=1),
            nn.BatchNorm2d(32),
            nn.ReLU()
        )
        self.conv_2 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=3, stride=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
        )
        self.maxpool = nn.MaxPool2d(2)
        self.flatten = nn.Flatten()
        self.fc_1 = nn.Sequential(
            nn.Linear(9216, 128),
            nn.BatchNorm1d(128),
            nn.ReLU()
        )
        self.fc_2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv_1(x)
        x = self.conv_2(x)
        x = self.maxpool(x)
        x = self.flatten(x)
        x = self.fc_1(x)
        logits = self.fc_2(x)
        return logits

# load the pretrained model
device = 'cuda' if torch.cuda.is_available() else 'cpu'
cnn = CNN()
cnn.load_state_dict(torch.load('cnn.pth', map_location=device))
cnn.eval().to(device)

# test the pretrained model on MNIST test data
size = len(test_data_loader.dataset)
correct = 0

with torch.no_grad():
    for img, label in test_data_loader:
        img, label = img.to(device), label.to(device)
        pred = cnn(img)

        correct += (pred.argmax(1) == label).type(torch.float).sum().item()

correct /= size
print(f'Accuracy on MNIST: {(100*correct):>0.1f}%')

# test the pretrained model on my MNIST test data
size = len(my_mnist_loader.dataset)
correct = 0

with torch.no_grad():
    for img, label in my_mnist_loader:
        img, label = img.to(device), label.to(device)
        pred = cnn(img)

        correct += (pred.argmax(1) == label).type(torch.float).sum().item()

correct /= size
print(f'Accuracy on my MNIST: {(100*correct):>0.1f}%')
相关推荐
不爱学英文的码字机器几秒前
深度解读CANN生态核心仓库——catlass,打造高效可扩展的分类器技术底座
人工智能·cann
Kiyra1 分钟前
作为后端开发你不得不知的 AI 知识——RAG
人工智能·语言模型
共享家95275 分钟前
Vibe Coding 与 LangChain、LangGraph 的协同进化
人工智能
dvlinker7 分钟前
2026远程桌面安全白皮书:ToDesk/TeamViewer/向日葵核心安全性与合规性横向测评
人工智能
2的n次方_8 分钟前
CANN ascend-transformer-boost 深度解析:针对大模型的高性能融合算子库与算力优化机制
人工智能·深度学习·transformer
熊猫_豆豆9 分钟前
YOLOP车道检测
人工智能·python·算法
nimadan1210 分钟前
**热门短剧小说扫榜工具2025推荐,精准捕捉爆款趋势与流量
人工智能·python
qq_124987075313 分钟前
基于JavaWeb的大学生房屋租赁系统(源码+论文+部署+安装)
java·数据库·人工智能·spring boot·计算机视觉·毕业设计·计算机毕业设计
杜子不疼.14 分钟前
CANN算子基础框架库opbase的算子开发与扩展机制深度解析
人工智能
程序猿追15 分钟前
CANN ops-math仓库解读 数学算子的底层支撑与高性能实现
人工智能·架构