DAY 38 MLP神经网络的训练

python 复制代码
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
import torch
import torch.nn as nn
import torch.optim as optim

iris = load_iris()
X = iris.data
y = iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

scaler = MinMaxScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)

X_train = torch.FloatTensor(X_train)
X_test = torch.FloatTensor(X_test)
y_train = torch.LongTensor(y_train)
y_test = torch.LongTensor(y_test)

class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(4, 10)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(10, 3)

    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        return out
    
model = MLP()

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

epoches = 20000
losses = []

for epoch in range(epoches):
    outputs = model.forward(X_train)
    loss = criterion(outputs, y_train)
    losses.append(loss.item())
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if (epoch + 1) % 100 == 0:
        print(f"Epoch [{epoch + 1}/{epoches}], Loss: {loss.item():.6f}")
python 复制代码
import matplotlib.pyplot as plt

plt.figure()
plt.plot(losses)
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Training Loss Curve")
plt.show()

@浙大疏锦行

相关推荐
Z-D-K14 小时前
S-44的周末”旅行“-周六
人工智能·机器学习·aigc·交互·agi
STDD15 小时前
Kubeflow ML 流水线 K8s 部署教程:机器学习工作流编排全攻略
机器学习·容器·kubernetes
cyyt15 小时前
深度学习周报(6.8~6.14)
人工智能·深度学习
Master_oid15 小时前
机器学习46:逻辑回归--基础篇
人工智能·机器学习·逻辑回归
chen_zn9515 小时前
OpenPi、GR00T的视觉语言模型与动作模型连接方式差异分析总结
人工智能·深度学习·具身智能·vla
云和数据.ChenGuang15 小时前
大模型厂商常用的数据库有哪些?
数据库·人工智能·pytorch·深度学习·numpy
旅僧15 小时前
Bert理论讲解
人工智能·深度学习·bert
FL162386312915 小时前
基于CNN深度学习算实现手写字母识别系统python源码+训练好的模型+说明文档
python·深度学习·cnn
老饼讲解-BP神经网络16 小时前
BP神经网络用什么训练算法(traingd、traingdm、trainlm)
人工智能·神经网络·算法
Godspeed Zhao17 小时前
Level 4自动驾驶系统设计2——功能与场景2
人工智能·机器学习·自动驾驶