复制代码
使用设备: cpu
Epoch [100/20000], Loss: 1.0420
Epoch [200/20000], Loss: 0.9975
Epoch [200/20000], Loss: 0.9975
Epoch [300/20000], Loss: 0.9480
Epoch [400/20000], Loss: 0.8947
Epoch [400/20000], Loss: 0.8947
Epoch [500/20000], Loss: 0.8393
Epoch [600/20000], Loss: 0.7838
Epoch [600/20000], Loss: 0.7838
Epoch [700/20000], Loss: 0.7300
Epoch [800/20000], Loss: 0.6797
Epoch [800/20000], Loss: 0.6797
Epoch [900/20000], Loss: 0.6337
Epoch [1000/20000], Loss: 0.5927
Epoch [1000/20000], Loss: 0.5927
Epoch [1100/20000], Loss: 0.5566
Epoch [1200/20000], Loss: 0.5251
Epoch [1200/20000], Loss: 0.5251
Epoch [1300/20000], Loss: 0.4975
Epoch [1400/20000], Loss: 0.4733
Epoch [1400/20000], Loss: 0.4733
Epoch [1500/20000], Loss: 0.4518
Epoch [1600/20000], Loss: 0.4325
Epoch [1600/20000], Loss: 0.4325
Epoch [1700/20000], Loss: 0.4150
Epoch [1800/20000], Loss: 0.3990
Epoch [1800/20000], Loss: 0.3990
Epoch [1900/20000], Loss: 0.3840
Epoch [2000/20000], Loss: 0.3701
Epoch [2000/20000], Loss: 0.3701
Epoch [2100/20000], Loss: 0.3570
Epoch [2200/20000], Loss: 0.3447
Epoch [2200/20000], Loss: 0.3447
Epoch [2300/20000], Loss: 0.3329
Epoch [2400/20000], Loss: 0.3218
Epoch [2400/20000], Loss: 0.3218
Epoch [2500/20000], Loss: 0.3112
Epoch [2600/20000], Loss: 0.3011
Epoch [2600/20000], Loss: 0.3011
Epoch [2700/20000], Loss: 0.2914
Epoch [2800/20000], Loss: 0.2822
Epoch [2800/20000], Loss: 0.2822
Epoch [2900/20000], Loss: 0.2735
Epoch [3000/20000], Loss: 0.2651
Epoch [3000/20000], Loss: 0.2651
Epoch [3100/20000], Loss: 0.2572
Epoch [3200/20000], Loss: 0.2496
Epoch [3200/20000], Loss: 0.2496
Epoch [3300/20000], Loss: 0.2423
Epoch [3400/20000], Loss: 0.2354
Epoch [3400/20000], Loss: 0.2354
Epoch [3500/20000], Loss: 0.2288
Epoch [3600/20000], Loss: 0.2226
Epoch [3600/20000], Loss: 0.2226
Epoch [3700/20000], Loss: 0.2166
Epoch [3800/20000], Loss: 0.2109
Epoch [3800/20000], Loss: 0.2109
Epoch [3900/20000], Loss: 0.2054
Epoch [4000/20000], Loss: 0.2003
Epoch [4000/20000], Loss: 0.2003
Epoch [4100/20000], Loss: 0.1953
Epoch [4200/20000], Loss: 0.1906
Epoch [4200/20000], Loss: 0.1906
Epoch [4300/20000], Loss: 0.1861
Epoch [4400/20000], Loss: 0.1818
Epoch [4400/20000], Loss: 0.1818
Epoch [4500/20000], Loss: 0.1777
Epoch [4600/20000], Loss: 0.1738
Epoch [4600/20000], Loss: 0.1738
Epoch [4700/20000], Loss: 0.1700
Epoch [4800/20000], Loss: 0.1664
Epoch [4800/20000], Loss: 0.1664
Epoch [4900/20000], Loss: 0.1630
Epoch [5000/20000], Loss: 0.1597
Epoch [5000/20000], Loss: 0.1597
Epoch [5100/20000], Loss: 0.1566
Epoch [5200/20000], Loss: 0.1536
Epoch [5200/20000], Loss: 0.1536
Epoch [5300/20000], Loss: 0.1507
Epoch [5400/20000], Loss: 0.1479
Epoch [5400/20000], Loss: 0.1479
Epoch [5500/20000], Loss: 0.1452
Epoch [5600/20000], Loss: 0.1427
Epoch [5600/20000], Loss: 0.1427
Epoch [5700/20000], Loss: 0.1402
Epoch [5800/20000], Loss: 0.1379
Epoch [5800/20000], Loss: 0.1379
Epoch [5900/20000], Loss: 0.1356
Epoch [6000/20000], Loss: 0.1335
Epoch [6000/20000], Loss: 0.1335
Epoch [6100/20000], Loss: 0.1314
Epoch [6200/20000], Loss: 0.1294
Epoch [6200/20000], Loss: 0.1294
Epoch [6300/20000], Loss: 0.1274
Epoch [6400/20000], Loss: 0.1256
Epoch [6400/20000], Loss: 0.1256
Epoch [6500/20000], Loss: 0.1238
Epoch [6600/20000], Loss: 0.1220
Epoch [6600/20000], Loss: 0.1220
Epoch [6700/20000], Loss: 0.1204
Epoch [6800/20000], Loss: 0.1188
Epoch [6800/20000], Loss: 0.1188
Epoch [6900/20000], Loss: 0.1172
Epoch [7000/20000], Loss: 0.1157
Epoch [7000/20000], Loss: 0.1157
Epoch [7100/20000], Loss: 0.1143
Epoch [7200/20000], Loss: 0.1129
Epoch [7200/20000], Loss: 0.1129
Epoch [7300/20000], Loss: 0.1115
Epoch [7400/20000], Loss: 0.1102
Epoch [7400/20000], Loss: 0.1102
Epoch [7500/20000], Loss: 0.1089
Epoch [7600/20000], Loss: 0.1077
Epoch [7600/20000], Loss: 0.1077
Epoch [7700/20000], Loss: 0.1065
Epoch [7800/20000], Loss: 0.1054
Epoch [7800/20000], Loss: 0.1054
Epoch [7900/20000], Loss: 0.1043
Epoch [8000/20000], Loss: 0.1032
Epoch [8000/20000], Loss: 0.1032
Epoch [8100/20000], Loss: 0.1022
Epoch [8200/20000], Loss: 0.1012
Epoch [8200/20000], Loss: 0.1012
Epoch [8300/20000], Loss: 0.1002
Epoch [8400/20000], Loss: 0.0992
Epoch [8400/20000], Loss: 0.0992
Epoch [8500/20000], Loss: 0.0983
Epoch [8600/20000], Loss: 0.0974
Epoch [8600/20000], Loss: 0.0974
Epoch [8700/20000], Loss: 0.0965
Epoch [8800/20000], Loss: 0.0957
Epoch [8800/20000], Loss: 0.0957
Epoch [8900/20000], Loss: 0.0949
Epoch [9000/20000], Loss: 0.0941
Epoch [9000/20000], Loss: 0.0941
Epoch [9100/20000], Loss: 0.0933
Epoch [9200/20000], Loss: 0.0926
Epoch [9200/20000], Loss: 0.0926
Epoch [9300/20000], Loss: 0.0918
Epoch [9400/20000], Loss: 0.0911
Epoch [9400/20000], Loss: 0.0911
Epoch [9500/20000], Loss: 0.0904
Epoch [9600/20000], Loss: 0.0898
Epoch [9600/20000], Loss: 0.0898
Epoch [9700/20000], Loss: 0.0891
Epoch [9800/20000], Loss: 0.0885
Epoch [9800/20000], Loss: 0.0885
Epoch [9900/20000], Loss: 0.0878
Epoch [10000/20000], Loss: 0.0872
Epoch [10000/20000], Loss: 0.0872
Epoch [10100/20000], Loss: 0.0866
Epoch [10200/20000], Loss: 0.0861
Epoch [10200/20000], Loss: 0.0861
Epoch [10300/20000], Loss: 0.0855
Epoch [10400/20000], Loss: 0.0850
Epoch [10400/20000], Loss: 0.0850
Epoch [10500/20000], Loss: 0.0844
Epoch [10600/20000], Loss: 0.0839
Epoch [10600/20000], Loss: 0.0839
Epoch [10700/20000], Loss: 0.0834
Epoch [10800/20000], Loss: 0.0829
Epoch [10800/20000], Loss: 0.0829
Epoch [10900/20000], Loss: 0.0824
Epoch [11000/20000], Loss: 0.0819
Epoch [11000/20000], Loss: 0.0819
Epoch [11100/20000], Loss: 0.0815
Epoch [11200/20000], Loss: 0.0810
Epoch [11200/20000], Loss: 0.0810
Epoch [11300/20000], Loss: 0.0806
Epoch [11400/20000], Loss: 0.0802
Epoch [11400/20000], Loss: 0.0802
Epoch [11500/20000], Loss: 0.0797
Epoch [11600/20000], Loss: 0.0793
Epoch [11600/20000], Loss: 0.0793
Epoch [11700/20000], Loss: 0.0789
Epoch [11800/20000], Loss: 0.0785
Epoch [11800/20000], Loss: 0.0785
Epoch [11900/20000], Loss: 0.0781
Epoch [12000/20000], Loss: 0.0778
Epoch [12000/20000], Loss: 0.0778
Epoch [12100/20000], Loss: 0.0774
Epoch [12200/20000], Loss: 0.0770
Epoch [12200/20000], Loss: 0.0770
Epoch [12300/20000], Loss: 0.0767
Epoch [12400/20000], Loss: 0.0763
Epoch [12400/20000], Loss: 0.0763
Epoch [12500/20000], Loss: 0.0760
Epoch [12600/20000], Loss: 0.0756
Epoch [12600/20000], Loss: 0.0756
Epoch [12700/20000], Loss: 0.0753
Epoch [12800/20000], Loss: 0.0750
Epoch [12800/20000], Loss: 0.0750
Epoch [12900/20000], Loss: 0.0747
Epoch [13000/20000], Loss: 0.0744
Epoch [13000/20000], Loss: 0.0744
Epoch [13100/20000], Loss: 0.0741
Epoch [13200/20000], Loss: 0.0738
Epoch [13200/20000], Loss: 0.0738
Epoch [13300/20000], Loss: 0.0735
Epoch [13400/20000], Loss: 0.0732
Epoch [13400/20000], Loss: 0.0732
Epoch [13500/20000], Loss: 0.0729
Epoch [13600/20000], Loss: 0.0726
Epoch [13600/20000], Loss: 0.0726
Epoch [13700/20000], Loss: 0.0724
Epoch [13800/20000], Loss: 0.0721
Epoch [13800/20000], Loss: 0.0721
Epoch [13900/20000], Loss: 0.0719
Epoch [14000/20000], Loss: 0.0716
Epoch [14000/20000], Loss: 0.0716
Epoch [14100/20000], Loss: 0.0713
Epoch [14200/20000], Loss: 0.0711
Epoch [14200/20000], Loss: 0.0711
Epoch [14300/20000], Loss: 0.0709
Epoch [14400/20000], Loss: 0.0706
Epoch [14400/20000], Loss: 0.0706
Epoch [14500/20000], Loss: 0.0704
Epoch [14600/20000], Loss: 0.0702
Epoch [14600/20000], Loss: 0.0702
Epoch [14700/20000], Loss: 0.0699
Epoch [14800/20000], Loss: 0.0697
Epoch [14800/20000], Loss: 0.0697
Epoch [14900/20000], Loss: 0.0695
Epoch [15000/20000], Loss: 0.0693
Epoch [15000/20000], Loss: 0.0693
Epoch [15100/20000], Loss: 0.0691
Epoch [15200/20000], Loss: 0.0689
Epoch [15200/20000], Loss: 0.0689
Epoch [15300/20000], Loss: 0.0687
Epoch [15400/20000], Loss: 0.0685
Epoch [15400/20000], Loss: 0.0685
Epoch [15500/20000], Loss: 0.0683
Epoch [15600/20000], Loss: 0.0681
Epoch [15600/20000], Loss: 0.0681
Epoch [15700/20000], Loss: 0.0679
Epoch [15800/20000], Loss: 0.0677
Epoch [15800/20000], Loss: 0.0677
Epoch [15900/20000], Loss: 0.0675
Epoch [16000/20000], Loss: 0.0673
Epoch [16000/20000], Loss: 0.0673
Epoch [16100/20000], Loss: 0.0671
Epoch [16200/20000], Loss: 0.0670
Epoch [16200/20000], Loss: 0.0670
Epoch [16300/20000], Loss: 0.0668
Epoch [16400/20000], Loss: 0.0666
Epoch [16400/20000], Loss: 0.0666
Epoch [16500/20000], Loss: 0.0664
Epoch [16600/20000], Loss: 0.0663
Epoch [16600/20000], Loss: 0.0663
Epoch [16700/20000], Loss: 0.0661
Epoch [16800/20000], Loss: 0.0660
Epoch [16800/20000], Loss: 0.0660
Epoch [16900/20000], Loss: 0.0658
Epoch [17000/20000], Loss: 0.0656
Epoch [17000/20000], Loss: 0.0656
Epoch [17100/20000], Loss: 0.0655
Epoch [17200/20000], Loss: 0.0653
Epoch [17200/20000], Loss: 0.0653
Epoch [17300/20000], Loss: 0.0652
Epoch [17400/20000], Loss: 0.0650
Epoch [17400/20000], Loss: 0.0650
Epoch [17500/20000], Loss: 0.0649
Epoch [17600/20000], Loss: 0.0647
Epoch [17600/20000], Loss: 0.0647
Epoch [17700/20000], Loss: 0.0646
Epoch [17800/20000], Loss: 0.0645
Epoch [17800/20000], Loss: 0.0645
Epoch [17900/20000], Loss: 0.0643
Epoch [18000/20000], Loss: 0.0642
Epoch [18000/20000], Loss: 0.0642
Epoch [18100/20000], Loss: 0.0640
Epoch [18200/20000], Loss: 0.0639
Epoch [18200/20000], Loss: 0.0639
Epoch [18300/20000], Loss: 0.0638
Epoch [18400/20000], Loss: 0.0636
Epoch [18400/20000], Loss: 0.0636
Epoch [18500/20000], Loss: 0.0635
Epoch [18600/20000], Loss: 0.0634
Epoch [18600/20000], Loss: 0.0634
Epoch [18700/20000], Loss: 0.0633
Epoch [18800/20000], Loss: 0.0631
Epoch [18800/20000], Loss: 0.0631
Epoch [18900/20000], Loss: 0.0630
Epoch [19000/20000], Loss: 0.0629
Epoch [19000/20000], Loss: 0.0629
Epoch [19100/20000], Loss: 0.0628
Epoch [19200/20000], Loss: 0.0627
Epoch [19200/20000], Loss: 0.0627
Epoch [19300/20000], Loss: 0.0626
Epoch [19400/20000], Loss: 0.0624
Epoch [19400/20000], Loss: 0.0624
Epoch [19500/20000], Loss: 0.0623
Epoch [19600/20000], Loss: 0.0622
Epoch [19600/20000], Loss: 0.0622
Epoch [19700/20000], Loss: 0.0621
Epoch [19800/20000], Loss: 0.0620
Epoch [19800/20000], Loss: 0.0620
Epoch [19900/20000], Loss: 0.0619
Epoch [20000/20000], Loss: 0.0618
Epoch [20000/20000], Loss: 0.0618
Training time: 8.60 seconds
python
复制代码
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
import time
import matplotlib.pyplot as plt
from tqdm import tqdm
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(f'使用设备: {device}\n')
iris = load_iris()
X = iris.data
y = iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
scaler = MinMaxScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)
X_train = torch.FloatTensor(X_train).to(device)
y_train = torch.LongTensor(y_train).to(device)
X_test = torch.FloatTensor(X_test).to(device)
y_test = torch.LongTensor(y_test).to(device)
class MLP_Original(nn.Module):
def __init__(self):
super(MLP_Original, self).__init__()
self.fc1 = nn.Linear(4, 10)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(10, 3)
def forward(self, x):
out = self.fc1(x)
out = self.relu(out)
out = self.fc2(out)
return out
class MLP_Larger(nn.Module):
def __init__(self):
super(MLP_Larger, self).__init__()
self.fc1 = nn.Linear(4, 20)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(20, 10)
self.fc3 = nn.Linear(10, 3)
def forward(self, x):
out = self.fc1(x)
out = self.relu(out)
out = self.fc2(out)
out = self.relu(out)
out = self.fc3(out)
return out
class MLP_Smaller(nn.Module):
def __init__(self):
super(MLP_Smaller, self).__init__()
self.fc1 = nn.Linear(4, 5)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(5, 3)
def forward(self, x):
out = self.fc1(x)
out = self.relu(out)
out = self.fc2(out)
return out
class MLP_Tanh(nn.Module):
def __init__(self):
super(MLP_Tanh, self).__init__()
self.fc1 = nn.Linear(4, 10)
self.act = nn.Tanh()
self.fc2 = nn.Linear(10, 3)
def forward(self, x):
out = self.fc1(x)
out = self.act(out)
out = self.fc2(out)
return out
def train_and_evaluate(model_class, optimizer_class, lr, num_epochs=20000):
model = model_class().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optimizer_class(model.parameters(), lr=lr)
losses = []
epochs = []
start_time = time.time()
with tqdm(total=num_epochs, desc=f'训练 {model_class.__name__}', unit='epoch') as pbar:
for epoch in range(num_epochs):
outputs = model(X_train)
loss = criterion(outputs, y_train)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if (epoch + 1) % 200 == 0:
losses.append(loss.item())
epochs.append(epoch + 1)
pbar.set_postfix({'Loss': f'{loss.item():.4f}'})
if (epoch + 1) % 1000 == 0:
pbar.update(1000)
if pbar.n < num_epochs:
pbar.update(num_epochs - pbar.n)
time_all = time.time() - start_time
with torch.no_grad():
outputs = model(X_test)
_, predicted = torch.max(outputs.data, 1)
accuracy = (predicted == y_test).sum().item() / y_test.size(0)
print(f'{model_class.__name__} 训练时间: {time_all:.2f}秒, 测试准确率: {accuracy:.4f}\n')
return epochs, losses, accuracy
configs = [
(MLP_Original, optim.SGD, 0.01),
(MLP_Larger, optim.SGD, 0.01),
(MLP_Smaller, optim.SGD, 0.01),
(MLP_Tanh, optim.SGD, 0.01),
(MLP_Original, optim.Adam, 0.001),
(MLP_Original, optim.SGD, 0.1),
(MLP_Original, optim.SGD, 0.001)
]
plt.figure(figsize=(12, 8))
for config in configs:
epochs, losses, accuracy = train_and_evaluate(*config)
plt.plot(epochs, losses, label=f'{config[0].__name__} {config[1].__name__} lr={config[2]} (Acc:{accuracy:.2f})')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss Comparison with Different Hyperparameters')
plt.legend()
plt.grid(True)
plt.show()