目录
1.使用全连接网络训练和验证MNIST数据集
python
import torch
from torch import nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torch import optim
from PIL import Image
import os
# 数据预处理
transform = transforms.Compose([transforms.ToTensor()])
# 数据准备
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
eval_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True)
train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
eval_loader = DataLoader(dataset=eval_dataset, batch_size=512, shuffle=True)
# 定义网络结构
class MyNet(nn.Module):
def __init__(self):
super(MyNet, self).__init__()
self.fc1 = nn.Linear(784, 256)
self.bn1 = nn.BatchNorm1d(256)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(256, 128)
self.bn2 = nn.BatchNorm1d(128)
self.fc3 = nn.Linear(128, 10)
def forward(self, x):
x = x.view(-1, 28 * 28)
x = self.bn1(self.fc1(x))
x = self.relu(x)
x = self.bn2(self.fc2(x))
x = self.relu(x)
x = self.fc3(x)
return x
model = MyNet()
# 定义损失函数
criterion = nn.CrossEntropyLoss()
# 定义优化器
optimizer = optim.Adam(model.parameters(), lr=0.001)
# 训练
def train(model, train_loader, epochs):
model.train()
for epoch in range(epochs):
correct = 0
for data, target in train_loader:
output = model(data)
loss = criterion(output, target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
_, predicted = torch.max(output.data, 1)
correct += (predicted.eq(target)).sum().item()
correct /= len(train_loader.dataset)
print(f'Train Epoch: {epoch} , loss: {loss.item():.4f},acc:{correct:.4f}')
# 验证
def eval(model, eval_loader):
model.eval()
eval_loss = 0
correct = 0
with torch.no_grad():
for data, target in eval_loader:
output = model(data)
eval_loss += criterion(output, target).item()
_, predicted = torch.max(output.data, 1)
correct += (predicted.eq(target)).sum().item()
eval_loss /= len(eval_loader.dataset)
acc = 100.0 * correct / len(eval_loader.dataset)
print(f'loss: {eval_loss:.4f}, acc: {acc:.4f}')
# 保存模型
def save_model():
torch.save(model.state_dict(), 'mnist_fc_model.pt')
# 预测
def predict(img_path):
model = MyNet()
model.load_state_dict(torch.load('mnist_fc_model.pt'))
model.eval()
img = Image.open(img_path).convert('L')
transform = transforms.Compose([
transforms.Resize((28, 28)),
transforms.ToTensor()
])
t_img = transform(img).unsqueeze(0)
print(t_img.shape)
with torch.no_grad():
output = model(t_img)
_, predicted = torch.max(output.data, 1)
print(predicted.item())
epochs = 5
train(model, train_loader, epochs)
eval(model, eval_loader)
save_model()
img_path = './img/7.png'
predict(img_path)
2.使用全连接网络训练和验证CIFAR10数据集
python
import torch
from torch import nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torch import optim
# 数据预处理
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
# 数据准备
train_dataset = datasets.CIFAR10(root='./cifar10', train=True, transform=transform, download=True)
eval_dataset = datasets.CIFAR10(root='./cifar10', train=False, transform=transform, download=True)
train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
eval_loader = DataLoader(dataset=eval_dataset, batch_size=512, shuffle=True)
# 定义网络结构
class MyNet(nn.Module):
def __init__(self):
super(MyNet, self).__init__()
self.fc1 = nn.Linear(32 * 32 * 3, 1024)
self.bn1 = nn.BatchNorm1d(1024)
self.dropout1 = nn.Dropout(0.3)
self.fc2 = nn.Linear(1024, 512)
self.bn2 = nn.BatchNorm1d(512)
self.dropout2 = nn.Dropout(0.3)
self.fc3 = nn.Linear(512, 256) # 增加第三层
self.bn3 = nn.BatchNorm1d(256)
self.fc4 = nn.Linear(256, 10)
self.relu = nn.ReLU()
def forward(self, x):
x = x.view(-1, 32 * 32 * 3)
x = self.dropout1(self.bn1(self.fc1(x)))
x = self.relu(x)
x = self.dropout2(self.bn2(self.fc2(x)))
x = self.relu(x)
x = self.bn3(self.fc3(x))
x = self.relu(x)
x = self.fc4(x)
return x
model = MyNet()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
def train(model, train_loader, epochs):
model.train()
for epoch in range(epochs):
correct = 0
for data, target in train_loader:
data, target = data.to(device), target.to(device)
output = model(data)
loss = criterion(output, target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
_, predicted = torch.max(output.data, 1)
correct += (predicted.eq(target)).sum().item()
correct /= len(train_loader.dataset)
print(f'Train Epoch: {epoch} , loss: {loss.item():.4f},acc:{correct:.4f}')
def eval(model, eval_loader):
model.eval()
eval_loss = 0
correct = 0
with torch.no_grad():
for data, target in eval_loader:
data, target = data.to(device), target.to(device)
output = model(data)
eval_loss += criterion(output, target).item()
_, predicted = torch.max(output.data, 1)
correct += (predicted.eq(target)).sum().item()
eval_loss /= len(eval_loader.dataset)
acc = 100.0 * correct / len(eval_loader.dataset)
print(f'loss: {eval_loss:.4f}, acc: {acc:.4f}')
epochs = 25
train(model, train_loader, epochs)
eval(model, eval_loader)
思考:为什么CIFAR10数据集的准确率很低?