DAY 38 MLP神经网络的训练

python 复制代码
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
import torch
import torch.nn as nn
import torch.optim as optim

iris = load_iris()
X = iris.data
y = iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

scaler = MinMaxScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)

X_train = torch.FloatTensor(X_train)
X_test = torch.FloatTensor(X_test)
y_train = torch.LongTensor(y_train)
y_test = torch.LongTensor(y_test)

class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(4, 10)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(10, 3)

    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        return out
    
model = MLP()

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

epoches = 20000
losses = []

for epoch in range(epoches):
    outputs = model.forward(X_train)
    loss = criterion(outputs, y_train)
    losses.append(loss.item())
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if (epoch + 1) % 100 == 0:
        print(f"Epoch [{epoch + 1}/{epoches}], Loss: {loss.item():.6f}")
python 复制代码
import matplotlib.pyplot as plt

plt.figure()
plt.plot(losses)
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Training Loss Curve")
plt.show()

@浙大疏锦行

相关推荐
island13145 分钟前
CANN ops-nn 算子库深度解析:核心算子(如激活函数、归一化)的数值精度控制与内存高效实现
开发语言·人工智能·神经网络
brave and determined1 小时前
CANN ops-nn算子库使用教程:实现神经网络在NPU上的加速计算
人工智能·深度学习·神经网络
笔画人生1 小时前
系统级整合:`ops-transformer` 在 CANN 全栈架构中的角色与实践
深度学习·架构·transformer
觉醒大王2 小时前
AI写的青基中了
人工智能·笔记·深度学习·学习·职场和发展·学习方法
深鱼~2 小时前
深度剖析ops-transformer:LayerNorm与GEMM的融合优化
人工智能·深度学习·transformer
哈__2 小时前
CANN图优化技术:深度学习模型的编译器魔法
人工智能·深度学习
灰灰勇闯IT2 小时前
神经网络的基石——深度解析 CANN ops-nn 算子库如何赋能昇腾 AI
人工智能·深度学习·神经网络
deephub2 小时前
LLM推理时计算技术详解:四种提升大模型推理能力的方法
人工智能·深度学习·大语言模型·推理时计算
chian-ocean2 小时前
智能多模态助手实战:基于 `ops-transformer` 与开源 LLM 构建 LLaVA 风格推理引擎
深度学习·开源·transformer
慢半拍iii2 小时前
对比源码解读:ops-nn中卷积算子的硬件加速实现原理
人工智能·深度学习·ai·cann