pythonstudy Day45

简单CNN

@疏锦行

c 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

# ---------------------------
# Basic setup
# ---------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

# Matplotlib font (optional). Use English text by default per requirement.
plt.rcParams["font.family"] = ["DejaVu Sans"]
plt.rcParams["axes.unicode_minus"] = False

# ---------------------------
# Data transforms
# ---------------------------
train_transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.RandomRotation(15),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

# ---------------------------
# Load CIFAR-10
# ---------------------------
batch_size = 64
train_dataset = datasets.CIFAR10(root="./data", train=True, download=True, transform=train_transform)
test_dataset  = datasets.CIFAR10(root="./data", train=False, download=True, transform=test_transform)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
test_loader  = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)

# ---------------------------
# CNN Models
# ---------------------------
class CNN_A(nn.Module):
    """3 conv blocks (baseline)"""
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
        self.bn1   = nn.BatchNorm2d(32)
        self.pool1 = nn.MaxPool2d(2, 2)

        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.bn2   = nn.BatchNorm2d(64)
        self.pool2 = nn.MaxPool2d(2, 2)

        self.conv3 = nn.Conv2d(64, 128, 3, padding=1)
        self.bn3   = nn.BatchNorm2d(128)
        self.pool3 = nn.MaxPool2d(2, 2)

        self.relu = nn.ReLU()
        self.fc1  = nn.Linear(128 * 4 * 4, 512)
        self.drop = nn.Dropout(0.5)
        self.fc2  = nn.Linear(512, 10)

    def forward(self, x):
        x = self.pool1(self.relu(self.bn1(self.conv1(x))))
        x = self.pool2(self.relu(self.bn2(self.conv2(x))))
        x = self.pool3(self.relu(self.bn3(self.conv3(x))))
        x = x.view(x.size(0), -1)
        x = self.drop(self.relu(self.fc1(x)))
        x = self.fc2(x)
        return x


class CNN_B(nn.Module):
    """4 conv blocks (deeper)"""
    def __init__(self):
        super().__init__()
        self.features = nn.Sequential(
            # Block 1 -> 16x16
            nn.Conv2d(3, 32, 3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2),

            # Block 2 -> 8x8
            nn.Conv2d(32, 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2),

            # Block 3 -> 4x4
            nn.Conv2d(64, 128, 3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2),

            # Block 4 -> 2x2
            nn.Conv2d(128, 256, 3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )
        self.classifier = nn.Sequential(
            nn.Linear(256 * 2 * 2, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, 10)
        )

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

# ---------------------------
# Utilities: evaluate + plotting
# ---------------------------
@torch.no_grad()
def evaluate(model, loader, criterion):
    model.eval()
    total_loss = 0.0
    correct = 0
    total = 0
    for data, target in loader:
        data, target = data.to(device), target.to(device)
        out = model(data)
        loss = criterion(out, target)
        total_loss += loss.item()
        pred = out.argmax(dim=1)
        correct += (pred == target).sum().item()
        total += target.size(0)
    avg_loss = total_loss / len(loader)
    acc = 100.0 * correct / total
    return avg_loss, acc

def plot_iteration_loss(history, title):
    plt.figure(figsize=(10, 4))
    plt.plot(history["iter_idx"], history["iter_loss"], alpha=0.8, label="Train Iter Loss")
    plt.xlabel("Iteration")
    plt.ylabel("Loss")
    plt.title(title + " - Iteration Loss")
    plt.grid(True)
    plt.legend()
    plt.tight_layout()
    plt.show()

def plot_epoch_curves(history, title):
    epochs = list(range(1, len(history["train_loss"]) + 1))

    plt.figure(figsize=(12, 4))
    # Accuracy
    plt.subplot(1, 2, 1)
    plt.plot(epochs, history["train_acc"], label="Train Acc")
    plt.plot(epochs, history["test_acc"], label="Test Acc")
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy (%)")
    plt.title(title + " - Accuracy")
    plt.grid(True)
    plt.legend()

    # Loss
    plt.subplot(1, 2, 2)
    plt.plot(epochs, history["train_loss"], label="Train Loss")
    plt.plot(epochs, history["test_loss"], label="Test Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title(title + " - Loss")
    plt.grid(True)
    plt.legend()

    plt.tight_layout()
    plt.show()

def plot_lr_curve(history, title):
    epochs = list(range(1, len(history["lr"]) + 1))
    plt.figure(figsize=(8, 4))
    plt.plot(epochs, history["lr"], label="Learning Rate")
    plt.xlabel("Epoch")
    plt.ylabel("LR")
    plt.title(title + " - Learning Rate")
    plt.grid(True)
    plt.legend()
    plt.tight_layout()
    plt.show()

# ---------------------------
# Train function (supports both schedulers)
# ---------------------------
def train_one_experiment(model, train_loader, test_loader, criterion, optimizer, scheduler, epochs, exp_name):
    history = {
        "iter_loss": [],
        "iter_idx": [],
        "train_loss": [],
        "test_loss": [],
        "train_acc": [],
        "test_acc": [],
        "lr": []
    }

    model.train()
    global_iter = 0

    for epoch in range(1, epochs + 1):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        for batch_idx, (data, target) in enumerate(train_loader, start=1):
            data, target = data.to(device), target.to(device)

            optimizer.zero_grad()
            out = model(data)
            loss = criterion(out, target)
            loss.backward()
            optimizer.step()

            # record iter loss
            global_iter += 1
            history["iter_loss"].append(loss.item())
            history["iter_idx"].append(global_iter)

            # stats
            running_loss += loss.item()
            pred = out.argmax(dim=1)
            correct += (pred == target).sum().item()
            total += target.size(0)

            if batch_idx % 100 == 0:
                print(f"[{exp_name}] Epoch {epoch}/{epochs} | Batch {batch_idx}/{len(train_loader)} "
                      f"| BatchLoss {loss.item():.4f} | AvgLoss {running_loss/batch_idx:.4f}")

        train_loss = running_loss / len(train_loader)
        train_acc = 100.0 * correct / total

        test_loss, test_acc = evaluate(model, test_loader, criterion)

        # scheduler step (handle both types)
        if isinstance(scheduler, optim.lr_scheduler.ReduceLROnPlateau):
            scheduler.step(test_loss)
        else:
            scheduler.step()

        # record epoch metrics + lr
        current_lr = optimizer.param_groups[0]["lr"]
        history["lr"].append(current_lr)
        history["train_loss"].append(train_loss)
        history["test_loss"].append(test_loss)
        history["train_acc"].append(train_acc)
        history["test_acc"].append(test_acc)

        print(f"[{exp_name}] Epoch {epoch}/{epochs} DONE | "
              f"TrainAcc {train_acc:.2f}% | TestAcc {test_acc:.2f}% | "
              f"TrainLoss {train_loss:.4f} | TestLoss {test_loss:.4f} | LR {current_lr:.6f}")

    return history

# ---------------------------
# Experiment runner (4 combos)
# ---------------------------
def build_model(model_name):
    if model_name == "CNN_A":
        return CNN_A().to(device)
    elif model_name == "CNN_B":
        return CNN_B().to(device)
    else:
        raise ValueError("Unknown model_name")

def build_optimizer(model, lr=1e-3, weight_decay=1e-4):
    return optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)

def build_scheduler(scheduler_name, optimizer):
    if scheduler_name == "StepLR":
        return optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
    elif scheduler_name == "Plateau":
        return optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=3, verbose=True)
    else:
        raise ValueError("Unknown scheduler_name")

def run_all_experiments(epochs=20):
    criterion = nn.CrossEntropyLoss()

    experiments = [
        ("CNN_A", "StepLR"),
        ("CNN_A", "Plateau"),
        ("CNN_B", "StepLR"),
        ("CNN_B", "Plateau"),
    ]

    results = {}

    for model_name, sched_name in experiments:
        exp_name = f"{model_name}+{sched_name}"
        print("\n" + "=" * 80)
        print(f"Start Experiment: {exp_name}")
        print("=" * 80)

        model = build_model(model_name)
        optimizer = build_optimizer(model, lr=1e-3, weight_decay=1e-4)
        scheduler = build_scheduler(sched_name, optimizer)

        history = train_one_experiment(
            model=model,
            train_loader=train_loader,
            test_loader=test_loader,
            criterion=criterion,
            optimizer=optimizer,
            scheduler=scheduler,
            epochs=epochs,
            exp_name=exp_name
        )

        results[exp_name] = history

        # plots per experiment
        plot_iteration_loss(history, exp_name)
        plot_epoch_curves(history, exp_name)
        plot_lr_curve(history, exp_name)

        print(f"[{exp_name}] Final Test Accuracy: {history['test_acc'][-1]:.2f}%")

    # summary print
    print("\n" + "#" * 80)
    print("Summary (Final Test Accuracy):")
    for exp_name, hist in results.items():
        print(f"{exp_name:20s} -> {hist['test_acc'][-1]:.2f}%")
    print("#" * 80)

    return results

# ---------------------------
# Main
# ---------------------------
if __name__ == "__main__":
    epochs = 20
    results = run_all_experiments(epochs=epochs)
相关推荐
Yu_iChan2 小时前
Day03 公共字段填充与菜品管理
java·开发语言
独自破碎E2 小时前
如何防止接口被恶意刷量?
java·开发语言
期待のcode3 小时前
Java的单例模式
java·开发语言·单例模式
Aliex_git3 小时前
内存堆栈分析笔记
开发语言·javascript·笔记
AI手记叨叨3 小时前
Python数学:统计运算
python·数学·统计运算·描述统计·概率运算
Brian Xia3 小时前
从0开始手写AI Agent框架:nano-agentscope(一)项目介绍
人工智能·python·ai
Sui_Network3 小时前
Sui 2025→2026 直播回顾中文版
大数据·前端·人工智能·深度学习·区块链
LYOBOYI1233 小时前
qml练习:创建地图玩家并且实现人物移动(2)
开发语言·qt
电商API&Tina3 小时前
【电商API接口】多电商平台数据API接入方案(附带实例)
运维·开发语言·数据库·chrome·爬虫·python·jenkins