深度学习项目--基于RNN的阿尔茨海默病诊断研究(pytorch实现)

前言

  • 其实这个项目比较适合机器学习做,用XGBoost会更好,这个项目更适合RNN 学习案例,测试集准确率达到百分之84.2,效果还是算过得去,但是用其他模型会更好,机器学习的方法后面会更新
  • RNN讲解: 深度学习基础--一文搞懂RNN
  • 欢迎收藏 + 关注,本人将会持续更新

文章目录

1、导入数据

python 复制代码
import pandas as pd  
import numpy as np 
import matplotlib.pyplot as plt  
import seaborn as sns 
import torch  
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset

plt.rcParams["font.sans-serif"] = ["Microsoft YaHei"]  # 显示中文
plt.rcParams['axes.unicode_minus'] = False		# 显示负号

data_df = pd.read_csv("alzheimers_disease_data.csv")

data_df.head()

| | PatientID | Age | Gender | Ethnicity | EducationLevel | BMI | Smoking | AlcoholConsumption | PhysicalActivity | DietQuality | ... | MemoryComplaints | BehavioralProblems | ADL | Confusion | Disorientation | PersonalityChanges | DifficultyCompletingTasks | Forgetfulness | Diagnosis | DoctorInCharge |
| 0 | 4751 | 73 | 0 | 0 | 2 | 22.927749 | 0 | 13.297218 | 6.327112 | 1.347214 | ... | 0 | 0 | 1.725883 | 0 | 0 | 0 | 1 | 0 | 0 | XXXConfid |
| 1 | 4752 | 89 | 0 | 0 | 0 | 26.827681 | 0 | 4.542524 | 7.619885 | 0.518767 | ... | 0 | 0 | 2.592424 | 0 | 0 | 0 | 0 | 1 | 0 | XXXConfid |
| 2 | 4753 | 73 | 0 | 3 | 1 | 17.795882 | 0 | 19.555085 | 7.844988 | 1.826335 | ... | 0 | 0 | 7.119548 | 0 | 1 | 0 | 1 | 0 | 0 | XXXConfid |
| 3 | 4754 | 74 | 1 | 0 | 1 | 33.800817 | 1 | 12.209266 | 8.428001 | 7.435604 | ... | 0 | 1 | 6.481226 | 0 | 0 | 0 | 0 | 0 | 0 | XXXConfid |

4 4755 89 0 0 0 20.716974 0 18.454356 6.310461 0.795498 ... 0 0 0.014691 0 0 1 1 0 0 XXXConfid

5 rows × 35 columns

该数据集是2149名被诊断患有阿尔茨海默病或有阿尔茨海默病风险的患者的健康记录的综合集合。数据集中的每个患者都有一个唯一的ID号,范围从4751到6900。该数据集涵盖了广泛的信息,这些信息对于理解与阿尔茨海默病相关的各种因素至关重要。它包括人口统计细节、生活习惯、病史、临床测量、认知和功能评估、症状和诊断信息。

python 复制代码
# 标签中文化
data_df.rename(columns={ "Age": "年龄", "Gender": "性别", "Ethnicity": "种族", "EducationLevel": "教育水平", "BMI": "身体质量指数(BMI)", "Smoking": "吸烟状况", "AlcoholConsumption": "酒精摄入量", "PhysicalActivity": "体育活动时间", "DietQuality": "饮食质量评分", "SleepQuality": "睡眠质量评分", "FamilyHistoryAlzheimers": "家族阿尔茨海默病史", "CardiovascularDisease": "心血管疾病", "Diabetes": "糖尿病", "Depression": "抑郁症史", "HeadInjury": "头部受伤", "Hypertension": "高血压", "SystolicBP": "收缩压", "DiastolicBP": "舒张压", "CholesterolTotal": "胆固醇总量", "CholesterolLDL": "低密度脂蛋白胆固醇(LDL)", "CholesterolHDL": "高密度脂蛋白胆固醇(HDL)", "CholesterolTriglycerides": "甘油三酯", "MMSE": "简易精神状态检查(MMSE)得分", "FunctionalAssessment": "功能评估得分", "MemoryComplaints": "记忆抱怨", "BehavioralProblems": "行为问题", "ADL": "日常生活活动(ADL)得分", "Confusion": "混乱与定向障碍", "Disorientation": "迷失方向", "PersonalityChanges": "人格变化", "DifficultyCompletingTasks": "完成任务困难", "Forgetfulness": "健忘", "Diagnosis": "诊断状态", "DoctorInCharge": "主诊医生" },inplace=True)

data_df.columns
Index(['PatientID', '年龄', '性别', '种族', '教育水平', '身体质量指数(BMI)', '吸烟状况', '酒精摄入量',
       '体育活动时间', '饮食质量评分', '睡眠质量评分', '家族阿尔茨海默病史', '心血管疾病', '糖尿病', '抑郁症史',
       '头部受伤', '高血压', '收缩压', '舒张压', '胆固醇总量', '低密度脂蛋白胆固醇(LDL)',
       '高密度脂蛋白胆固醇(HDL)', '甘油三酯', '简易精神状态检查(MMSE)得分', '功能评估得分', '记忆抱怨', '行为问题',
       '日常生活活动(ADL)得分', '混乱与定向障碍', '迷失方向', '人格变化', '完成任务困难', '健忘', '诊断状态',
       '主诊医生'],
      dtype='object')

2、数据处理

python 复制代码
data_df.isnull().sum()
PatientID           0
年龄                  0
性别                  0
种族                  0
教育水平                0
身体质量指数(BMI)         0
吸烟状况                0
酒精摄入量               0
体育活动时间              0
饮食质量评分              0
睡眠质量评分              0
家族阿尔茨海默病史           0
心血管疾病               0
糖尿病                 0
抑郁症史                0
头部受伤                0
高血压                 0
收缩压                 0
舒张压                 0
胆固醇总量               0
低密度脂蛋白胆固醇(LDL)      0
高密度脂蛋白胆固醇(HDL)      0
甘油三酯                0
简易精神状态检查(MMSE)得分    0
功能评估得分              0
记忆抱怨                0
行为问题                0
日常生活活动(ADL)得分       0
混乱与定向障碍             0
迷失方向                0
人格变化                0
完成任务困难              0
健忘                  0
诊断状态                0
主诊医生                0
dtype: int64
python 复制代码
from sklearn.preprocessing import LabelEncoder

# 创建 LabelEncoder 实例
label_encoder = LabelEncoder()

# 对非数值型列进行标签编码
data_df['主诊医生'] = label_encoder.fit_transform(data_df['主诊医生'])

data_df.head()

| | PatientID | 年龄 | 性别 | 种族 | 教育水平 | 身体质量指数(BMI) | 吸烟状况 | 酒精摄入量 | 体育活动时间 | 饮食质量评分 | ... | 记忆抱怨 | 行为问题 | 日常生活活动(ADL)得分 | 混乱与定向障碍 | 迷失方向 | 人格变化 | 完成任务困难 | 健忘 | 诊断状态 | 主诊医生 |
| 0 | 4751 | 73 | 0 | 0 | 2 | 22.927749 | 0 | 13.297218 | 6.327112 | 1.347214 | ... | 0 | 0 | 1.725883 | 0 | 0 | 0 | 1 | 0 | 0 | 0 |
| 1 | 4752 | 89 | 0 | 0 | 0 | 26.827681 | 0 | 4.542524 | 7.619885 | 0.518767 | ... | 0 | 0 | 2.592424 | 0 | 0 | 0 | 0 | 1 | 0 | 0 |
| 2 | 4753 | 73 | 0 | 3 | 1 | 17.795882 | 0 | 19.555085 | 7.844988 | 1.826335 | ... | 0 | 0 | 7.119548 | 0 | 1 | 0 | 1 | 0 | 0 | 0 |
| 3 | 4754 | 74 | 1 | 0 | 1 | 33.800817 | 1 | 12.209266 | 8.428001 | 7.435604 | ... | 0 | 1 | 6.481226 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |

4 4755 89 0 0 0 20.716974 0 18.454356 6.310461 0.795498 ... 0 0 0.014691 0 0 1 1 0 0 0

5 rows × 35 columns

1、患病占比

python 复制代码
# 计算是否患病, 人数
counts = data_df["诊断状态"].value_counts()

# 计算百分比
sizes = counts / counts.sum() * 100

# 绘制环形图
fig, ax = plt.subplots()
wedges, texts, autotexts = ax.pie(sizes, labels=sizes.index, autopct='%1.2ff%%', startangle=90, wedgeprops=dict(width=0.3))

plt.title("患病占比(1患病,0没有患病)")

plt.show()


患病人数居多

2、相关性分析

python 复制代码
plt.figure(figsize=(40, 35))
sns.heatmap(data_df.corr(), annot=True, fmt=".2f")
plt.show()


其中,与患病相关性比较强的有:MMSE得分、功能评估得分、记忆抱怨、行为问题等相关性比较强,其中,MMSE得分、功能评估得分为负相关,记忆抱怨、行为问题为正相关。

3、年龄与患病探究

python 复制代码
data_df['年龄'].min(), data_df['年龄'].max()
(60, 90)
python 复制代码
# 计算每一个年龄段患病人数 
age_bins = range(60, 91)
grouped = data_df.groupby('年龄').agg({'诊断状态': ['sum', 'size']})  # 分组、聚合函数: sum求和,size总大小
grouped.columns = ['患病', '总人数']
grouped['不患病'] = grouped['总人数'] - grouped['患病']  # 计算不患病的人数

# 设置绘图风格
sns.set(style="whitegrid")

plt.figure(figsize=(12, 5))

# 获取x轴标签(即年龄)
x = grouped.index.astype(str)  # 将年龄转换为字符串格式便于显示

# 画图
plt.bar(x, grouped["不患病"], 0.35, label="不患病", color='skyblue')
plt.bar(x, grouped["患病"], 0.35, label="患病", color='salmon')

# 设置标题
plt.title("患病年龄分布", fontproperties='Microsoft YaHei')
plt.xlabel("年龄", fontproperties='Microsoft YaHei')
plt.ylabel("人数", fontproperties='Microsoft YaHei')

# 如果需要对图例也应用相同的字体
plt.legend(prop={'family': 'Microsoft YaHei'})

# 展示
plt.tight_layout()
plt.show()


通过发现,由于原本数据中不患病的多,所以不患病的在图像中显示多,通过观察发现患病与年龄有关,尤其是年龄大,80岁的,患病与不患病比例高

提示:这里写代码的时候,不知道为什么,不指定字体,就显示不了字体。

3、特征选择

模型采用:决策树特征训练,可以很好的对特征重要性进行排序。

特征选择:采用REF,特征选择方法:

RFE(Recursive Feature Elimination,递归特征消除)和 SelectFromModel 都是 Scikit-learn 中用于特征选择的方法,但它们的工作机制和使用场景有所不同。

SelectFromModel

  • 工作原理SelectFromModel 是一种基于模型的特征选择方法。它通过一个基础评估器来判断每个特征的重要性,并根据给定的阈值选择那些重要性得分超过该阈值的特征。默认情况下,它会使用基础评估器提供的 feature_importances_ 或者 coef_ 属性来衡量特征的重要性。
  • 适用场景:当您希望基于某个预训练模型的特征重要性来进行特征选择时特别有用。它允许您设置一个全局阈值来控制特征选择的标准,但不直接支持指定想要选择的特征数量。
  • 优点:简单易用,适合快速进行特征筛选。
  • 缺点:不如 RFE 精细,不能直接控制最终选择的特征数量。

RFE (Recursive Feature Elimination)

  • 工作原理:RFE 采用了一种递归的方式进行特征选择。首先,它会训练一个模型,并根据模型对每个特征的重要性评分进行排序。然后,它会移除最不重要的特征,并重复这个过程,直到留下指定数量的特征为止。
  • 适用场景 :当您确切知道想要选择多少个特征时非常有用。它提供了比 SelectFromModel 更细致的控制,因为您可以直接指定要保留的特征数量。
  • 优点:可以精确控制最终选择的特征数量,并且在每一轮迭代中都能考虑到所有剩余特征的整体贡献。
  • 缺点:计算成本相对较高,因为它需要多次训练模型,特别是当数据集很大或模型复杂度很高时。

总结

  • 如果您的目标是基于某个预定义的重要性阈值来简化模型,那么 SelectFromModel 可能是更合适的选择。
  • 如果您希望直接控制最终选择的特征数量,并愿意接受更高的计算成本以获得更精细的控制,那么 RFE 可能更适合您的需求。

两种方法都有其独特的优势和适用场景,选择哪一种取决于您的具体应用需求、数据特性以及性能考虑。

python 复制代码
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import classification_report

data = data_df.copy()

X = data_df.iloc[:, 1:-2]
y = data_df.iloc[:, -2]

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 标准化
sc = StandardScaler()
X_train = sc.fit_transform(X_train)
X_test = sc.transform(X_test)

# 模型创建
tree = DecisionTreeClassifier()
tree.fit(X_train, y_train)
pred = tree.predict(X_test)

reporter = classification_report(y_test, pred)
print(reporter)
              precision    recall  f1-score   support

           0       0.91      0.92      0.91       277
           1       0.85      0.83      0.84       153

    accuracy                           0.89       430
   macro avg       0.88      0.88      0.88       430
weighted avg       0.89      0.89      0.89       430

效果不错,进行特征选择

python 复制代码
# 不知道为啥,这样也需要在设置
plt.rcParams["font.sans-serif"] = ["Microsoft YaHei"]  # 显示中文
plt.rcParams['axes.unicode_minus'] = False		# 显示负号

# 特征展示
feature_importances = tree.feature_importances_
features_rf = pd.DataFrame({'特征': X.columns, '重要度': feature_importances})
features_rf.sort_values(by='重要度', ascending=False, inplace=True)
plt.figure(figsize=(20, 10))
sns.barplot(x='重要度', y='特征', data=features_rf)
plt.xlabel('重要度')
plt.ylabel('特征')
plt.title('随机森林特征图')
plt.show()


从这个可以看出,有些特征没有效果,如性别,高血压等。

下面进行特征选择,选取20个特征。

python 复制代码
from sklearn.feature_selection import RFE

# 使用 RFE 来选择特征
rfe_selector = RFE(estimator=tree, n_features_to_select=20)  # 选择前20个特征
rfe_selector.fit(X, y)  
X_new = rfe_selector.transform(X)
feature_names = np.array(X.columns) 
selected_feature_names = feature_names[rfe_selector.support_]
print(selected_feature_names)
['年龄' '种族' '教育水平' '身体质量指数(BMI)' '酒精摄入量' '体育活动时间' '饮食质量评分' '睡眠质量评分' '心血管疾病'
 '收缩压' '舒张压' '胆固醇总量' '低密度脂蛋白胆固醇(LDL)' '高密度脂蛋白胆固醇(HDL)' '甘油三酯'
 '简易精神状态检查(MMSE)得分' '功能评估得分' '记忆抱怨' '行为问题' '日常生活活动(ADL)得分']

4、构建数据集

1、数据集划分与标准化

python 复制代码
feature_selection = ['年龄', '种族','教育水平','身体质量指数(BMI)', '酒精摄入量', '体育活动时间', '饮食质量评分', '睡眠质量评分', '心血管疾病',
 '收缩压', '舒张压', '胆固醇总量', '低密度脂蛋白胆固醇(LDL)', '高密度脂蛋白胆固醇(HDL)', '甘油三酯',
 '简易精神状态检查(MMSE)得分', '功能评估得分', '记忆抱怨', '行为问题', '日常生活活动(ADL)得分']

X = data_df[feature_selection]

# 标准化, 标准化其实对应连续性数据,分类数据不适合,由于特征中只有种族是分类数据,这里我偷个"小懒"
sc = StandardScaler()
X = sc.fit_transform(X)

X = torch.tensor(np.array(X), dtype=torch.float32)
y = torch.tensor(np.array(y), dtype=torch.long)

# 再次进行特征选择
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

X_train.shape, y_train.shape
(torch.Size([1719, 20]), torch.Size([1719]))

2、构建加载

python 复制代码
batch_size = 32

train_dl = DataLoader(
    TensorDataset(X_train, y_train),
    batch_size=batch_size,
    shuffle=True
)

test_dl = DataLoader(
    TensorDataset(X_test, y_test),
    batch_size=batch_size,
    shuffle=False
)

5、构建模型

python 复制代码
class Rnn_Model(nn.Module):
    def __init__(self):
        super().__init__()
        
        # 调用rnn
        self.rnn = nn.RNN(input_size=20, hidden_size=200, num_layers=1, batch_first=True)
        
        self.fc1 = nn.Linear(200, 50)
        self.fc2 = nn.Linear(50, 2)
        
    def forward(self, x):
        x, hidden1 = self.rnn(x)
        x = self.fc1(x)
        x = self.fc2(x)
        
        return x
    
# 数据不大,cpu即可
device = "cpu"

model = Rnn_Model().to(device)
model
Rnn_Model(
  (rnn): RNN(20, 200, batch_first=True)
  (fc1): Linear(in_features=200, out_features=50, bias=True)
  (fc2): Linear(in_features=50, out_features=2, bias=True)
)
python 复制代码
model(torch.randn(32, 20)).shape
torch.Size([32, 2])

6、模型训练

1、构建训练集

python 复制代码
def train(data, model, loss_fn, opt):
    size = len(data.dataset)
    batch_num = len(data)
    
    train_loss, train_acc = 0.0, 0.0
    
    for X, y in data:
        X, y = X.to(device), y.to(device)
        
        pred = model(X)
        loss = loss_fn(pred, y)
        
        # 反向传播
        opt.zero_grad()  # 梯度清零
        loss.backward()  # 求导
        opt.step()       # 设置梯度
        
        train_loss += loss.item()
        train_acc += (pred.argmax(1) == y).type(torch.float).sum().item()
    
    train_loss /= batch_num
    train_acc /= size 
    
    return train_acc, train_loss 

2、构建训练集

python 复制代码
def test(data, model, loss_fn):
    size = len(data.dataset)
    batch_num = len(data)
    
    test_loss, test_acc = 0.0, 0.0 
    
    with torch.no_grad():
        for X, y in data: 
            X, y = X.to(device), y.to(device)
            
            pred = model(X)
            loss = loss_fn(pred, y)
            
            test_loss += loss.item()
            test_acc += (pred.argmax(1) == y).type(torch.float).sum().item()
            
    test_loss /= batch_num
    test_acc /= size
    
    return test_acc, test_loss 

3、设置超参数

超参数,这里第一步设置了:

  • 1e-3,但是不稳定;
  • 1e-4,效果不错.
python 复制代码
loss_fn = nn.CrossEntropyLoss()  # 损失函数     
learn_lr = 1e-4            # 超参数
optimizer = torch.optim.Adam(model.parameters(), lr=learn_lr)   # 优化器

7、模型训练

python 复制代码
train_acc = []
train_loss = []
test_acc = []
test_loss = []

epoches = 50

for i in range(epoches):
    model.train()
    epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, optimizer)
    
    model.eval()
    epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn)
    
    train_acc.append(epoch_train_acc)
    train_loss.append(epoch_train_loss)
    test_acc.append(epoch_test_acc)
    test_loss.append(epoch_test_loss)
    
    # 输出
    template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%, Test_loss:{:.3f}')
    print(template.format(i + 1, epoch_train_acc*100, epoch_train_loss, epoch_test_acc*100, epoch_test_loss))
    
print("Done")
Epoch: 1, Train_acc:64.9%, Train_loss:0.658, Test_acc:66.0%, Test_loss:0.617
Epoch: 2, Train_acc:66.9%, Train_loss:0.585, Test_acc:70.9%, Test_loss:0.564
Epoch: 3, Train_acc:75.1%, Train_loss:0.531, Test_acc:75.1%, Test_loss:0.511
Epoch: 4, Train_acc:79.8%, Train_loss:0.476, Test_acc:80.9%, Test_loss:0.463
Epoch: 5, Train_acc:82.7%, Train_loss:0.432, Test_acc:81.9%, Test_loss:0.429
Epoch: 6, Train_acc:83.9%, Train_loss:0.399, Test_acc:82.6%, Test_loss:0.413
Epoch: 7, Train_acc:84.4%, Train_loss:0.388, Test_acc:83.3%, Test_loss:0.405
Epoch: 8, Train_acc:85.0%, Train_loss:0.380, Test_acc:82.8%, Test_loss:0.401
Epoch: 9, Train_acc:84.7%, Train_loss:0.381, Test_acc:83.0%, Test_loss:0.398
Epoch:10, Train_acc:84.3%, Train_loss:0.374, Test_acc:84.0%, Test_loss:0.398
Epoch:11, Train_acc:84.9%, Train_loss:0.373, Test_acc:83.5%, Test_loss:0.395
Epoch:12, Train_acc:84.3%, Train_loss:0.374, Test_acc:83.7%, Test_loss:0.400
Epoch:13, Train_acc:84.4%, Train_loss:0.375, Test_acc:83.7%, Test_loss:0.398
Epoch:14, Train_acc:84.6%, Train_loss:0.370, Test_acc:83.5%, Test_loss:0.399
Epoch:15, Train_acc:85.0%, Train_loss:0.370, Test_acc:83.3%, Test_loss:0.400
Epoch:16, Train_acc:84.9%, Train_loss:0.371, Test_acc:83.5%, Test_loss:0.402
Epoch:17, Train_acc:84.8%, Train_loss:0.373, Test_acc:83.3%, Test_loss:0.396
Epoch:18, Train_acc:85.0%, Train_loss:0.369, Test_acc:83.5%, Test_loss:0.397
Epoch:19, Train_acc:84.9%, Train_loss:0.372, Test_acc:83.7%, Test_loss:0.397
Epoch:20, Train_acc:85.3%, Train_loss:0.371, Test_acc:83.3%, Test_loss:0.394
Epoch:21, Train_acc:84.8%, Train_loss:0.372, Test_acc:83.5%, Test_loss:0.396
Epoch:22, Train_acc:84.6%, Train_loss:0.373, Test_acc:83.7%, Test_loss:0.396
Epoch:23, Train_acc:84.8%, Train_loss:0.370, Test_acc:84.0%, Test_loss:0.397
Epoch:24, Train_acc:84.3%, Train_loss:0.373, Test_acc:84.0%, Test_loss:0.401
Epoch:25, Train_acc:84.8%, Train_loss:0.370, Test_acc:84.0%, Test_loss:0.398
Epoch:26, Train_acc:84.9%, Train_loss:0.370, Test_acc:83.5%, Test_loss:0.398
Epoch:27, Train_acc:84.2%, Train_loss:0.373, Test_acc:82.8%, Test_loss:0.398
Epoch:28, Train_acc:85.6%, Train_loss:0.367, Test_acc:82.8%, Test_loss:0.399
Epoch:29, Train_acc:84.6%, Train_loss:0.370, Test_acc:83.7%, Test_loss:0.400
Epoch:30, Train_acc:84.4%, Train_loss:0.374, Test_acc:84.0%, Test_loss:0.399
Epoch:31, Train_acc:84.6%, Train_loss:0.370, Test_acc:83.0%, Test_loss:0.399
Epoch:32, Train_acc:85.2%, Train_loss:0.370, Test_acc:83.7%, Test_loss:0.396
Epoch:33, Train_acc:84.8%, Train_loss:0.372, Test_acc:84.0%, Test_loss:0.395
Epoch:34, Train_acc:84.9%, Train_loss:0.371, Test_acc:83.0%, Test_loss:0.396
Epoch:35, Train_acc:84.5%, Train_loss:0.371, Test_acc:83.3%, Test_loss:0.395
Epoch:36, Train_acc:85.0%, Train_loss:0.371, Test_acc:83.5%, Test_loss:0.396
Epoch:37, Train_acc:85.2%, Train_loss:0.369, Test_acc:84.2%, Test_loss:0.396
Epoch:38, Train_acc:84.6%, Train_loss:0.376, Test_acc:84.0%, Test_loss:0.395
Epoch:39, Train_acc:85.2%, Train_loss:0.370, Test_acc:84.2%, Test_loss:0.396
Epoch:40, Train_acc:84.9%, Train_loss:0.371, Test_acc:84.2%, Test_loss:0.396
Epoch:41, Train_acc:84.4%, Train_loss:0.372, Test_acc:84.0%, Test_loss:0.394
Epoch:42, Train_acc:84.9%, Train_loss:0.370, Test_acc:84.2%, Test_loss:0.393
Epoch:43, Train_acc:84.8%, Train_loss:0.370, Test_acc:84.4%, Test_loss:0.395
Epoch:44, Train_acc:84.4%, Train_loss:0.372, Test_acc:84.0%, Test_loss:0.394
Epoch:45, Train_acc:85.3%, Train_loss:0.371, Test_acc:85.3%, Test_loss:0.396
Epoch:46, Train_acc:84.5%, Train_loss:0.371, Test_acc:83.5%, Test_loss:0.395
Epoch:47, Train_acc:84.5%, Train_loss:0.369, Test_acc:83.5%, Test_loss:0.396
Epoch:48, Train_acc:84.9%, Train_loss:0.371, Test_acc:83.3%, Test_loss:0.397
Epoch:49, Train_acc:85.1%, Train_loss:0.371, Test_acc:83.3%, Test_loss:0.396
Epoch:50, Train_acc:85.0%, Train_loss:0.369, Test_acc:82.6%, Test_loss:0.398
Done

8、结果评估

1、结果图

python 复制代码
import matplotlib.pyplot as plt
#隐藏警告
import warnings
warnings.filterwarnings("ignore")               #忽略警告信息

epochs_range = range(epoches)

plt.figure(figsize=(12, 3))
plt.subplot(1, 2, 1)

plt.plot(epochs_range, train_acc, label='Training Accuracy')
plt.plot(epochs_range, test_acc, label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training Accuracy')

plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_loss, label='Training Loss')
plt.plot(epochs_range, test_loss, label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training= Loss')
plt.show()


2、混淆矩阵

混淆矩阵(Confusion Matrix)是机器学习和数据科学中用于评估分类模型性能的一种表格。它通过展示模型预测结果与实际标签之间的对比,帮助我们理解模型的准确度以及其在不同类别上的表现。

对于一个二分类问题,混淆矩阵通常是一个2x2的表格,包含以下四个指标:

  • 真正例 (True Positive, TP):模型正确预测为正类的样本数。
  • 假正例 (False Positive, FP):模型错误地将负类预测为正类的样本数。
  • 假负例 (False Negative, FN):模型错误地将正类预测为负类的样本数。
  • 真负例 (True Negative, TN):模型正确预测为负类的样本数。

而对于多分类问题,混淆矩阵会相应地扩展到NxN的大小(N为类别数量),每一行代表实际类别,每一列代表预测类别。

python 复制代码
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay 

pred = model(X_test.to(device)).argmax(1).cpu().numpy()

# 计算混淆矩阵
cm = confusion_matrix(y_test, pred)

# 计算
plt.figure(figsize=(6, 5))
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues")
# 标题
plt.title("混淆矩阵")
plt.xlabel("Predicted Label")
plt.ylabel("True Label")

plt.tight_layout()  # 自适应
plt.show()

相关推荐
轻口味几秒前
【每日学点HarmonyOS Next知识】对话框与导航冲突、富文本、字体大小、列表刷新、Scroll包裹文本
pytorch·深度学习·harmonyos·harmonyos next
AAA小肥杨6 分钟前
2025人工智能AI新突破:PINN内嵌物理神经网络火了
人工智能·深度学习·神经网络·ai·大模型部署
王国强200928 分钟前
现代循环神经网络4-双向循环神经网络
深度学习
梓羽玩Python42 分钟前
太牛了!OWL:Manus 最强开源复现,开源框架GAIA基准测试中排第一!
人工智能·python
二川bro1 小时前
TensorFlow.js 全面解析:在浏览器中构建机器学习应用
javascript·机器学习·tensorflow
闲人编程1 小时前
经典网络复现与迁移学习
pytorch·深度学习·神经网络
dearxue1 小时前
「新」AI Coding(Agent) 的一点总结和看法
机器学习·aigc
詹天佐1 小时前
ICCE 数字车钥匙介绍
人工智能·算法
坚果的博客1 小时前
uniapp版本加密货币行情应用
人工智能·华为·uni-app·harmonyos