NLP实战(4):使用PyTorch构建LSTM模型预测糖尿病

目录

[1. 数据准备](#1. 数据准备)

[2. 创建数据加载器](#2. 创建数据加载器)

[3. 构建LSTM模型](#3. 构建LSTM模型)

[4. 模型训练](#4. 模型训练)

[5. 模型评估](#5. 模型评估)

[6. 可视化训练过程](#6. 可视化训练过程)

7.总结

8.实验过程和下载


在这篇博客中,我将详细介绍如何使用PyTorch构建一个双层LSTM模型来预测糖尿病。

我们将从数据加载开始,逐步讲解模型构建、训练过程和结果评估。

1. 数据准备

首先,我们需要加载并准备数据:

python 复制代码
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import matplotlib.pyplot as plt

# 加载数据
data = pd.read_csv('diabetes.csv', header=None)
X = data.iloc[:, :-1].values  # 特征
y = data.iloc[:, -1].values   # 标签

# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 转换为PyTorch张量
X_train = torch.FloatTensor(X_train)  # 形状为 (样本数, 8)
X_test = torch.FloatTensor(X_test)    # 形状为 (样本数, 8)
y_train = torch.FloatTensor(y_train)
y_test = torch.FloatTensor(y_test)

这段代码完成了以下工作:

  1. 导入必要的库

  2. 从CSV文件加载糖尿病数据集

  3. 将数据分为特征(X)和标签(y)

  4. 使用train_test_split将数据划分为训练集和测试集(80%训练,20%测试)

  5. 将NumPy数组转换为PyTorch张量

2. 创建数据加载器

为了高效地批量加载数据,我们使用PyTorch的DataLoader:

python 复制代码
# 创建DataLoader
train_data = TensorDataset(X_train, y_train)
test_data = TensorDataset(X_test, y_test)
train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
test_loader = DataLoader(test_data, batch_size=32)

这里我们:

  • 使用TensorDataset将特征和标签打包

  • 创建训练和测试的DataLoader,批量大小为32

  • 训练数据会被随机打乱(shuffle=True),而测试数据保持原顺序

3. 构建LSTM模型

我们构建了一个双层LSTM模型:

python 复制代码
class LSTMModel(nn.Module):
    def __init__(self, input_size=8, hidden_size1=64, hidden_size2=32):
        super(LSTMModel, self).__init__()
        self.lstm1 = nn.LSTM(input_size, hidden_size1, batch_first=True)
        self.dropout1 = nn.Dropout(0.3)
        self.lstm2 = nn.LSTM(hidden_size1, hidden_size2, batch_first=True)
        self.dropout2 = nn.Dropout(0.3)
        self.fc = nn.Linear(hidden_size2, 1)

    def forward(self, x):
        # 添加序列长度维度 (batch_size, 1, input_size)
        x = x.unsqueeze(1)  # 从(batch_size, 8)变为(batch_size, 1, 8)

        # 第一层LSTM
        x, _ = self.lstm1(x)
        x = self.dropout1(x)

        # 第二层LSTM
        x, (hn, cn) = self.lstm2(x)
        x = self.dropout2(hn[-1])  # 取最后一个时间步的隐藏状态

        x = self.fc(x)
        return torch.sigmoid(x.squeeze())

模型特点:

  • 输入特征数为8(对应糖尿病数据集的8个特征)

  • 第一层LSTM有64个隐藏单元

  • 第二层LSTM有32个隐藏单元

  • 每层LSTM后都有dropout层(概率0.3)防止过拟合

  • 最后通过一个全连接层输出单个值,并用sigmoid激活函数转换为概率

  • 在forward方法中,我们添加了一个序列长度维度(1),因为LSTM需要序列数据

4. 模型训练

我们使用Adam优化器和BCELoss(二元交叉熵损失)来训练模型:

python 复制代码
# 初始化模型
model = LSTMModel(input_size=8)  # 8个特征
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.0002)

# 训练和验证记录
train_losses = []
train_accs = []
val_losses = []
val_accs = []

# 训练模型
epochs = 300
for epoch in range(epochs):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for inputs, labels in train_loader:
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        predicted = (outputs > 0.5).float()
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    # 计算并记录训练指标
    train_loss = running_loss / len(train_loader)
    train_acc = correct / total
    train_losses.append(train_loss)
    train_accs.append(train_acc)

    # 验证
    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, labels in test_loader:
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            predicted = (outputs > 0.5).float()
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    # 计算并记录验证指标
    val_loss = val_loss / len(test_loader)
    val_acc = correct / total
    val_losses.append(val_loss)
    val_accs.append(val_acc)

    print(f'Epoch {epoch + 1}/{epochs}, Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}')

训练过程包括:

  1. 初始化模型、损失函数和优化器

  2. 进行300个epoch的训练

  3. 每个epoch中:

    • 训练阶段:前向传播、计算损失、反向传播、参数更新

    • 验证阶段:评估模型在测试集上的表现

  4. 记录并打印训练和验证的损失和准确率

5. 模型评估

训练完成后,我们评估模型在测试集上的最终表现:

python 复制代码
# 评估模型
model.eval()
with torch.no_grad():
    outputs = model(X_test)
    predicted = (outputs > 0.5).float()
    accuracy = (predicted == y_test).float().mean()
print(f'Test Accuracy: {accuracy:.4f}')

6. 可视化训练过程

最后,我们绘制训练和验证的准确率和损失曲线:

python 复制代码
# 绘制训练曲线
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(train_accs, label='Training Accuracy')
plt.plot(val_accs, label='Validation Accuracy')
plt.title('Training and Validation Accuracy')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(train_losses, label='Training Loss')
plt.plot(val_losses, label='Validation Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.show()

这些图表可以帮助我们:

  • 观察模型是否收敛

  • 检测是否存在过拟合或欠拟合

  • 决定是否需要调整训练参数

7.总结

在这篇博客中,我们详细介绍了如何使用PyTorch构建和训练一个双层LSTM模型来预测糖尿病。关键点包括:

  1. 数据准备和加载

  2. LSTM模型架构设计

  3. 训练过程和验证

  4. 模型评估和可视化

虽然LSTM通常用于时间序列数据,但在这个例子中我们将其应用于非时间序列数据,展示了PyTorch的灵活性。通过调整模型架构、超参数和数据预处理,可以进一步提高模型性能。

希望这篇博客能帮助你理解如何使用PyTorch实现LSTM模型!

8.实验过程和下载

日志如下:

python 复制代码
Epoch 1/300, Train Loss: 0.7193, Train Acc: 0.3558, Val Loss: 0.7248, Val Acc: 0.3092
Epoch 2/300, Train Loss: 0.7165, Train Acc: 0.3558, Val Loss: 0.7203, Val Acc: 0.3092
Epoch 3/300, Train Loss: 0.7121, Train Acc: 0.3558, Val Loss: 0.7158, Val Acc: 0.3092
Epoch 4/300, Train Loss: 0.7087, Train Acc: 0.3558, Val Loss: 0.7108, Val Acc: 0.3092
Epoch 5/300, Train Loss: 0.7042, Train Acc: 0.3558, Val Loss: 0.7053, Val Acc: 0.3092
Epoch 6/300, Train Loss: 0.7003, Train Acc: 0.3624, Val Loss: 0.6989, Val Acc: 0.3092
Epoch 7/300, Train Loss: 0.6951, Train Acc: 0.4498, Val Loss: 0.6920, Val Acc: 0.5066
Epoch 8/300, Train Loss: 0.6881, Train Acc: 0.6277, Val Loss: 0.6837, Val Acc: 0.7500
Epoch 9/300, Train Loss: 0.6823, Train Acc: 0.6590, Val Loss: 0.6734, Val Acc: 0.7105
Epoch 10/300, Train Loss: 0.6737, Train Acc: 0.6557, Val Loss: 0.6622, Val Acc: 0.6908
Epoch 11/300, Train Loss: 0.6653, Train Acc: 0.6491, Val Loss: 0.6496, Val Acc: 0.6974
Epoch 12/300, Train Loss: 0.6566, Train Acc: 0.6409, Val Loss: 0.6357, Val Acc: 0.6974
Epoch 13/300, Train Loss: 0.6457, Train Acc: 0.6458, Val Loss: 0.6215, Val Acc: 0.6908
Epoch 14/300, Train Loss: 0.6379, Train Acc: 0.6425, Val Loss: 0.6075, Val Acc: 0.6908
Epoch 15/300, Train Loss: 0.6306, Train Acc: 0.6425, Val Loss: 0.5973, Val Acc: 0.6908
Epoch 16/300, Train Loss: 0.6248, Train Acc: 0.6425, Val Loss: 0.5870, Val Acc: 0.6908
Epoch 17/300, Train Loss: 0.6203, Train Acc: 0.6442, Val Loss: 0.5778, Val Acc: 0.6908
Epoch 18/300, Train Loss: 0.6123, Train Acc: 0.6442, Val Loss: 0.5709, Val Acc: 0.6974
Epoch 19/300, Train Loss: 0.6142, Train Acc: 0.6425, Val Loss: 0.5648, Val Acc: 0.6974
Epoch 20/300, Train Loss: 0.6046, Train Acc: 0.6425, Val Loss: 0.5597, Val Acc: 0.6974
Epoch 21/300, Train Loss: 0.5988, Train Acc: 0.6425, Val Loss: 0.5547, Val Acc: 0.6974
Epoch 22/300, Train Loss: 0.5989, Train Acc: 0.6442, Val Loss: 0.5497, Val Acc: 0.6974
Epoch 23/300, Train Loss: 0.5993, Train Acc: 0.6392, Val Loss: 0.5454, Val Acc: 0.6974
Epoch 24/300, Train Loss: 0.5930, Train Acc: 0.6409, Val Loss: 0.5406, Val Acc: 0.7039
Epoch 25/300, Train Loss: 0.5872, Train Acc: 0.6392, Val Loss: 0.5362, Val Acc: 0.6974
Epoch 26/300, Train Loss: 0.5859, Train Acc: 0.6425, Val Loss: 0.5327, Val Acc: 0.6974
Epoch 27/300, Train Loss: 0.5859, Train Acc: 0.6442, Val Loss: 0.5285, Val Acc: 0.7039
Epoch 28/300, Train Loss: 0.5796, Train Acc: 0.6458, Val Loss: 0.5244, Val Acc: 0.7105
Epoch 29/300, Train Loss: 0.5778, Train Acc: 0.6524, Val Loss: 0.5212, Val Acc: 0.7171
Epoch 30/300, Train Loss: 0.5727, Train Acc: 0.6573, Val Loss: 0.5170, Val Acc: 0.7303
Epoch 31/300, Train Loss: 0.5682, Train Acc: 0.6623, Val Loss: 0.5122, Val Acc: 0.7434
Epoch 32/300, Train Loss: 0.5695, Train Acc: 0.6689, Val Loss: 0.5075, Val Acc: 0.7434
Epoch 33/300, Train Loss: 0.5667, Train Acc: 0.6771, Val Loss: 0.5044, Val Acc: 0.7566
Epoch 34/300, Train Loss: 0.5592, Train Acc: 0.6870, Val Loss: 0.4993, Val Acc: 0.7566
Epoch 35/300, Train Loss: 0.5555, Train Acc: 0.6903, Val Loss: 0.4958, Val Acc: 0.7632
Epoch 36/300, Train Loss: 0.5513, Train Acc: 0.7051, Val Loss: 0.4914, Val Acc: 0.7763
Epoch 37/300, Train Loss: 0.5483, Train Acc: 0.7035, Val Loss: 0.4870, Val Acc: 0.7829
Epoch 38/300, Train Loss: 0.5484, Train Acc: 0.7068, Val Loss: 0.4828, Val Acc: 0.7829
Epoch 39/300, Train Loss: 0.5436, Train Acc: 0.7216, Val Loss: 0.4794, Val Acc: 0.7961
Epoch 40/300, Train Loss: 0.5420, Train Acc: 0.7282, Val Loss: 0.4767, Val Acc: 0.8092
Epoch 41/300, Train Loss: 0.5353, Train Acc: 0.7216, Val Loss: 0.4727, Val Acc: 0.8289
Epoch 42/300, Train Loss: 0.5284, Train Acc: 0.7463, Val Loss: 0.4680, Val Acc: 0.8289
Epoch 43/300, Train Loss: 0.5287, Train Acc: 0.7463, Val Loss: 0.4651, Val Acc: 0.8158
Epoch 44/300, Train Loss: 0.5268, Train Acc: 0.7496, Val Loss: 0.4626, Val Acc: 0.8158
Epoch 45/300, Train Loss: 0.5204, Train Acc: 0.7529, Val Loss: 0.4592, Val Acc: 0.8158
Epoch 46/300, Train Loss: 0.5176, Train Acc: 0.7512, Val Loss: 0.4553, Val Acc: 0.8158
Epoch 47/300, Train Loss: 0.5191, Train Acc: 0.7562, Val Loss: 0.4510, Val Acc: 0.8158
Epoch 48/300, Train Loss: 0.5202, Train Acc: 0.7545, Val Loss: 0.4492, Val Acc: 0.8158
Epoch 49/300, Train Loss: 0.5073, Train Acc: 0.7611, Val Loss: 0.4473, Val Acc: 0.8158
Epoch 50/300, Train Loss: 0.5062, Train Acc: 0.7661, Val Loss: 0.4447, Val Acc: 0.8224
Epoch 51/300, Train Loss: 0.5083, Train Acc: 0.7661, Val Loss: 0.4426, Val Acc: 0.8289
Epoch 52/300, Train Loss: 0.5080, Train Acc: 0.7578, Val Loss: 0.4405, Val Acc: 0.8289
Epoch 53/300, Train Loss: 0.5068, Train Acc: 0.7595, Val Loss: 0.4389, Val Acc: 0.8092
Epoch 54/300, Train Loss: 0.4990, Train Acc: 0.7595, Val Loss: 0.4359, Val Acc: 0.8092
Epoch 55/300, Train Loss: 0.5007, Train Acc: 0.7578, Val Loss: 0.4346, Val Acc: 0.8092
Epoch 56/300, Train Loss: 0.5052, Train Acc: 0.7545, Val Loss: 0.4325, Val Acc: 0.8092
Epoch 57/300, Train Loss: 0.5023, Train Acc: 0.7562, Val Loss: 0.4327, Val Acc: 0.8026
Epoch 58/300, Train Loss: 0.4969, Train Acc: 0.7578, Val Loss: 0.4329, Val Acc: 0.7961
Epoch 59/300, Train Loss: 0.4955, Train Acc: 0.7562, Val Loss: 0.4284, Val Acc: 0.8026
Epoch 60/300, Train Loss: 0.4971, Train Acc: 0.7595, Val Loss: 0.4291, Val Acc: 0.7961
Epoch 61/300, Train Loss: 0.4928, Train Acc: 0.7545, Val Loss: 0.4271, Val Acc: 0.7961
Epoch 62/300, Train Loss: 0.4902, Train Acc: 0.7578, Val Loss: 0.4258, Val Acc: 0.7961
Epoch 63/300, Train Loss: 0.4909, Train Acc: 0.7463, Val Loss: 0.4241, Val Acc: 0.7961
Epoch 64/300, Train Loss: 0.4970, Train Acc: 0.7595, Val Loss: 0.4229, Val Acc: 0.7961
Epoch 65/300, Train Loss: 0.4892, Train Acc: 0.7595, Val Loss: 0.4234, Val Acc: 0.7961
Epoch 66/300, Train Loss: 0.4914, Train Acc: 0.7545, Val Loss: 0.4234, Val Acc: 0.7961
Epoch 67/300, Train Loss: 0.4937, Train Acc: 0.7628, Val Loss: 0.4232, Val Acc: 0.7961
Epoch 68/300, Train Loss: 0.4887, Train Acc: 0.7562, Val Loss: 0.4225, Val Acc: 0.7961
Epoch 69/300, Train Loss: 0.4890, Train Acc: 0.7562, Val Loss: 0.4214, Val Acc: 0.7961
Epoch 70/300, Train Loss: 0.4868, Train Acc: 0.7479, Val Loss: 0.4208, Val Acc: 0.7961
Epoch 71/300, Train Loss: 0.4883, Train Acc: 0.7529, Val Loss: 0.4197, Val Acc: 0.7961
Epoch 72/300, Train Loss: 0.4917, Train Acc: 0.7545, Val Loss: 0.4198, Val Acc: 0.7961
Epoch 73/300, Train Loss: 0.4849, Train Acc: 0.7628, Val Loss: 0.4182, Val Acc: 0.7961
Epoch 74/300, Train Loss: 0.4903, Train Acc: 0.7529, Val Loss: 0.4190, Val Acc: 0.7961
Epoch 75/300, Train Loss: 0.4965, Train Acc: 0.7562, Val Loss: 0.4196, Val Acc: 0.7961
Epoch 76/300, Train Loss: 0.4906, Train Acc: 0.7545, Val Loss: 0.4198, Val Acc: 0.7961
Epoch 77/300, Train Loss: 0.4893, Train Acc: 0.7529, Val Loss: 0.4189, Val Acc: 0.7961
Epoch 78/300, Train Loss: 0.4907, Train Acc: 0.7562, Val Loss: 0.4173, Val Acc: 0.7961
Epoch 79/300, Train Loss: 0.4828, Train Acc: 0.7496, Val Loss: 0.4168, Val Acc: 0.7961
Epoch 80/300, Train Loss: 0.4855, Train Acc: 0.7661, Val Loss: 0.4162, Val Acc: 0.8026
Epoch 81/300, Train Loss: 0.4880, Train Acc: 0.7578, Val Loss: 0.4169, Val Acc: 0.8026
Epoch 82/300, Train Loss: 0.4967, Train Acc: 0.7545, Val Loss: 0.4180, Val Acc: 0.7895
Epoch 83/300, Train Loss: 0.4864, Train Acc: 0.7578, Val Loss: 0.4187, Val Acc: 0.7829
Epoch 84/300, Train Loss: 0.4914, Train Acc: 0.7545, Val Loss: 0.4167, Val Acc: 0.7961
Epoch 85/300, Train Loss: 0.4818, Train Acc: 0.7595, Val Loss: 0.4154, Val Acc: 0.8026
Epoch 86/300, Train Loss: 0.4943, Train Acc: 0.7562, Val Loss: 0.4159, Val Acc: 0.8026
Epoch 87/300, Train Loss: 0.4830, Train Acc: 0.7595, Val Loss: 0.4165, Val Acc: 0.7961
Epoch 88/300, Train Loss: 0.4845, Train Acc: 0.7628, Val Loss: 0.4162, Val Acc: 0.7961
Epoch 89/300, Train Loss: 0.4790, Train Acc: 0.7611, Val Loss: 0.4163, Val Acc: 0.7961
Epoch 90/300, Train Loss: 0.4856, Train Acc: 0.7512, Val Loss: 0.4170, Val Acc: 0.7895
Epoch 91/300, Train Loss: 0.4853, Train Acc: 0.7562, Val Loss: 0.4151, Val Acc: 0.7961
Epoch 92/300, Train Loss: 0.4827, Train Acc: 0.7545, Val Loss: 0.4153, Val Acc: 0.7961
Epoch 93/300, Train Loss: 0.4887, Train Acc: 0.7661, Val Loss: 0.4175, Val Acc: 0.7895
Epoch 94/300, Train Loss: 0.4933, Train Acc: 0.7479, Val Loss: 0.4171, Val Acc: 0.7895
Epoch 95/300, Train Loss: 0.4836, Train Acc: 0.7545, Val Loss: 0.4171, Val Acc: 0.7895
Epoch 96/300, Train Loss: 0.4789, Train Acc: 0.7611, Val Loss: 0.4164, Val Acc: 0.7895
Epoch 97/300, Train Loss: 0.4831, Train Acc: 0.7529, Val Loss: 0.4159, Val Acc: 0.7895
Epoch 98/300, Train Loss: 0.4867, Train Acc: 0.7595, Val Loss: 0.4149, Val Acc: 0.7895
Epoch 99/300, Train Loss: 0.4818, Train Acc: 0.7595, Val Loss: 0.4154, Val Acc: 0.7895
Epoch 100/300, Train Loss: 0.4872, Train Acc: 0.7562, Val Loss: 0.4147, Val Acc: 0.7895
Epoch 101/300, Train Loss: 0.4828, Train Acc: 0.7529, Val Loss: 0.4158, Val Acc: 0.7895
Epoch 102/300, Train Loss: 0.4853, Train Acc: 0.7578, Val Loss: 0.4163, Val Acc: 0.7895
Epoch 103/300, Train Loss: 0.4844, Train Acc: 0.7628, Val Loss: 0.4170, Val Acc: 0.7829
Epoch 104/300, Train Loss: 0.4896, Train Acc: 0.7578, Val Loss: 0.4147, Val Acc: 0.7895
Epoch 105/300, Train Loss: 0.4853, Train Acc: 0.7562, Val Loss: 0.4162, Val Acc: 0.7895
Epoch 106/300, Train Loss: 0.4846, Train Acc: 0.7529, Val Loss: 0.4152, Val Acc: 0.7895
Epoch 107/300, Train Loss: 0.4832, Train Acc: 0.7562, Val Loss: 0.4159, Val Acc: 0.7829
Epoch 108/300, Train Loss: 0.4911, Train Acc: 0.7496, Val Loss: 0.4157, Val Acc: 0.7895
Epoch 109/300, Train Loss: 0.4808, Train Acc: 0.7496, Val Loss: 0.4163, Val Acc: 0.7829
Epoch 110/300, Train Loss: 0.4901, Train Acc: 0.7496, Val Loss: 0.4169, Val Acc: 0.7829
Epoch 111/300, Train Loss: 0.4832, Train Acc: 0.7529, Val Loss: 0.4154, Val Acc: 0.7829
Epoch 112/300, Train Loss: 0.4860, Train Acc: 0.7545, Val Loss: 0.4162, Val Acc: 0.7829
Epoch 113/300, Train Loss: 0.4828, Train Acc: 0.7611, Val Loss: 0.4156, Val Acc: 0.7829
Epoch 114/300, Train Loss: 0.4889, Train Acc: 0.7496, Val Loss: 0.4161, Val Acc: 0.7829
Epoch 115/300, Train Loss: 0.4863, Train Acc: 0.7496, Val Loss: 0.4150, Val Acc: 0.7829
Epoch 116/300, Train Loss: 0.4822, Train Acc: 0.7529, Val Loss: 0.4145, Val Acc: 0.7895
Epoch 117/300, Train Loss: 0.4790, Train Acc: 0.7562, Val Loss: 0.4148, Val Acc: 0.7829
Epoch 118/300, Train Loss: 0.4818, Train Acc: 0.7578, Val Loss: 0.4140, Val Acc: 0.7895
Epoch 119/300, Train Loss: 0.4840, Train Acc: 0.7529, Val Loss: 0.4152, Val Acc: 0.7829
Epoch 120/300, Train Loss: 0.4824, Train Acc: 0.7562, Val Loss: 0.4129, Val Acc: 0.7895
Epoch 121/300, Train Loss: 0.4890, Train Acc: 0.7512, Val Loss: 0.4136, Val Acc: 0.7895
Epoch 122/300, Train Loss: 0.4800, Train Acc: 0.7578, Val Loss: 0.4153, Val Acc: 0.7829
Epoch 123/300, Train Loss: 0.4896, Train Acc: 0.7562, Val Loss: 0.4158, Val Acc: 0.7829
Epoch 124/300, Train Loss: 0.4854, Train Acc: 0.7479, Val Loss: 0.4172, Val Acc: 0.7763
Epoch 125/300, Train Loss: 0.4822, Train Acc: 0.7578, Val Loss: 0.4158, Val Acc: 0.7829
Epoch 126/300, Train Loss: 0.4803, Train Acc: 0.7595, Val Loss: 0.4135, Val Acc: 0.7895
Epoch 127/300, Train Loss: 0.4859, Train Acc: 0.7496, Val Loss: 0.4142, Val Acc: 0.7829
Epoch 128/300, Train Loss: 0.4883, Train Acc: 0.7529, Val Loss: 0.4159, Val Acc: 0.7763
Epoch 129/300, Train Loss: 0.4854, Train Acc: 0.7545, Val Loss: 0.4165, Val Acc: 0.7763
Epoch 130/300, Train Loss: 0.4857, Train Acc: 0.7545, Val Loss: 0.4152, Val Acc: 0.7829
Epoch 131/300, Train Loss: 0.4758, Train Acc: 0.7562, Val Loss: 0.4143, Val Acc: 0.7829
Epoch 132/300, Train Loss: 0.4886, Train Acc: 0.7512, Val Loss: 0.4153, Val Acc: 0.7763
Epoch 133/300, Train Loss: 0.4854, Train Acc: 0.7463, Val Loss: 0.4144, Val Acc: 0.7829
Epoch 134/300, Train Loss: 0.4834, Train Acc: 0.7595, Val Loss: 0.4149, Val Acc: 0.7763
Epoch 135/300, Train Loss: 0.4779, Train Acc: 0.7545, Val Loss: 0.4147, Val Acc: 0.7763
Epoch 136/300, Train Loss: 0.4836, Train Acc: 0.7496, Val Loss: 0.4149, Val Acc: 0.7763
Epoch 137/300, Train Loss: 0.4798, Train Acc: 0.7562, Val Loss: 0.4140, Val Acc: 0.7829
Epoch 138/300, Train Loss: 0.4856, Train Acc: 0.7529, Val Loss: 0.4143, Val Acc: 0.7763
Epoch 139/300, Train Loss: 0.4842, Train Acc: 0.7611, Val Loss: 0.4138, Val Acc: 0.7829
Epoch 140/300, Train Loss: 0.4772, Train Acc: 0.7578, Val Loss: 0.4132, Val Acc: 0.7829
Epoch 141/300, Train Loss: 0.4861, Train Acc: 0.7496, Val Loss: 0.4142, Val Acc: 0.7763
Epoch 142/300, Train Loss: 0.4779, Train Acc: 0.7578, Val Loss: 0.4154, Val Acc: 0.7763
Epoch 143/300, Train Loss: 0.4779, Train Acc: 0.7512, Val Loss: 0.4146, Val Acc: 0.7763
Epoch 144/300, Train Loss: 0.4829, Train Acc: 0.7644, Val Loss: 0.4138, Val Acc: 0.7763
Epoch 145/300, Train Loss: 0.4801, Train Acc: 0.7628, Val Loss: 0.4136, Val Acc: 0.7763
Epoch 146/300, Train Loss: 0.4842, Train Acc: 0.7496, Val Loss: 0.4142, Val Acc: 0.7763
Epoch 147/300, Train Loss: 0.4845, Train Acc: 0.7529, Val Loss: 0.4146, Val Acc: 0.7763
Epoch 148/300, Train Loss: 0.4775, Train Acc: 0.7595, Val Loss: 0.4144, Val Acc: 0.7763
Epoch 149/300, Train Loss: 0.4805, Train Acc: 0.7446, Val Loss: 0.4130, Val Acc: 0.7895
Epoch 150/300, Train Loss: 0.4838, Train Acc: 0.7562, Val Loss: 0.4142, Val Acc: 0.7763
Epoch 151/300, Train Loss: 0.4900, Train Acc: 0.7562, Val Loss: 0.4151, Val Acc: 0.7763
Epoch 152/300, Train Loss: 0.4791, Train Acc: 0.7463, Val Loss: 0.4141, Val Acc: 0.7763
Epoch 153/300, Train Loss: 0.4792, Train Acc: 0.7545, Val Loss: 0.4147, Val Acc: 0.7763
Epoch 154/300, Train Loss: 0.4814, Train Acc: 0.7512, Val Loss: 0.4152, Val Acc: 0.7763
Epoch 155/300, Train Loss: 0.4736, Train Acc: 0.7529, Val Loss: 0.4132, Val Acc: 0.7829
Epoch 156/300, Train Loss: 0.4852, Train Acc: 0.7611, Val Loss: 0.4145, Val Acc: 0.7763
Epoch 157/300, Train Loss: 0.4828, Train Acc: 0.7595, Val Loss: 0.4132, Val Acc: 0.7829
Epoch 158/300, Train Loss: 0.4798, Train Acc: 0.7545, Val Loss: 0.4143, Val Acc: 0.7763
Epoch 159/300, Train Loss: 0.4832, Train Acc: 0.7512, Val Loss: 0.4150, Val Acc: 0.7763
Epoch 160/300, Train Loss: 0.4789, Train Acc: 0.7512, Val Loss: 0.4150, Val Acc: 0.7763
Epoch 161/300, Train Loss: 0.4806, Train Acc: 0.7479, Val Loss: 0.4142, Val Acc: 0.7763
Epoch 162/300, Train Loss: 0.4835, Train Acc: 0.7595, Val Loss: 0.4140, Val Acc: 0.7763
Epoch 163/300, Train Loss: 0.4796, Train Acc: 0.7479, Val Loss: 0.4143, Val Acc: 0.7763
Epoch 164/300, Train Loss: 0.4821, Train Acc: 0.7529, Val Loss: 0.4158, Val Acc: 0.7697
Epoch 165/300, Train Loss: 0.4828, Train Acc: 0.7545, Val Loss: 0.4133, Val Acc: 0.7829
Epoch 166/300, Train Loss: 0.4878, Train Acc: 0.7512, Val Loss: 0.4140, Val Acc: 0.7763
Epoch 167/300, Train Loss: 0.4854, Train Acc: 0.7463, Val Loss: 0.4167, Val Acc: 0.7697
Epoch 168/300, Train Loss: 0.4875, Train Acc: 0.7479, Val Loss: 0.4152, Val Acc: 0.7763
Epoch 169/300, Train Loss: 0.4864, Train Acc: 0.7479, Val Loss: 0.4150, Val Acc: 0.7763
Epoch 170/300, Train Loss: 0.4763, Train Acc: 0.7529, Val Loss: 0.4142, Val Acc: 0.7763
Epoch 171/300, Train Loss: 0.4843, Train Acc: 0.7446, Val Loss: 0.4154, Val Acc: 0.7763
Epoch 172/300, Train Loss: 0.4769, Train Acc: 0.7545, Val Loss: 0.4145, Val Acc: 0.7763
Epoch 173/300, Train Loss: 0.4846, Train Acc: 0.7595, Val Loss: 0.4155, Val Acc: 0.7697
Epoch 174/300, Train Loss: 0.4831, Train Acc: 0.7512, Val Loss: 0.4145, Val Acc: 0.7763
Epoch 175/300, Train Loss: 0.4922, Train Acc: 0.7496, Val Loss: 0.4144, Val Acc: 0.7763
Epoch 176/300, Train Loss: 0.4826, Train Acc: 0.7479, Val Loss: 0.4161, Val Acc: 0.7697
Epoch 177/300, Train Loss: 0.4793, Train Acc: 0.7611, Val Loss: 0.4141, Val Acc: 0.7763
Epoch 178/300, Train Loss: 0.4768, Train Acc: 0.7644, Val Loss: 0.4134, Val Acc: 0.7829
Epoch 179/300, Train Loss: 0.4837, Train Acc: 0.7562, Val Loss: 0.4141, Val Acc: 0.7763
Epoch 180/300, Train Loss: 0.4831, Train Acc: 0.7496, Val Loss: 0.4136, Val Acc: 0.7829
Epoch 181/300, Train Loss: 0.4824, Train Acc: 0.7562, Val Loss: 0.4141, Val Acc: 0.7763
Epoch 182/300, Train Loss: 0.4786, Train Acc: 0.7562, Val Loss: 0.4133, Val Acc: 0.7829
Epoch 183/300, Train Loss: 0.4826, Train Acc: 0.7628, Val Loss: 0.4139, Val Acc: 0.7829
Epoch 184/300, Train Loss: 0.4858, Train Acc: 0.7545, Val Loss: 0.4160, Val Acc: 0.7697
Epoch 185/300, Train Loss: 0.4847, Train Acc: 0.7529, Val Loss: 0.4142, Val Acc: 0.7829
Epoch 186/300, Train Loss: 0.4776, Train Acc: 0.7496, Val Loss: 0.4148, Val Acc: 0.7763
Epoch 187/300, Train Loss: 0.4846, Train Acc: 0.7562, Val Loss: 0.4147, Val Acc: 0.7763
Epoch 188/300, Train Loss: 0.4744, Train Acc: 0.7529, Val Loss: 0.4142, Val Acc: 0.7763
Epoch 189/300, Train Loss: 0.4845, Train Acc: 0.7545, Val Loss: 0.4147, Val Acc: 0.7763
Epoch 190/300, Train Loss: 0.4802, Train Acc: 0.7512, Val Loss: 0.4143, Val Acc: 0.7763
Epoch 191/300, Train Loss: 0.4831, Train Acc: 0.7496, Val Loss: 0.4149, Val Acc: 0.7697
Epoch 192/300, Train Loss: 0.4767, Train Acc: 0.7496, Val Loss: 0.4142, Val Acc: 0.7763
Epoch 193/300, Train Loss: 0.4786, Train Acc: 0.7512, Val Loss: 0.4139, Val Acc: 0.7829
Epoch 194/300, Train Loss: 0.4810, Train Acc: 0.7545, Val Loss: 0.4132, Val Acc: 0.7829
Epoch 195/300, Train Loss: 0.4748, Train Acc: 0.7512, Val Loss: 0.4131, Val Acc: 0.7829
Epoch 196/300, Train Loss: 0.4750, Train Acc: 0.7479, Val Loss: 0.4131, Val Acc: 0.7829
Epoch 197/300, Train Loss: 0.4801, Train Acc: 0.7512, Val Loss: 0.4131, Val Acc: 0.7829
Epoch 198/300, Train Loss: 0.4819, Train Acc: 0.7545, Val Loss: 0.4129, Val Acc: 0.7829
Epoch 199/300, Train Loss: 0.4840, Train Acc: 0.7496, Val Loss: 0.4145, Val Acc: 0.7763
Epoch 200/300, Train Loss: 0.4786, Train Acc: 0.7611, Val Loss: 0.4145, Val Acc: 0.7763
Epoch 201/300, Train Loss: 0.4813, Train Acc: 0.7446, Val Loss: 0.4157, Val Acc: 0.7697
Epoch 202/300, Train Loss: 0.4858, Train Acc: 0.7545, Val Loss: 0.4145, Val Acc: 0.7697
Epoch 203/300, Train Loss: 0.4814, Train Acc: 0.7562, Val Loss: 0.4150, Val Acc: 0.7697
Epoch 204/300, Train Loss: 0.4797, Train Acc: 0.7611, Val Loss: 0.4134, Val Acc: 0.7829
Epoch 205/300, Train Loss: 0.4863, Train Acc: 0.7628, Val Loss: 0.4132, Val Acc: 0.7829
Epoch 206/300, Train Loss: 0.4813, Train Acc: 0.7545, Val Loss: 0.4145, Val Acc: 0.7763
Epoch 207/300, Train Loss: 0.4817, Train Acc: 0.7545, Val Loss: 0.4138, Val Acc: 0.7829
Epoch 208/300, Train Loss: 0.4877, Train Acc: 0.7661, Val Loss: 0.4140, Val Acc: 0.7829
Epoch 209/300, Train Loss: 0.4787, Train Acc: 0.7578, Val Loss: 0.4143, Val Acc: 0.7763
Epoch 210/300, Train Loss: 0.4836, Train Acc: 0.7430, Val Loss: 0.4145, Val Acc: 0.7697
Epoch 211/300, Train Loss: 0.4743, Train Acc: 0.7578, Val Loss: 0.4139, Val Acc: 0.7829
Epoch 212/300, Train Loss: 0.4795, Train Acc: 0.7529, Val Loss: 0.4141, Val Acc: 0.7829
Epoch 213/300, Train Loss: 0.4821, Train Acc: 0.7512, Val Loss: 0.4139, Val Acc: 0.7829
Epoch 214/300, Train Loss: 0.4805, Train Acc: 0.7545, Val Loss: 0.4142, Val Acc: 0.7763
Epoch 215/300, Train Loss: 0.4807, Train Acc: 0.7529, Val Loss: 0.4150, Val Acc: 0.7697
Epoch 216/300, Train Loss: 0.4793, Train Acc: 0.7578, Val Loss: 0.4134, Val Acc: 0.7829
Epoch 217/300, Train Loss: 0.4816, Train Acc: 0.7479, Val Loss: 0.4148, Val Acc: 0.7697
Epoch 218/300, Train Loss: 0.4831, Train Acc: 0.7479, Val Loss: 0.4124, Val Acc: 0.7829
Epoch 219/300, Train Loss: 0.4714, Train Acc: 0.7529, Val Loss: 0.4132, Val Acc: 0.7829
Epoch 220/300, Train Loss: 0.4795, Train Acc: 0.7479, Val Loss: 0.4130, Val Acc: 0.7829
Epoch 221/300, Train Loss: 0.4822, Train Acc: 0.7611, Val Loss: 0.4124, Val Acc: 0.7829
Epoch 222/300, Train Loss: 0.4892, Train Acc: 0.7529, Val Loss: 0.4135, Val Acc: 0.7829
Epoch 223/300, Train Loss: 0.4810, Train Acc: 0.7595, Val Loss: 0.4133, Val Acc: 0.7829
Epoch 224/300, Train Loss: 0.4809, Train Acc: 0.7529, Val Loss: 0.4146, Val Acc: 0.7763
Epoch 225/300, Train Loss: 0.4789, Train Acc: 0.7512, Val Loss: 0.4135, Val Acc: 0.7763
Epoch 226/300, Train Loss: 0.4805, Train Acc: 0.7545, Val Loss: 0.4145, Val Acc: 0.7763
Epoch 227/300, Train Loss: 0.4748, Train Acc: 0.7512, Val Loss: 0.4142, Val Acc: 0.7763
Epoch 228/300, Train Loss: 0.4811, Train Acc: 0.7529, Val Loss: 0.4140, Val Acc: 0.7763
Epoch 229/300, Train Loss: 0.4780, Train Acc: 0.7562, Val Loss: 0.4131, Val Acc: 0.7829
Epoch 230/300, Train Loss: 0.4851, Train Acc: 0.7595, Val Loss: 0.4132, Val Acc: 0.7829
Epoch 231/300, Train Loss: 0.4823, Train Acc: 0.7479, Val Loss: 0.4130, Val Acc: 0.7829
Epoch 232/300, Train Loss: 0.4782, Train Acc: 0.7512, Val Loss: 0.4135, Val Acc: 0.7829
Epoch 233/300, Train Loss: 0.4785, Train Acc: 0.7512, Val Loss: 0.4140, Val Acc: 0.7763
Epoch 234/300, Train Loss: 0.4799, Train Acc: 0.7578, Val Loss: 0.4150, Val Acc: 0.7697
Epoch 235/300, Train Loss: 0.4798, Train Acc: 0.7545, Val Loss: 0.4138, Val Acc: 0.7763
Epoch 236/300, Train Loss: 0.4818, Train Acc: 0.7529, Val Loss: 0.4151, Val Acc: 0.7697
Epoch 237/300, Train Loss: 0.4784, Train Acc: 0.7562, Val Loss: 0.4132, Val Acc: 0.7829
Epoch 238/300, Train Loss: 0.4760, Train Acc: 0.7529, Val Loss: 0.4119, Val Acc: 0.7829
Epoch 239/300, Train Loss: 0.4781, Train Acc: 0.7529, Val Loss: 0.4118, Val Acc: 0.7829
Epoch 240/300, Train Loss: 0.4797, Train Acc: 0.7545, Val Loss: 0.4120, Val Acc: 0.7829
Epoch 241/300, Train Loss: 0.4793, Train Acc: 0.7578, Val Loss: 0.4133, Val Acc: 0.7829
Epoch 242/300, Train Loss: 0.4825, Train Acc: 0.7545, Val Loss: 0.4136, Val Acc: 0.7763
Epoch 243/300, Train Loss: 0.4781, Train Acc: 0.7479, Val Loss: 0.4142, Val Acc: 0.7763
Epoch 244/300, Train Loss: 0.4802, Train Acc: 0.7512, Val Loss: 0.4133, Val Acc: 0.7763
Epoch 245/300, Train Loss: 0.4830, Train Acc: 0.7479, Val Loss: 0.4124, Val Acc: 0.7829
Epoch 246/300, Train Loss: 0.4844, Train Acc: 0.7578, Val Loss: 0.4135, Val Acc: 0.7763
Epoch 247/300, Train Loss: 0.4757, Train Acc: 0.7496, Val Loss: 0.4128, Val Acc: 0.7829
Epoch 248/300, Train Loss: 0.4774, Train Acc: 0.7611, Val Loss: 0.4132, Val Acc: 0.7829
Epoch 249/300, Train Loss: 0.4850, Train Acc: 0.7479, Val Loss: 0.4132, Val Acc: 0.7829
Epoch 250/300, Train Loss: 0.4811, Train Acc: 0.7479, Val Loss: 0.4131, Val Acc: 0.7829
Epoch 251/300, Train Loss: 0.4812, Train Acc: 0.7545, Val Loss: 0.4137, Val Acc: 0.7763
Epoch 252/300, Train Loss: 0.4827, Train Acc: 0.7512, Val Loss: 0.4146, Val Acc: 0.7763
Epoch 253/300, Train Loss: 0.4768, Train Acc: 0.7578, Val Loss: 0.4142, Val Acc: 0.7763
Epoch 254/300, Train Loss: 0.4792, Train Acc: 0.7644, Val Loss: 0.4130, Val Acc: 0.7829
Epoch 255/300, Train Loss: 0.4812, Train Acc: 0.7545, Val Loss: 0.4134, Val Acc: 0.7763
Epoch 256/300, Train Loss: 0.4768, Train Acc: 0.7529, Val Loss: 0.4131, Val Acc: 0.7763
Epoch 257/300, Train Loss: 0.4785, Train Acc: 0.7595, Val Loss: 0.4128, Val Acc: 0.7829
Epoch 258/300, Train Loss: 0.4817, Train Acc: 0.7578, Val Loss: 0.4140, Val Acc: 0.7763
Epoch 259/300, Train Loss: 0.4809, Train Acc: 0.7512, Val Loss: 0.4139, Val Acc: 0.7763
Epoch 260/300, Train Loss: 0.4777, Train Acc: 0.7529, Val Loss: 0.4144, Val Acc: 0.7763
Epoch 261/300, Train Loss: 0.4823, Train Acc: 0.7479, Val Loss: 0.4136, Val Acc: 0.7763
Epoch 262/300, Train Loss: 0.4783, Train Acc: 0.7578, Val Loss: 0.4145, Val Acc: 0.7763
Epoch 263/300, Train Loss: 0.4813, Train Acc: 0.7512, Val Loss: 0.4143, Val Acc: 0.7763
Epoch 264/300, Train Loss: 0.4797, Train Acc: 0.7611, Val Loss: 0.4144, Val Acc: 0.7763
Epoch 265/300, Train Loss: 0.4751, Train Acc: 0.7562, Val Loss: 0.4142, Val Acc: 0.7763
Epoch 266/300, Train Loss: 0.4771, Train Acc: 0.7545, Val Loss: 0.4139, Val Acc: 0.7763
Epoch 267/300, Train Loss: 0.4809, Train Acc: 0.7512, Val Loss: 0.4136, Val Acc: 0.7763
Epoch 268/300, Train Loss: 0.4726, Train Acc: 0.7562, Val Loss: 0.4129, Val Acc: 0.7829
Epoch 269/300, Train Loss: 0.4746, Train Acc: 0.7529, Val Loss: 0.4137, Val Acc: 0.7763
Epoch 270/300, Train Loss: 0.4800, Train Acc: 0.7463, Val Loss: 0.4129, Val Acc: 0.7763
Epoch 271/300, Train Loss: 0.4810, Train Acc: 0.7562, Val Loss: 0.4129, Val Acc: 0.7763
Epoch 272/300, Train Loss: 0.4780, Train Acc: 0.7479, Val Loss: 0.4130, Val Acc: 0.7829
Epoch 273/300, Train Loss: 0.4790, Train Acc: 0.7562, Val Loss: 0.4129, Val Acc: 0.7763
Epoch 274/300, Train Loss: 0.4825, Train Acc: 0.7545, Val Loss: 0.4129, Val Acc: 0.7763
Epoch 275/300, Train Loss: 0.4742, Train Acc: 0.7512, Val Loss: 0.4136, Val Acc: 0.7763
Epoch 276/300, Train Loss: 0.4882, Train Acc: 0.7381, Val Loss: 0.4145, Val Acc: 0.7763
Epoch 277/300, Train Loss: 0.4838, Train Acc: 0.7562, Val Loss: 0.4151, Val Acc: 0.7697
Epoch 278/300, Train Loss: 0.4779, Train Acc: 0.7529, Val Loss: 0.4145, Val Acc: 0.7763
Epoch 279/300, Train Loss: 0.4826, Train Acc: 0.7529, Val Loss: 0.4146, Val Acc: 0.7763
Epoch 280/300, Train Loss: 0.4847, Train Acc: 0.7529, Val Loss: 0.4153, Val Acc: 0.7697
Epoch 281/300, Train Loss: 0.4811, Train Acc: 0.7545, Val Loss: 0.4163, Val Acc: 0.7697
Epoch 282/300, Train Loss: 0.4767, Train Acc: 0.7545, Val Loss: 0.4149, Val Acc: 0.7697
Epoch 283/300, Train Loss: 0.4808, Train Acc: 0.7512, Val Loss: 0.4146, Val Acc: 0.7763
Epoch 284/300, Train Loss: 0.4775, Train Acc: 0.7578, Val Loss: 0.4131, Val Acc: 0.7763
Epoch 285/300, Train Loss: 0.4809, Train Acc: 0.7545, Val Loss: 0.4139, Val Acc: 0.7763
Epoch 286/300, Train Loss: 0.4757, Train Acc: 0.7529, Val Loss: 0.4149, Val Acc: 0.7763
Epoch 287/300, Train Loss: 0.4794, Train Acc: 0.7545, Val Loss: 0.4127, Val Acc: 0.7763
Epoch 288/300, Train Loss: 0.4823, Train Acc: 0.7479, Val Loss: 0.4144, Val Acc: 0.7763
Epoch 289/300, Train Loss: 0.4789, Train Acc: 0.7578, Val Loss: 0.4140, Val Acc: 0.7763
Epoch 290/300, Train Loss: 0.4755, Train Acc: 0.7430, Val Loss: 0.4138, Val Acc: 0.7763
Epoch 291/300, Train Loss: 0.4809, Train Acc: 0.7463, Val Loss: 0.4131, Val Acc: 0.7763
Epoch 292/300, Train Loss: 0.4834, Train Acc: 0.7595, Val Loss: 0.4136, Val Acc: 0.7763
Epoch 293/300, Train Loss: 0.4812, Train Acc: 0.7479, Val Loss: 0.4141, Val Acc: 0.7763
Epoch 294/300, Train Loss: 0.4816, Train Acc: 0.7529, Val Loss: 0.4136, Val Acc: 0.7763
Epoch 295/300, Train Loss: 0.4773, Train Acc: 0.7545, Val Loss: 0.4130, Val Acc: 0.7763
Epoch 296/300, Train Loss: 0.4759, Train Acc: 0.7430, Val Loss: 0.4139, Val Acc: 0.7763
Epoch 297/300, Train Loss: 0.4806, Train Acc: 0.7545, Val Loss: 0.4131, Val Acc: 0.7763
Epoch 298/300, Train Loss: 0.4826, Train Acc: 0.7578, Val Loss: 0.4141, Val Acc: 0.7763
Epoch 299/300, Train Loss: 0.4713, Train Acc: 0.7595, Val Loss: 0.4143, Val Acc: 0.7763
Epoch 300/300, Train Loss: 0.4777, Train Acc: 0.7545, Val Loss: 0.4131, Val Acc: 0.7763
Test Accuracy: 0.7763

曲线图:

代码:

python 复制代码
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import matplotlib.pyplot as plt

# 加载数据
data = pd.read_csv('diabetes.csv', header=None)
X = data.iloc[:, :-1].values
y = data.iloc[:, -1].values

# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 转换为PyTorch张量 - 注意这里不需要unsqueeze(2)
X_train = torch.FloatTensor(X_train)  # 形状为 (样本数, 8)
X_test = torch.FloatTensor(X_test)  # 形状为 (样本数, 8)
y_train = torch.FloatTensor(y_train)
y_test = torch.FloatTensor(y_test)

# 创建DataLoader
train_data = TensorDataset(X_train, y_train)
test_data = TensorDataset(X_test, y_test)
train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
test_loader = DataLoader(test_data, batch_size=32)


# 定义LSTM模型
class LSTMModel(nn.Module):
    def __init__(self, input_size=8, hidden_size1=64, hidden_size2=32):
        super(LSTMModel, self).__init__()
        self.lstm1 = nn.LSTM(input_size, hidden_size1, batch_first=True)
        self.dropout1 = nn.Dropout(0.3)
        self.lstm2 = nn.LSTM(hidden_size1, hidden_size2, batch_first=True)
        self.dropout2 = nn.Dropout(0.3)
        self.fc = nn.Linear(hidden_size2, 1)

    def forward(self, x):
        # 添加序列长度维度 (batch_size, 1, input_size)
        x = x.unsqueeze(1)  # 从(batch_size, 8)变为(batch_size, 1, 8)

        # 第一层LSTM
        x, _ = self.lstm1(x)
        x = self.dropout1(x)

        # 第二层LSTM
        x, (hn, cn) = self.lstm2(x)
        x = self.dropout2(hn[-1])  # 取最后一个时间步的隐藏状态

        x = self.fc(x)
        return torch.sigmoid(x.squeeze())


# 初始化模型
model = LSTMModel(input_size=8)  # 8个特征
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.0002)

# 训练和验证记录
train_losses = []
train_accs = []
val_losses = []
val_accs = []

# 训练模型
epochs = 300
for epoch in range(epochs):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for inputs, labels in train_loader:
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        predicted = (outputs > 0.5).float()
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    train_loss = running_loss / len(train_loader)
    train_acc = correct / total
    train_losses.append(train_loss)
    train_accs.append(train_acc)

    # 验证
    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, labels in test_loader:
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            predicted = (outputs > 0.5).float()
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    val_loss = val_loss / len(test_loader)
    val_acc = correct / total
    val_losses.append(val_loss)
    val_accs.append(val_acc)

    print(
        f'Epoch {epoch + 1}/{epochs}, Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}')

# 评估模型
model.eval()
with torch.no_grad():
    outputs = model(X_test)
    predicted = (outputs > 0.5).float()
    accuracy = (predicted == y_test).float().mean()
print(f'Test Accuracy: {accuracy:.4f}')

# 绘制训练曲线
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(train_accs, label='Training Accuracy')
plt.plot(val_accs, label='Validation Accuracy')
plt.title('Training and Validation Accuracy')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(train_losses, label='Training Loss')
plt.plot(val_losses, label='Validation Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.show()

下载:基于LSTM实现的糖尿病分类项目资源-CSDN文库

相关推荐
Y1nhl4 分钟前
搜广推校招面经八十一
开发语言·人工智能·pytorch·深度学习·机器学习·推荐算法·搜索算法
胡攀峰4 分钟前
第12章 微调生成模型
人工智能·大模型·llm·sft·强化学习·rlhf·指令微调
yuanlaile5 分钟前
AI大模型自然语言处理能力案例演示
人工智能·ai·自然语言处理
小白白搭建9 分钟前
WordPress AI 原创文章自动生成插件 24小时全自动生成SEO原创文章 | 多语言支持 | 智能配图与排版
人工智能
Jamence11 分钟前
多模态大语言模型arxiv论文略读(三十九)
人工智能·语言模型·自然语言处理
ai大模型木子35 分钟前
嵌入模型(Embedding Models)原理详解:从Word2Vec到BERT的技术演进
人工智能·自然语言处理·bert·embedding·word2vec·ai大模型·大模型资料
普if加的帕2 小时前
java Springboot使用扣子Coze实现实时音频对话智能客服
java·开发语言·人工智能·spring boot·实时音视频·智能客服
KoiC3 小时前
Dify接入RAGFlow无返回结果
人工智能·ai应用
lilye663 小时前
精益数据分析(20/126):解析经典数据分析框架,助力创业增长
大数据·人工智能·数据分析
盈达科技3 小时前
盈达科技:登顶GEO优化全球制高点,以AICC定义AI时代内容智能优化新标杆
大数据·人工智能