【R语言】校准曲线,绘制原理


①获取predict的结果,"prob.Case"这一列就是预测风险概率,"truth"列为实际发生结局的分组

②将prob.Case进行分桶(简单理解为分组,一般分10组),常见的分桶方式有两种:一是将prob.Case从大到小排序后,按照样本数平均分为10组,每组样本数相等

③计算10个分桶中prob.Case的桶内均值作为预测概率;

④计算10个分桶中实际患病者(truth = 1 | truth=Case)占该桶样本数的频率作为实际概率;

⑤将10对预测概率和实际概率分别作为横坐标值和纵坐标值得到10个散点;

⑥将这些点连起来,即为校准曲线中的Apparent线。

R包的函数

r 复制代码
function (df, outcome, positive, prediction, model, n_bins = 10, 
  show_loess = FALSE, plot_title = "", ...) 
{
  if ((n_bins > 0 && show_loess == TRUE) || (n_bins == 0 && 
    show_loess == FALSE)) {
    stop("You must either set n_bins > 0 and show_loess to FALSE or set n_bins to 0 and show_loess to TRUE. Both cannot be displayed.")
  }
  how_many_models = df[[model]] %>% unique() %>% length()
  df[[outcome]] = ifelse(positive == df[[outcome]], 1, 0)
  if (n_bins > 0) {
    df <- df %>% dplyr::group_by(!!rlang::parse_expr(model)) %>% 
      dplyr::mutate(bin = dplyr::ntile(!!rlang::parse_expr(prediction), 
        n_bins)) %>% dplyr::group_by(!!rlang::parse_expr(model), 
      bin) %>% dplyr::mutate(n = dplyr::n(), bin_pred = mean(!!rlang::parse_expr(prediction), 
      na.rm = TRUE), bin_prob = mean(as.numeric(as.character(!!rlang::parse_expr(outcome))), 
      na.rm = TRUE), se = sqrt((bin_prob * (1 - bin_prob))/n), 
      ul = bin_prob + 1.96 * se, ll = bin_prob - 1.96 * 
        se) %>% dplyr::mutate_at(dplyr::vars(ul, ll), 
      . %>% scales::oob_squish(range = c(0, 1))) %>% dplyr::ungroup()
  }
  g1 = ggplot2::ggplot(df) + ggplot2::scale_y_continuous(limits = c(0, 
    1), breaks = seq(0, 1, by = 0.1)) + ggplot2::scale_x_continuous(limits = c(0, 
    1), breaks = seq(0, 1, by = 0.1)) + ggplot2::geom_abline(linetype = "dashed")
  if (show_loess == TRUE) {
    g1 = g1 + ggplot2::stat_smooth(ggplot2::aes(x = !!rlang::parse_expr(prediction), 
      y = as.numeric(!!rlang::parse_expr(outcome)), color = !!rlang::parse_expr(model), 
      fill = !!rlang::parse_expr(model)), se = TRUE, method = "loess")
  }
  else {
    g1 = g1 + ggplot2::aes(x = bin_pred, y = bin_prob, color = !!rlang::parse_expr(model), 
      fill = !!rlang::parse_expr(model)) + ggplot2::geom_ribbon(ggplot2::aes(ymin = ll, 
      ymax = ul, ), alpha = 1/how_many_models) + ggplot2::geom_point(size = 2) + 
      ggplot2::geom_line(size = 1, alpha = 1/how_many_models)
  }
  g1 = g1 + ggplot2::xlab("Predicted Probability") + ggplot2::ylab("Observed Risk") + 
    ggplot2::scale_color_brewer(name = "Models", palette = "Set1") + 
    ggplot2::scale_fill_brewer(name = "Models", palette = "Set1") + 
    ggplot2::theme_minimal() + ggplot2::theme(aspect.ratio = 1) + 
    ggplot2::ggtitle(plot_title)
  g2 <- ggplot2::ggplot(df, ggplot2::aes(x = !!rlang::parse_expr(prediction))) + 
    ggplot2::geom_density(alpha = 1/how_many_models, ggplot2::aes(fill = !!rlang::parse_expr(model), 
      color = !!rlang::parse_expr(model))) + ggplot2::scale_x_continuous(limits = c(0, 
    1), breaks = seq(0, 1, by = 0.1)) + ggplot2::coord_fixed() + 
    ggplot2::xlab("") + ggplot2::ylab("") + ggplot2::scale_color_brewer(palette = "Set1") + 
    ggplot2::scale_fill_brewer(palette = "Set1") + ggplot2::theme_minimal() + 
    ggeasy::easy_remove_y_axis() + ggeasy::easy_remove_legend(fill, 
    color) + ggplot2::theme_void() + ggplot2::theme(aspect.ratio = 0.1)
  layout = c(patchwork::area(t = 1, b = 10, l = 1, r = 10), 
    patchwork::area(t = 11, b = 12, l = 1, r = 10))
  g1/g2
}

自己写的函数

r 复制代码
# 读取数据
data <- prediction_all_rf

get_cal <- function(data=prediction_all_rf){
  data <- data %>% mutate(bucket = ntile(prob.Case, 10))
  bucket_means <- data %>% group_by(bucket) %>% 
    summarise(predicted_prob = mean(prob.Case))
  actual_probs <- data %>% group_by(bucket) %>% 
    summarise(actual_prob = mean(truth == "Case"))
  calibration_data <- left_join(bucket_means, actual_probs, by = "bucket")
  calibration_data$type=data$type[1]
  return(calibration_data)
}

cal_rf <- get_cal(data = prediction_all_rf)
cal_kkmm <- get_cal(data = prediction_all_kknn)
cal_SVM <- get_cal(data = prediction_all_SVM)
cal_xgb <- get_cal(data = prediction_all_xgb)

calibration_data <- rbind(cal_rf,cal_kkmm,cal_SVM,cal_xgb)
# ⑥ 将这些点连起来,即为校准曲线中的Apparent线
ggplot(calibration_data, 
       aes(x = predicted_prob, y = actual_prob,group = type,colour = type)) +
  geom_point() +
  geom_line() +
  labs(title = "Calibration Curve", x = "Predicted Probability", y = "Actual Probability") +
  theme_minimal()+
  ggplot2::scale_y_continuous(limits = c(0, 
                                           1), breaks = seq(0, 1, by = 0.1)) + ggplot2::scale_x_continuous(limits = c(0, 
                                                                                                                      1), breaks = seq(0, 1, by = 0.1)) + ggplot2::geom_abline(linetype = "dashed")
相关推荐
知识分享小能手12 小时前
R语言入门学习教程,从入门到精通,R语言数值关系数据可视化 - 完整知识点(5)
学习·信息可视化·r语言
生信碱移21 小时前
PACells:这个方法可以鉴定疾病/预后相关的重要细胞亚群,作者提供的代码流程可以学习起来了,甚至兼容转录组与 ATAC 两种数据类型!
人工智能·学习·算法·机器学习·数据挖掘·数据分析·r语言
知识分享小能手1 天前
R语言入门学习教程,从入门到精通,R语言类别比较数据可视化- 完整知识点与案例代码(4)
学习·信息可视化·r语言
星座5281 天前
掌握双碳核心工具,从产品碳足迹到气候变化响应:基于OpenLCA、GREET、R语言的生命周期评价方法、模型构建及典型案例应用
r语言·生命周期·openlca·greet
知识分享小能手2 天前
R语言入门学习教程,从入门到精通,R语言网格绘图系统(ggplot2)- 完整知识点与案例代码(3)
开发语言·学习·r语言
做cv的小昊2 天前
【TJU】研究生应用统计学课程笔记(5)——第二章 参数估计(2.3 C-R不等式)
c语言·笔记·线性代数·机器学习·数学建模·r语言·概率论
hhb_6183 天前
R语言数据分析与可视化实战指南
开发语言·数据分析·r语言
知识分享小能手3 天前
R语言入门学习教程,从入门到精通,R语言传统绘图系统 - 完整知识点与案例代码(2)
开发语言·学习·r语言
笑不语4 天前
从共病网络到可解释 AI:同济医院 10 分 SCI 全流程复现(R 语言)
开发语言·人工智能·r语言
知识分享小能手4 天前
R语言入门学习教程,从入门到精通,R语言基础 - 完整知识点与案例代码(1)
开发语言·学习·r语言