pythonstudy Day39

模型可视化和推理


@疏锦行

clike 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
import random

# =====================
# 1. 固定随机种子(保证可复现)
# =====================
seed = 42
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)

# =====================
# 2. 加载与预处理数据
# =====================
iris = load_iris()
X = iris.data
y = iris.target

X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=seed
)

scaler = MinMaxScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)

X_train = torch.FloatTensor(X_train)
y_train = torch.LongTensor(y_train)
X_test = torch.FloatTensor(X_test)
y_test = torch.LongTensor(y_test)

# =====================
# 3. 定义 MLP 模型
# =====================
class MLP(nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        self.fc1 = nn.Linear(4, hidden_dim)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_dim, 3)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

# =====================
# 4. 准确率函数
# =====================
def accuracy(logits, y):
    preds = torch.argmax(logits, dim=1)
    return (preds == y).float().mean().item()

# =====================
# 5. 训练函数(核心)
# =====================
def train_model(hidden_dim, optimizer_name, lr, weight_decay, epochs=2000):
    model = MLP(hidden_dim)
    criterion = nn.CrossEntropyLoss()

    if optimizer_name == "SGD":
        optimizer = optim.SGD(
            model.parameters(),
            lr=lr,
            momentum=0.9,
            weight_decay=weight_decay
        )
    else:
        optimizer = optim.Adam(
            model.parameters(),
            lr=lr,
            weight_decay=weight_decay
        )

    best_test_acc = 0.0
    losses = []

    for epoch in range(epochs):
        # forward
        outputs = model(X_train)
        loss = criterion(outputs, y_train)

        # backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        losses.append(loss.item())

        # evaluation
        if (epoch + 1) % 200 == 0:
            with torch.no_grad():
                train_acc = accuracy(model(X_train), y_train)
                test_acc = accuracy(model(X_test), y_test)
                best_test_acc = max(best_test_acc, test_acc)

    return best_test_acc, losses

# =====================
# 6. 超参数组合(对比实验)
# =====================
experiments = [
    # 学习率对比(SGD)
    {"hidden": 10, "opt": "SGD", "lr": 0.1,   "wd": 0.0},
    {"hidden": 10, "opt": "SGD", "lr": 0.01,  "wd": 0.0},
    {"hidden": 10, "opt": "SGD", "lr": 0.001, "wd": 0.0},

    # 优化器对比
    {"hidden": 10, "opt": "Adam", "lr": 0.001, "wd": 0.0},

    # 隐藏层大小对比
    {"hidden": 4,  "opt": "Adam", "lr": 0.001, "wd": 0.0},
    {"hidden": 32, "opt": "Adam", "lr": 0.001, "wd": 0.0},

    # 正则化对比
    {"hidden": 10, "opt": "Adam", "lr": 0.001, "wd": 1e-4},
    {"hidden": 10, "opt": "Adam", "lr": 0.001, "wd": 1e-3},
]

# =====================
# 7. 运行实验
# =====================
results = []

for i, cfg in enumerate(experiments):
    print(f"\nRunning Experiment {i+1}/{len(experiments)}: {cfg}")
    best_acc, _ = train_model(
        hidden_dim=cfg["hidden"],
        optimizer_name=cfg["opt"],
        lr=cfg["lr"],
        weight_decay=cfg["wd"]
    )
    results.append({
        **cfg,
        "best_test_acc": best_acc
    })

# =====================
# 8. 打印最终对比结果
# =====================
print("\n========== Final Results ==========")
print("hidden | opt  | lr     | weight_decay | best_test_acc")
print("-" * 55)
for r in results:
    print(
        f"{r['hidden']:>6} | "
        f"{r['opt']:<4} | "
        f"{r['lr']:<6} | "
        f"{r['wd']:<12} | "
        f"{r['best_test_acc']:.4f}"
    )


相关推荐
阿尔的代码屋4 小时前
[大模型实战 07] 基于 LlamaIndex ReAct 框架手搓全自动博客监控 Agent
人工智能·python
AI探索者21 小时前
LangGraph StateGraph 实战:状态机聊天机器人构建指南
python
AI探索者21 小时前
LangGraph 入门:构建带记忆功能的天气查询 Agent
python
FishCoderh1 天前
Python自动化办公实战:批量重命名文件,告别手动操作
python
躺平大鹅1 天前
Python函数入门详解(定义+调用+参数)
python
曲幽1 天前
我用FastAPI接ollama大模型,差点被asyncio整崩溃(附对话窗口实战)
python·fastapi·web·async·httpx·asyncio·ollama
两万五千个小时1 天前
落地实现 Anthropic Multi-Agent Research System
人工智能·python·架构
哈里谢顿1 天前
Python 高并发服务限流终极方案:从原理到生产落地(2026 实战指南)
python
用户8356290780512 天前
无需 Office:Python 批量转换 PPT 为图片
后端·python