1.读取数据
使用CIFAR-10图像数据
python
import torch
from torchvision import datasets, transforms
# 定义图像预处理流程
image_transform = transforms.Compose([
transforms.ToTensor(), # 将PIL图像转换为张量
transforms.Normalize(mean=(0.5, 0.5, 0.5), # RGB三通道均值
std=(0.5, 0.5, 0.5)) # RGB三通道标准差
])
# 获取训练数据集
trainset = datasets.CIFAR10(
root='./data', # 数据集存储路径
train=True, # 使用训练集
transform=image_transform,
download=True # 如果本地不存在则下载
)
# 获取测试数据集
testset = datasets.CIFAR10(
root='./data',
train=False, # 使用测试集
transform=image_transform,
download=True
)
# 配置数据加载器
train_loader = torch.utils.data.DataLoader(
dataset=trainset,
batch_size=128, # 每批样本数量
shuffle=True # 训练时打乱顺序
)
test_loader = torch.utils.data.DataLoader(
dataset=testset,
batch_size=128,
shuffle=False # 测试时保持原始顺序
)
2.模型建立
(1)建立CNN模型
python
import torch
import torch.nn as nn
class SimpleCNN(nn.Module):
def __init__(self):
super(SimpleCNN, self).__init__()
self.conv1 = nn.Conv2d(3, 16, 3, padding=1)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
self.fc1 = nn.Linear(32 * 8 * 8, 256)
self.fc2 = nn.Linear(256, 10)
self.relu = nn.ReLU()
def forward(self, x):
x = self.pool(self.relu(self.conv1(x))) # 16x16x16
x = self.pool(self.relu(self.conv2(x))) # 32x8x8
x = x.view(-1, 32 * 8 * 8)
x = self.relu(self.fc1(x))
x = self.fc2(x)
return x