基于LSTM的医疗成本预测

声明:


前言:

本节不止学了LSTM网络的运用,还学习了数据可视化处理的常用方法

数据集介绍:本文采用的数据集为公开的医用费用预测数据集,在采集患者年龄与年度医疗费用的基础上,还采集了以下20种特征:

age

gender

bmi

smoker

diabetes

hypertension

heart_disease

asthma

physical_activity_level

daily_steps

sleep_hours

stress_level

doctor_visits_per_year

hospital_admissions

medication_count

insurance_type

insurance_coverage_pct

city_type

previous_year_cost

annual_medical_cost


有一个叫Seaborn的库,集成了很多可以直接用的常用的高级绘图


热力图:

复制代码
sns.heatmap(numeric_cols.corr(), annot=True, cmap='coolwarm')

计算numeric_cols数据之间的相关系数,然后采用冷暖颜色表示的格子内有数字的热力图

热力图能够很清楚的看出各个数据之间的相关性,例如:hospital_admissions 和annual_medical_cost有着很强的相关性(住院一般都容易花很多钱),而insurance_coverage_pct 和annual_medical_cost则表示出负相关(保险覆盖率高花钱少)

热力图可以通过颜色很清楚的看出最强相关的和最不相关的东西,送入神经网络训练也是一个很好的样本


箱线图 boxplot:

复制代码
sns.boxplot(data = df.loc[:, ['annual_medical_cost']], ax=ax[0], whis=3)
ax[0].set_title('多个数值变量')

sns.boxplot(x = df['hospital_admissions'], ax=ax[1], whis=3)
ax[1].set_title('一个数值变量多个分组')

sns.boxplot(x = 'hospital_admissions', y = 'annual_medical_cost', hue = 'smoker', data = df, palette='Set1', width=0.5, ax=ax[2], whis=3)
ax[2].set_title('一个数值变量多个分组')

含义:
x='hospital_admissions':按住院次数分组。
y='annual_medical_cost':看每组的年度医疗费用分布。
hue='smoker':再按是否吸烟分颜色。
data=df:数据来自 df。
palette='Set1':颜色方案。
width=0.5:箱子宽度。
ax=ax[2]:画在第 3 个子图上。
whis=3:离群值判断范围放宽到 3 倍 IQR。

箱子中间那条线是代表中位值箱体上沿和下沿分别是75%和25%的数据点位。

上须:75%点位+1.5*(75%点位值-25%点位值),表示理论最大值,上须以外的叫离群点,实际值,下须同理


小提琴图 violinplot:

复制代码
sns.violinplot(data = df.loc[:, ['annual_medical_cost']], ax=ax[0])
ax[0].set_title('多个数值变量')

sns.violinplot(x = df['hospital_admissions'], y = df['annual_medical_cost'], ax=ax[1])
ax[1].set_title('一个数值变量多个分组')

sns.violinplot(x = 'hospital_admissions', y = 'annual_medical_cost', hue = 'smoker', data = df, palette='Set1', width=0.5, ax=ax[2])
ax[2].set_title('一个数值变量多个分组')

小提琴图和箱线图类似,但它多了一层信息:分布密度。

怎么看:

  • 图越宽:这个数值附近的数据越多。
  • 图越窄:这个数值附近的数据少。
  • 形状如果很胖,说明数据集中在某些区间。
  • 如果上下拉得很长,说明波动范围大。

箱线图更适合看中位数和离群点,小提琴图更适合看"数据集中在哪里"。


散点图 scatterplot:

复制代码
sns.scatterplot(
    data = df[:100],
    x = 'previous_year_cost',
    y = 'annual_medical_cost',
    size = 'age',
    sizes = (20, 200),
    alpha = 0.6,
    color = '#2E86AB',
    ax = ax[0]
)

ax[0].set_title('(气泡大小=年龄)')
ax[0].set_xlabel('上一年费用')
ax[0].set_ylabel('年度医疗费用')
ax[0].legend(title='年龄', loc='upper left', bbox_to_anchor=(1.05, 1))

sns.scatterplot(
    data = df[:100],
    x = 'bmi',
    y = 'annual_medical_cost',
    size = 'age',
    sizes = (20, 200),
    alpha = 0.7,
    hue = 'gender',
    palette = 'Set2',
    ax = ax[1]
)
ax[1].set_title('按bmi分组(气泡大小=年龄, 颜色=性别)')
ax[1].set_xlabel('BMI')
ax[1].set_ylabel('年度医疗费用')
ax[1].legend(title='性别', loc='upper left', bbox_to_anchor=(1.05, 1))

sns.scatterplot(
    data = df[:100],
    x = 'sleep_hours',
    y = 'annual_medical_cost',
    hue = 'smoker',
    size = 'age',
    sizes = (40, 300),
    alpha = 0.7,
    palette = 'Set1',
    ax = ax[2]
)

ax[2].set_title('按睡眠时间分组(颜色=是否吸烟, 气泡大小=年龄)')
ax[2].set_xlabel('睡眠时间')
ax[2].set_ylabel('年度医疗费用')
ax[2].legend(title='是否吸烟', loc='upper left', bbox_to_anchor=(1.05, 1))

散点图用来看两个连续变量之间的关系。

怎么看:

  • 点整体从左下到右上:正相关。
  • 点整体从左上到右下:负相关。
  • 点乱成一团:关系不明显。
  • 颜色 hue:看不同类别是否形成不同模式。
  • 大小 size:看第三个变量是否有影响。

训练结果:


总结:

LSTM的预测用的是R2而不是acc

关键区别

  • R2:衡量连续数值预测的拟合程度,越接近 1 越好。
  • accuracy:衡量离散类别预测是否正确,不能直接用于预测数值型医疗费用。

而且loss计算由于是直接计算的,没有归一化,所以特别大。

整体和RNN的训练还是很相似的,这节的重点在于熟悉一下这几个常用的绘图方式。