DAY 38 MLP神经网络的训练

python 复制代码
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
import torch
import torch.nn as nn
import torch.optim as optim

iris = load_iris()
X = iris.data
y = iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

scaler = MinMaxScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)

X_train = torch.FloatTensor(X_train)
X_test = torch.FloatTensor(X_test)
y_train = torch.LongTensor(y_train)
y_test = torch.LongTensor(y_test)

class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(4, 10)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(10, 3)

    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        return out
    
model = MLP()

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

epoches = 20000
losses = []

for epoch in range(epoches):
    outputs = model.forward(X_train)
    loss = criterion(outputs, y_train)
    losses.append(loss.item())
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if (epoch + 1) % 100 == 0:
        print(f"Epoch [{epoch + 1}/{epoches}], Loss: {loss.item():.6f}")
python 复制代码
import matplotlib.pyplot as plt

plt.figure()
plt.plot(losses)
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Training Loss Curve")
plt.show()

@浙大疏锦行

相关推荐
wm104341 分钟前
机器学习第二讲 KNN算法
人工智能·算法·机器学习
小途软件2 小时前
用于机器人电池电量预测的Sarsa强化学习混合集成方法
java·人工智能·pytorch·python·深度学习·语言模型
哥布林学者2 小时前
吴恩达深度学习课程五:自然语言处理 第一周:循环神经网络 (五)门控循环单元 GRU
深度学习·ai
薛不痒3 小时前
深度学习之优化模型(数据预处理,数据增强,调整学习率)
深度学习·学习
Yeats_Liao4 小时前
MindSpore开发之路(二十四):MindSpore Hub:快速复用预训练模型
人工智能·分布式·神经网络·机器学习·个人开发
格林威4 小时前
传送带上运动模糊图像复原:提升动态成像清晰度的 6 个核心方案,附 OpenCV+Halcon 实战代码!
人工智能·opencv·机器学习·计算机视觉·ai·halcon·工业相机
Aurora-Borealis.5 小时前
Day27 机器学习流水线
人工智能·机器学习
棒棒的皮皮5 小时前
【深度学习】YOLO模型速度优化Checklist
人工智能·深度学习·yolo·计算机视觉
AI街潜水的八角6 小时前
基于Pytorch深度学习神经网络MNIST手写数字识别系统源码(带界面和手写画板)
pytorch·深度学习·神经网络
黑符石7 小时前
【论文研读】Madgwick 姿态滤波算法报告总结
人工智能·算法·机器学习·imu·惯性动捕·madgwick·姿态滤波