PyTorch入门之【CNN】

参考:https://www.bilibili.com/video/BV1114y1d79e/?spm_id_from=333.999.0.0\&vd_source=98d31d5c9db8c0021988f2c2c25a9620

书接上回的MLP故本章就不详细解释了

目录

train

python 复制代码
import torch
from torchvision.transforms import ToTensor
from torchvision import datasets
import torch.nn as nn

# load MNIST dataset
training_data = datasets.MNIST(
    root='../02_dataset/data',
    train=True,
    download=True,
    transform=ToTensor()
)

train_data_loader = torch.utils.data.DataLoader(training_data, batch_size=64, shuffle=True)

# define a CNN model
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv_1 = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=1),
            nn.BatchNorm2d(32),
            nn.ReLU()
        )
        self.conv_2 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=3, stride=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
        )
        self.maxpool = nn.MaxPool2d(2)
        self.flatten = nn.Flatten()
        self.fc_1 = nn.Sequential(
            nn.Linear(9216, 128),
            nn.BatchNorm1d(128),
            nn.ReLU()
        )
        self.fc_2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv_1(x)
        x = self.conv_2(x)
        x = self.maxpool(x)
        x = self.flatten(x)
        x = self.fc_1(x)
        logits = self.fc_2(x)
        return logits

# create a CNN model
device = 'cuda' if torch.cuda.is_available() else 'cpu'
cnn = CNN().to(device)
optimizer = torch.optim.Adam(cnn.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()

# train the model
num_epochs = 20

for epoch in range(num_epochs):
    print(f'Epoch {epoch+1}\n-------------------------------')
    for idx, (img, label) in enumerate(train_data_loader):
        size = len(train_data_loader.dataset)
        img, label = img.to(device), label.to(device)

        # compute prediction error
        pred = cnn(img)
        loss = loss_fn(pred, label)

        # backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if idx % 400 == 0:
            loss, current = loss.item(), idx*len(img)
            print(f'loss: {loss:>7f} [{current:>5d}/{size:>5d}]')

# save the model
torch.save(cnn.state_dict(), 'cnn.pth')
print('Saved PyTorch Model State to cnn.pth')

test

python 复制代码
import torch
from torchvision import datasets
from torchvision import transforms
from torchvision.transforms import ToTensor
from torchvision.datasets import ImageFolder
import torch.nn as nn

# load test data
test_data = datasets.MNIST(
    root='../02_dataset/data',
    train=False,
    download=True,
    transform=ToTensor()
)
test_data_loader = torch.utils.data.DataLoader(test_data, batch_size=64, shuffle=True)

transform = transforms.Compose([
    transforms.Grayscale(),
    transforms.RandomRotation(10),
    transforms.ToTensor()
])
my_mnist = ImageFolder(root='../02_dataset/my-mnist', transform=transform)
my_mnist_loader = torch.utils.data.DataLoader(my_mnist, batch_size=64, shuffle=True)

# define a CNN model
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv_1 = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=1),
            nn.BatchNorm2d(32),
            nn.ReLU()
        )
        self.conv_2 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=3, stride=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
        )
        self.maxpool = nn.MaxPool2d(2)
        self.flatten = nn.Flatten()
        self.fc_1 = nn.Sequential(
            nn.Linear(9216, 128),
            nn.BatchNorm1d(128),
            nn.ReLU()
        )
        self.fc_2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv_1(x)
        x = self.conv_2(x)
        x = self.maxpool(x)
        x = self.flatten(x)
        x = self.fc_1(x)
        logits = self.fc_2(x)
        return logits

# load the pretrained model
device = 'cuda' if torch.cuda.is_available() else 'cpu'
cnn = CNN()
cnn.load_state_dict(torch.load('cnn.pth', map_location=device))
cnn.eval().to(device)

# test the pretrained model on MNIST test data
size = len(test_data_loader.dataset)
correct = 0

with torch.no_grad():
    for img, label in test_data_loader:
        img, label = img.to(device), label.to(device)
        pred = cnn(img)

        correct += (pred.argmax(1) == label).type(torch.float).sum().item()

correct /= size
print(f'Accuracy on MNIST: {(100*correct):>0.1f}%')

# test the pretrained model on my MNIST test data
size = len(my_mnist_loader.dataset)
correct = 0

with torch.no_grad():
    for img, label in my_mnist_loader:
        img, label = img.to(device), label.to(device)
        pred = cnn(img)

        correct += (pred.argmax(1) == label).type(torch.float).sum().item()

correct /= size
print(f'Accuracy on my MNIST: {(100*correct):>0.1f}%')
相关推荐
Coder_Boy_8 分钟前
技术让开发更轻松的底层矛盾
java·大数据·数据库·人工智能·深度学习
啊森要自信13 分钟前
CANN ops-cv:面向计算机视觉的 AI 硬件端高效算子库核心架构与开发逻辑
人工智能·计算机视觉·架构·cann
2401_8362358616 分钟前
中安未来SDK15:以AI之眼,解锁企业档案的数字化基因
人工智能·科技·深度学习·ocr·生活
njsgcs19 分钟前
llm使用 AgentScope-Tuner 通过 RL 训练 FrozenLake 智能体
人工智能·深度学习
董董灿是个攻城狮24 分钟前
AI 视觉连载2:灰度图
人工智能
yunfuuwqi1 小时前
OpenClaw✅真·喂饭级教程:2026年OpenClaw(原Moltbot)一键部署+接入飞书最佳实践
运维·服务器·网络·人工智能·飞书·京东云
九河云1 小时前
5秒开服,你的应用部署还卡在“加载中”吗?
大数据·人工智能·安全·机器学习·华为云
人工智能培训1 小时前
具身智能视觉、触觉、力觉、听觉等信息如何实时对齐与融合?
人工智能·深度学习·大模型·transformer·企业数字化转型·具身智能
wenzhangli71 小时前
能力中心 (Agent SkillCenter):开启AI技能管理新时代
人工智能
后端小肥肠2 小时前
别再盲目抽卡了!Seedance 2.0 成本太高?教你用 Claude Code 100% 出片
人工智能·aigc·agent