c
复制代码
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
# ---------------------------
# Basic setup
# ---------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")
# Matplotlib font (optional). Use English text by default per requirement.
plt.rcParams["font.family"] = ["DejaVu Sans"]
plt.rcParams["axes.unicode_minus"] = False
# ---------------------------
# Data transforms
# ---------------------------
train_transform = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
transforms.RandomRotation(15),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])
test_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])
# ---------------------------
# Load CIFAR-10
# ---------------------------
batch_size = 64
train_dataset = datasets.CIFAR10(root="./data", train=True, download=True, transform=train_transform)
test_dataset = datasets.CIFAR10(root="./data", train=False, download=True, transform=test_transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)
# ---------------------------
# CNN Models
# ---------------------------
class CNN_A(nn.Module):
"""3 conv blocks (baseline)"""
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
self.bn1 = nn.BatchNorm2d(32)
self.pool1 = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
self.bn2 = nn.BatchNorm2d(64)
self.pool2 = nn.MaxPool2d(2, 2)
self.conv3 = nn.Conv2d(64, 128, 3, padding=1)
self.bn3 = nn.BatchNorm2d(128)
self.pool3 = nn.MaxPool2d(2, 2)
self.relu = nn.ReLU()
self.fc1 = nn.Linear(128 * 4 * 4, 512)
self.drop = nn.Dropout(0.5)
self.fc2 = nn.Linear(512, 10)
def forward(self, x):
x = self.pool1(self.relu(self.bn1(self.conv1(x))))
x = self.pool2(self.relu(self.bn2(self.conv2(x))))
x = self.pool3(self.relu(self.bn3(self.conv3(x))))
x = x.view(x.size(0), -1)
x = self.drop(self.relu(self.fc1(x)))
x = self.fc2(x)
return x
class CNN_B(nn.Module):
"""4 conv blocks (deeper)"""
def __init__(self):
super().__init__()
self.features = nn.Sequential(
# Block 1 -> 16x16
nn.Conv2d(3, 32, 3, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.MaxPool2d(2),
# Block 2 -> 8x8
nn.Conv2d(32, 64, 3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.MaxPool2d(2),
# Block 3 -> 4x4
nn.Conv2d(64, 128, 3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.MaxPool2d(2),
# Block 4 -> 2x2
nn.Conv2d(128, 256, 3, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(),
nn.MaxPool2d(2),
)
self.classifier = nn.Sequential(
nn.Linear(256 * 2 * 2, 512),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(512, 10)
)
def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), -1)
x = self.classifier(x)
return x
# ---------------------------
# Utilities: evaluate + plotting
# ---------------------------
@torch.no_grad()
def evaluate(model, loader, criterion):
model.eval()
total_loss = 0.0
correct = 0
total = 0
for data, target in loader:
data, target = data.to(device), target.to(device)
out = model(data)
loss = criterion(out, target)
total_loss += loss.item()
pred = out.argmax(dim=1)
correct += (pred == target).sum().item()
total += target.size(0)
avg_loss = total_loss / len(loader)
acc = 100.0 * correct / total
return avg_loss, acc
def plot_iteration_loss(history, title):
plt.figure(figsize=(10, 4))
plt.plot(history["iter_idx"], history["iter_loss"], alpha=0.8, label="Train Iter Loss")
plt.xlabel("Iteration")
plt.ylabel("Loss")
plt.title(title + " - Iteration Loss")
plt.grid(True)
plt.legend()
plt.tight_layout()
plt.show()
def plot_epoch_curves(history, title):
epochs = list(range(1, len(history["train_loss"]) + 1))
plt.figure(figsize=(12, 4))
# Accuracy
plt.subplot(1, 2, 1)
plt.plot(epochs, history["train_acc"], label="Train Acc")
plt.plot(epochs, history["test_acc"], label="Test Acc")
plt.xlabel("Epoch")
plt.ylabel("Accuracy (%)")
plt.title(title + " - Accuracy")
plt.grid(True)
plt.legend()
# Loss
plt.subplot(1, 2, 2)
plt.plot(epochs, history["train_loss"], label="Train Loss")
plt.plot(epochs, history["test_loss"], label="Test Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title(title + " - Loss")
plt.grid(True)
plt.legend()
plt.tight_layout()
plt.show()
def plot_lr_curve(history, title):
epochs = list(range(1, len(history["lr"]) + 1))
plt.figure(figsize=(8, 4))
plt.plot(epochs, history["lr"], label="Learning Rate")
plt.xlabel("Epoch")
plt.ylabel("LR")
plt.title(title + " - Learning Rate")
plt.grid(True)
plt.legend()
plt.tight_layout()
plt.show()
# ---------------------------
# Train function (supports both schedulers)
# ---------------------------
def train_one_experiment(model, train_loader, test_loader, criterion, optimizer, scheduler, epochs, exp_name):
history = {
"iter_loss": [],
"iter_idx": [],
"train_loss": [],
"test_loss": [],
"train_acc": [],
"test_acc": [],
"lr": []
}
model.train()
global_iter = 0
for epoch in range(1, epochs + 1):
model.train()
running_loss = 0.0
correct = 0
total = 0
for batch_idx, (data, target) in enumerate(train_loader, start=1):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
out = model(data)
loss = criterion(out, target)
loss.backward()
optimizer.step()
# record iter loss
global_iter += 1
history["iter_loss"].append(loss.item())
history["iter_idx"].append(global_iter)
# stats
running_loss += loss.item()
pred = out.argmax(dim=1)
correct += (pred == target).sum().item()
total += target.size(0)
if batch_idx % 100 == 0:
print(f"[{exp_name}] Epoch {epoch}/{epochs} | Batch {batch_idx}/{len(train_loader)} "
f"| BatchLoss {loss.item():.4f} | AvgLoss {running_loss/batch_idx:.4f}")
train_loss = running_loss / len(train_loader)
train_acc = 100.0 * correct / total
test_loss, test_acc = evaluate(model, test_loader, criterion)
# scheduler step (handle both types)
if isinstance(scheduler, optim.lr_scheduler.ReduceLROnPlateau):
scheduler.step(test_loss)
else:
scheduler.step()
# record epoch metrics + lr
current_lr = optimizer.param_groups[0]["lr"]
history["lr"].append(current_lr)
history["train_loss"].append(train_loss)
history["test_loss"].append(test_loss)
history["train_acc"].append(train_acc)
history["test_acc"].append(test_acc)
print(f"[{exp_name}] Epoch {epoch}/{epochs} DONE | "
f"TrainAcc {train_acc:.2f}% | TestAcc {test_acc:.2f}% | "
f"TrainLoss {train_loss:.4f} | TestLoss {test_loss:.4f} | LR {current_lr:.6f}")
return history
# ---------------------------
# Experiment runner (4 combos)
# ---------------------------
def build_model(model_name):
if model_name == "CNN_A":
return CNN_A().to(device)
elif model_name == "CNN_B":
return CNN_B().to(device)
else:
raise ValueError("Unknown model_name")
def build_optimizer(model, lr=1e-3, weight_decay=1e-4):
return optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
def build_scheduler(scheduler_name, optimizer):
if scheduler_name == "StepLR":
return optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
elif scheduler_name == "Plateau":
return optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=3, verbose=True)
else:
raise ValueError("Unknown scheduler_name")
def run_all_experiments(epochs=20):
criterion = nn.CrossEntropyLoss()
experiments = [
("CNN_A", "StepLR"),
("CNN_A", "Plateau"),
("CNN_B", "StepLR"),
("CNN_B", "Plateau"),
]
results = {}
for model_name, sched_name in experiments:
exp_name = f"{model_name}+{sched_name}"
print("\n" + "=" * 80)
print(f"Start Experiment: {exp_name}")
print("=" * 80)
model = build_model(model_name)
optimizer = build_optimizer(model, lr=1e-3, weight_decay=1e-4)
scheduler = build_scheduler(sched_name, optimizer)
history = train_one_experiment(
model=model,
train_loader=train_loader,
test_loader=test_loader,
criterion=criterion,
optimizer=optimizer,
scheduler=scheduler,
epochs=epochs,
exp_name=exp_name
)
results[exp_name] = history
# plots per experiment
plot_iteration_loss(history, exp_name)
plot_epoch_curves(history, exp_name)
plot_lr_curve(history, exp_name)
print(f"[{exp_name}] Final Test Accuracy: {history['test_acc'][-1]:.2f}%")
# summary print
print("\n" + "#" * 80)
print("Summary (Final Test Accuracy):")
for exp_name, hist in results.items():
print(f"{exp_name:20s} -> {hist['test_acc'][-1]:.2f}%")
print("#" * 80)
return results
# ---------------------------
# Main
# ---------------------------
if __name__ == "__main__":
epochs = 20
results = run_all_experiments(epochs=epochs)