【项目实战】——深度学习.全连接神经网络

目录

1.使用全连接网络训练和验证MNIST数据集

2.使用全连接网络训练和验证CIFAR10数据集


1.使用全连接网络训练和验证MNIST数据集

python 复制代码
import torch
from torch import nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torch import optim
from PIL import Image
import os

# 数据预处理
transform = transforms.Compose([transforms.ToTensor()])

# 数据准备
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
eval_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True)

train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
eval_loader = DataLoader(dataset=eval_dataset, batch_size=512, shuffle=True)


# 定义网络结构
class MyNet(nn.Module):
    def __init__(self):
        super(MyNet, self).__init__()
        self.fc1 = nn.Linear(784, 256)
        self.bn1 = nn.BatchNorm1d(256)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(256, 128)
        self.bn2 = nn.BatchNorm1d(128)
        self.fc3 = nn.Linear(128, 10)

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = self.bn1(self.fc1(x))
        x = self.relu(x)
        x = self.bn2(self.fc2(x))
        x = self.relu(x)
        x = self.fc3(x)
        return x


model = MyNet()
# 定义损失函数
criterion = nn.CrossEntropyLoss()
# 定义优化器
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 训练
def train(model, train_loader, epochs):
    model.train()

    for epoch in range(epochs):
        correct = 0
        for data, target in train_loader:
            output = model(data)
            loss = criterion(output, target)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            _, predicted = torch.max(output.data, 1)

            correct += (predicted.eq(target)).sum().item()
        correct /= len(train_loader.dataset)

        print(f'Train Epoch:  {epoch} , loss: {loss.item():.4f},acc:{correct:.4f}')

# 验证
def eval(model, eval_loader):
    model.eval()
    eval_loss = 0
    correct = 0

    with torch.no_grad():
        for data, target in eval_loader:
            output = model(data)
            eval_loss += criterion(output, target).item()

            _, predicted = torch.max(output.data, 1)

            correct += (predicted.eq(target)).sum().item()

        eval_loss /= len(eval_loader.dataset)
        acc = 100.0 * correct / len(eval_loader.dataset)
        print(f'loss: {eval_loss:.4f}, acc: {acc:.4f}')

# 保存模型
def save_model():
    torch.save(model.state_dict(), 'mnist_fc_model.pt')

# 预测
def predict(img_path):
    model = MyNet()
    model.load_state_dict(torch.load('mnist_fc_model.pt'))
    model.eval()

    img = Image.open(img_path).convert('L')
    transform = transforms.Compose([
        transforms.Resize((28, 28)),
        transforms.ToTensor()
    ])
    t_img = transform(img).unsqueeze(0)
    print(t_img.shape)

    with torch.no_grad():
        output = model(t_img)
        _, predicted = torch.max(output.data, 1)

        print(predicted.item())


epochs = 5

train(model, train_loader, epochs)
eval(model, eval_loader)

save_model()

img_path = './img/7.png'
predict(img_path)

2.使用全连接网络训练和验证CIFAR10数据集

python 复制代码
import torch
from torch import nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torch import optim

# 数据预处理
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])

# 数据准备
train_dataset = datasets.CIFAR10(root='./cifar10', train=True, transform=transform, download=True)
eval_dataset = datasets.CIFAR10(root='./cifar10', train=False, transform=transform, download=True)

train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
eval_loader = DataLoader(dataset=eval_dataset, batch_size=512, shuffle=True)


# 定义网络结构
class MyNet(nn.Module):
    def __init__(self):
        super(MyNet, self).__init__()
        self.fc1 = nn.Linear(32 * 32 * 3, 1024)
        self.bn1 = nn.BatchNorm1d(1024)
        self.dropout1 = nn.Dropout(0.3)
        self.fc2 = nn.Linear(1024, 512)
        self.bn2 = nn.BatchNorm1d(512)
        self.dropout2 = nn.Dropout(0.3)
        self.fc3 = nn.Linear(512, 256)  # 增加第三层
        self.bn3 = nn.BatchNorm1d(256)
        self.fc4 = nn.Linear(256, 10)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = x.view(-1, 32 * 32 * 3)
        x = self.dropout1(self.bn1(self.fc1(x)))
        x = self.relu(x)
        x = self.dropout2(self.bn2(self.fc2(x)))
        x = self.relu(x)
        x = self.bn3(self.fc3(x))
        x = self.relu(x)
        x = self.fc4(x)
        return x


model = MyNet()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)


def train(model, train_loader, epochs):
    model.train()

    for epoch in range(epochs):
        correct = 0
        for data, target in train_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss = criterion(output, target)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            _, predicted = torch.max(output.data, 1)

            correct += (predicted.eq(target)).sum().item()
        correct /= len(train_loader.dataset)

        print(f'Train Epoch:  {epoch} , loss: {loss.item():.4f},acc:{correct:.4f}')


def eval(model, eval_loader):
    model.eval()
    eval_loss = 0
    correct = 0

    with torch.no_grad():
        for data, target in eval_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            eval_loss += criterion(output, target).item()

            _, predicted = torch.max(output.data, 1)

            correct += (predicted.eq(target)).sum().item()

        eval_loss /= len(eval_loader.dataset)
        acc = 100.0 * correct / len(eval_loader.dataset)
        print(f'loss: {eval_loss:.4f}, acc: {acc:.4f}')


epochs = 25

train(model, train_loader, epochs)
eval(model, eval_loader)

思考:为什么CIFAR10数据集的准确率很低?