DAY 38 MLP神经网络的训练

python 复制代码
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
import torch
import torch.nn as nn
import torch.optim as optim

iris = load_iris()
X = iris.data
y = iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

scaler = MinMaxScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)

X_train = torch.FloatTensor(X_train)
X_test = torch.FloatTensor(X_test)
y_train = torch.LongTensor(y_train)
y_test = torch.LongTensor(y_test)

class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(4, 10)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(10, 3)

    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        return out
    
model = MLP()

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

epoches = 20000
losses = []

for epoch in range(epoches):
    outputs = model.forward(X_train)
    loss = criterion(outputs, y_train)
    losses.append(loss.item())
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if (epoch + 1) % 100 == 0:
        print(f"Epoch [{epoch + 1}/{epoches}], Loss: {loss.item():.6f}")
python 复制代码
import matplotlib.pyplot as plt

plt.figure()
plt.plot(losses)
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Training Loss Curve")
plt.show()

@浙大疏锦行

相关推荐
deephub8 小时前
DeepSeek-R1 与 OpenAI o3 的启示:Test-Time Compute 技术不再迷信参数堆叠
人工智能·python·深度学习·大语言模型
yzx9910138 小时前
从“识别猫”到诊断疾病:卷积神经网络如何改变我们的视觉世界
人工智能·神经网络·cnn
Lululaurel8 小时前
AI编程文本挖掘提示词实战
人工智能·python·机器学习·ai·ai编程·提示词
一瞬祈望8 小时前
⭐ 深度学习入门体系(第 3 篇):反向传播到底怎么工作的?
人工智能·深度学习
Felaim9 小时前
Sparse4D 时序输入和 Feature Queue 详解
人工智能·深度学习·自动驾驶
LaughingZhu9 小时前
Product Hunt 每日热榜 | 2025-12-13
人工智能·经验分享·神经网络·搜索引擎·产品运营
学好statistics和DS9 小时前
机器学习中所有可以调整的超参数(考试/自己调参用)
人工智能·机器学习
老马啸西风9 小时前
成熟企业级技术平台 MVE-010-IGA(Identity Governance & Administration,身份治理与管理)平台
人工智能·深度学习·算法·职场和发展
老马啸西风10 小时前
成熟企业级技术平台 MVE-010-app 管理平台
人工智能·深度学习·算法·职场和发展