python训练营第33天

MLP神经网络的训练

知识点回顾:

  1. PyTorch和cuda的安装
  2. 查看显卡信息的命令行命令(cmd中使用)
  3. cuda的检查
  4. 简单神经网络的流程
    1. 数据预处理(归一化、转换成张量)
    2. 模型的定义
      1. 继承nn.Module类
      2. 定义每一个层
      3. 定义前向传播流程
    3. 定义损失函数和优化器
    4. 定义训练流程
    5. 可视化loss过程
python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
import numpy as np
import matplotlib.pyplot as plt

# 检查CUDA可用性并输出相关信息
if torch.cuda.is_available():
    print("CUDA可用!")
    device_count = torch.cuda.device_count()
    print(f"可用的CUDA设备数量:{device_count}")

    current_device = torch.cuda.current_device()
    print(f"当前使用的CUDA设备索引:{current_device}")

    device_name = torch.cuda.get_device_name(current_device)
    print(f"当前CUDA设备的名称:{device_name}")

    cuda_version = torch.version.cuda
    print(f"CUDA版本:{cuda_version}")
else:
    print("CUDA不可用。")

# 加载并准备Iris数据集
iris = load_iris()
X = iris.data
y = iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

print(X_train.shape)
print(y_train.shape)
print(X_test.shape)
print(y_test.shape)

# 数据标准化处理
scaler = MinMaxScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)

# 转换为PyTorch张量
X_train = torch.FloatTensor(X_train)
y_train = torch.LongTensor(y_train)
X_test = torch.FloatTensor(X_test)
y_test = torch.LongTensor(y_test)

# 定义多层感知机模型
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(4, 10)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(10, 3)

    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        return out

# 初始化模型、损失函数和优化器
model = MLP()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)  # 添加了学习率参数

# 训练模型
num_epochs = 20000
losses = []

for epoch in range(num_epochs):
    outputs = model.forward(X_train)
    loss = criterion(outputs, y_train)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    losses.append(loss.item())

    if (epoch + 1) % 100 == 0:
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')

# 可视化训练损失
plt.plot(range(num_epochs), losses)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss over Epochs')
plt.show()

输出结果:

python 复制代码
(120, 4)
(120,)  
(30, 4) 
(30,)   
Epoch [100/20000], Loss: 1.0980
Epoch [200/20000], Loss: 1.0692
Epoch [300/20000], Loss: 1.0398
Epoch [400/20000], Loss: 1.0060
Epoch [500/20000], Loss: 0.9665
Epoch [600/20000], Loss: 0.9214
Epoch [700/20000], Loss: 0.8731
Epoch [800/20000], Loss: 0.8241
Epoch [900/20000], Loss: 0.7777
Epoch [1000/20000], Loss: 0.7360
Epoch [1100/20000], Loss: 0.6991
Epoch [1200/20000], Loss: 0.6666
Epoch [1300/20000], Loss: 0.6380
Epoch [1400/20000], Loss: 0.6127
Epoch [1500/20000], Loss: 0.5903
Epoch [1600/20000], Loss: 0.5703
Epoch [1700/20000], Loss: 0.5523
Epoch [1800/20000], Loss: 0.5360
Epoch [1900/20000], Loss: 0.5211
Epoch [2000/20000], Loss: 0.5074
Epoch [2100/20000], Loss: 0.4948
Epoch [2200/20000], Loss: 0.4830
Epoch [2300/20000], Loss: 0.4720
Epoch [2400/20000], Loss: 0.4616
Epoch [2500/20000], Loss: 0.4518
Epoch [2600/20000], Loss: 0.4425
Epoch [2700/20000], Loss: 0.4336
Epoch [2800/20000], Loss: 0.4251
Epoch [2900/20000], Loss: 0.4168
Epoch [3000/20000], Loss: 0.4088
Epoch [3100/20000], Loss: 0.4010
Epoch [3200/20000], Loss: 0.3935
Epoch [3300/20000], Loss: 0.3860
Epoch [3400/20000], Loss: 0.3787
Epoch [3500/20000], Loss: 0.3716
Epoch [3600/20000], Loss: 0.3645
Epoch [3700/20000], Loss: 0.3576
Epoch [3800/20000], Loss: 0.3508
Epoch [3900/20000], Loss: 0.3440
Epoch [4000/20000], Loss: 0.3374
Epoch [4100/20000], Loss: 0.3308
Epoch [4200/20000], Loss: 0.3243
Epoch [4300/20000], Loss: 0.3179
Epoch [4400/20000], Loss: 0.3115
Epoch [4500/20000], Loss: 0.3053
Epoch [4600/20000], Loss: 0.2991
Epoch [4700/20000], Loss: 0.2930
Epoch [4800/20000], Loss: 0.2870
Epoch [4900/20000], Loss: 0.2810
Epoch [5000/20000], Loss: 0.2752
Epoch [5100/20000], Loss: 0.2694
Epoch [5200/20000], Loss: 0.2638
Epoch [5300/20000], Loss: 0.2583
Epoch [5400/20000], Loss: 0.2530
Epoch [5500/20000], Loss: 0.2478
Epoch [5600/20000], Loss: 0.2428
Epoch [5700/20000], Loss: 0.2378
Epoch [5800/20000], Loss: 0.2330
Epoch [5900/20000], Loss: 0.2284
Epoch [6000/20000], Loss: 0.2238
Epoch [6100/20000], Loss: 0.2193
Epoch [6200/20000], Loss: 0.2150
Epoch [6300/20000], Loss: 0.2108
Epoch [6400/20000], Loss: 0.2067
Epoch [6500/20000], Loss: 0.2027
Epoch [6600/20000], Loss: 0.1989
Epoch [6700/20000], Loss: 0.1951
Epoch [6800/20000], Loss: 0.1914
Epoch [6900/20000], Loss: 0.1878
Epoch [7000/20000], Loss: 0.1844
Epoch [7100/20000], Loss: 0.1810
Epoch [7200/20000], Loss: 0.1778
Epoch [7300/20000], Loss: 0.1746
Epoch [7400/20000], Loss: 0.1716
Epoch [7500/20000], Loss: 0.1686
Epoch [7600/20000], Loss: 0.1658
Epoch [7700/20000], Loss: 0.1630
Epoch [7800/20000], Loss: 0.1604
Epoch [7900/20000], Loss: 0.1578
Epoch [8000/20000], Loss: 0.1553
Epoch [8100/20000], Loss: 0.1528
Epoch [8200/20000], Loss: 0.1505
Epoch [8300/20000], Loss: 0.1483
Epoch [8400/20000], Loss: 0.1461
Epoch [8500/20000], Loss: 0.1440
Epoch [8600/20000], Loss: 0.1420
Epoch [8700/20000], Loss: 0.1400
Epoch [8800/20000], Loss: 0.1381
Epoch [8900/20000], Loss: 0.1363
Epoch [9000/20000], Loss: 0.1345
Epoch [9100/20000], Loss: 0.1328
Epoch [9200/20000], Loss: 0.1312
Epoch [9300/20000], Loss: 0.1295
Epoch [9400/20000], Loss: 0.1280
Epoch [9500/20000], Loss: 0.1265
Epoch [9600/20000], Loss: 0.1250
Epoch [9700/20000], Loss: 0.1236
Epoch [9800/20000], Loss: 0.1222
Epoch [9900/20000], Loss: 0.1208
Epoch [10000/20000], Loss: 0.1195
Epoch [10100/20000], Loss: 0.1183
Epoch [10200/20000], Loss: 0.1170
Epoch [10300/20000], Loss: 0.1159
Epoch [10400/20000], Loss: 0.1147
Epoch [10500/20000], Loss: 0.1136
Epoch [10600/20000], Loss: 0.1125
Epoch [10700/20000], Loss: 0.1114
Epoch [10800/20000], Loss: 0.1104
Epoch [10900/20000], Loss: 0.1094
Epoch [11000/20000], Loss: 0.1084
Epoch [11100/20000], Loss: 0.1075
Epoch [11200/20000], Loss: 0.1065
Epoch [11300/20000], Loss: 0.1056
Epoch [11400/20000], Loss: 0.1048
Epoch [11500/20000], Loss: 0.1039
Epoch [11600/20000], Loss: 0.1031
Epoch [11700/20000], Loss: 0.1023
Epoch [11800/20000], Loss: 0.1015
Epoch [11900/20000], Loss: 0.1007
Epoch [12000/20000], Loss: 0.0999
Epoch [12100/20000], Loss: 0.0992
Epoch [12200/20000], Loss: 0.0985
Epoch [12300/20000], Loss: 0.0978
Epoch [12400/20000], Loss: 0.0971
Epoch [12500/20000], Loss: 0.0964
Epoch [12600/20000], Loss: 0.0958
Epoch [12700/20000], Loss: 0.0951
Epoch [12800/20000], Loss: 0.0945
Epoch [12900/20000], Loss: 0.0939
Epoch [13000/20000], Loss: 0.0933
Epoch [13100/20000], Loss: 0.0927
Epoch [13200/20000], Loss: 0.0922
Epoch [13300/20000], Loss: 0.0916
Epoch [13400/20000], Loss: 0.0910
Epoch [13500/20000], Loss: 0.0905
Epoch [13600/20000], Loss: 0.0900
Epoch [13700/20000], Loss: 0.0895
Epoch [13800/20000], Loss: 0.0890
Epoch [13900/20000], Loss: 0.0885
Epoch [14000/20000], Loss: 0.0880
Epoch [14100/20000], Loss: 0.0875
Epoch [14200/20000], Loss: 0.0871
Epoch [14300/20000], Loss: 0.0866
Epoch [14400/20000], Loss: 0.0862
Epoch [14500/20000], Loss: 0.0857
Epoch [14600/20000], Loss: 0.0853
Epoch [14700/20000], Loss: 0.0849
Epoch [14800/20000], Loss: 0.0845
Epoch [14900/20000], Loss: 0.0840
Epoch [15000/20000], Loss: 0.0837
Epoch [15100/20000], Loss: 0.0833
Epoch [15200/20000], Loss: 0.0829
Epoch [15300/20000], Loss: 0.0825
Epoch [15400/20000], Loss: 0.0821
Epoch [15500/20000], Loss: 0.0818
Epoch [15600/20000], Loss: 0.0814
Epoch [15700/20000], Loss: 0.0811
Epoch [15800/20000], Loss: 0.0807
Epoch [15900/20000], Loss: 0.0804
Epoch [16000/20000], Loss: 0.0800
Epoch [16100/20000], Loss: 0.0797
Epoch [16200/20000], Loss: 0.0794
Epoch [16300/20000], Loss: 0.0791
Epoch [16400/20000], Loss: 0.0788
Epoch [16500/20000], Loss: 0.0785
Epoch [16600/20000], Loss: 0.0782
Epoch [16700/20000], Loss: 0.0779
Epoch [16800/20000], Loss: 0.0776
Epoch [16900/20000], Loss: 0.0773
Epoch [17000/20000], Loss: 0.0770
Epoch [17100/20000], Loss: 0.0767
Epoch [17200/20000], Loss: 0.0765
Epoch [17300/20000], Loss: 0.0762
Epoch [17400/20000], Loss: 0.0759
Epoch [17500/20000], Loss: 0.0757
Epoch [17600/20000], Loss: 0.0754
Epoch [17700/20000], Loss: 0.0751
Epoch [17800/20000], Loss: 0.0749
Epoch [17900/20000], Loss: 0.0747
Epoch [18000/20000], Loss: 0.0744
Epoch [18100/20000], Loss: 0.0742
Epoch [18200/20000], Loss: 0.0739
Epoch [18300/20000], Loss: 0.0737
Epoch [18400/20000], Loss: 0.0735
Epoch [18500/20000], Loss: 0.0733
Epoch [18600/20000], Loss: 0.0730
Epoch [18700/20000], Loss: 0.0728
Epoch [18800/20000], Loss: 0.0726
Epoch [18900/20000], Loss: 0.0724
Epoch [19000/20000], Loss: 0.0722
Epoch [19100/20000], Loss: 0.0720
Epoch [19200/20000], Loss: 0.0718
Epoch [19300/20000], Loss: 0.0716
Epoch [19400/20000], Loss: 0.0714
Epoch [19500/20000], Loss: 0.0712
Epoch [19600/20000], Loss: 0.0710
Epoch [19700/20000], Loss: 0.0708
Epoch [19800/20000], Loss: 0.0706
Epoch [19900/20000], Loss: 0.0704
Epoch [20000/20000], Loss: 0.0702
PS E:\shucai\py>
Epoch [20000/20000], Loss: 0.0702
Epoch [20000/20000], Loss: 0.0702
Epoch [20000/20000], Loss: 0.0702
Epoch [20000/20000], Loss: 0.0702
PS E:\shucai\py>
Epoch [20000/20000], Loss: 0.0702
PS E:\shucai\py>
Epoch [20000/20000], Loss: 0.0702
Epoch [20000/20000], Loss: 0.0702
Epoch [20000/20000], Loss: 0.0702

@浙大疏锦行

相关推荐
还在忙碌的吴小二3 分钟前
Harness 最佳实践:Java Spring Boot 项目落地 OpenSpec + Claude Code
java·开发语言·spring boot·后端·spring
liliangcsdn4 分钟前
mstsc不在“C:\Windows\System32“下在C:\windows\WinSxS\anmd64xxx“问题分析
开发语言·windows
weixin_156241575767 分钟前
基于YOLOv8深度学习花卉识别系统摄像头实时图片文件夹多图片等另有其他的识别系统可二开
大数据·人工智能·python·深度学习·yolo
AI_Claude_code12 分钟前
ZLibrary访问困境方案三:Web代理与轻量级转发服务的搭建与优化
爬虫·python·web安全·搜索引擎·网络安全·web3·httpx
小陈工15 分钟前
2026年4月7日技术资讯洞察:下一代数据库融合、AI基础设施竞赛与异步编程实战
开发语言·前端·数据库·人工智能·python
KAU的云实验台18 分钟前
【算法精解】AIR期刊算法IAGWO:引入速度概念与逆多元二次权重,可应对高维/工程问题(附Matlab源码)
开发语言·算法·matlab
时空无限21 分钟前
ansible 由于不同主机 python 版本不同执行报错
python·ansible
会编程的土豆24 分钟前
【数据结构与算法】再次全面了解LCS底层
开发语言·数据结构·c++·算法
ZhengEnCi25 分钟前
P2E-Python字典操作完全指南-从增删改查到遍历嵌套的Python编程利器
python
alanesnape26 分钟前
使用AVL平衡树和列表实现 map容器 -- 附加测试/python代码
python·map·avl 平衡树·bst树·二叉树旋转