第100+43步 ChatGPT学习:R语言实现特征选择曲线图

今天来说个机器学习分类的图,最近文章常出现的:特征选择曲线图(Feature Selection Curve)。

一、何为特征选择曲线图

特征选择曲线图主要用于展示随着纳入模型的特征数量增加,模型性能(如 AUC、准确率等)的变化情况。

二、有何用

(1)评估特征重要性:通过观察曲线,可以判断哪些特征对模型性能的提升贡献最大。例如,当纳入前几个重要特征时,模型性能显著提升,说明这些特征对模型的预测能力至关重要。

(2)确定最优特征数量:图中通常会显示一个"拐点",即增加更多特征后,模型性能的提升趋于平缓甚至下降。这个拐点可以帮助确定最优的特征数量,避免过拟合或引入冗余特征。

(3)比较训练集和验证集性能:通过同时绘制训练集和验证集的性能曲线,可以判断模型是否存在过拟合:

(a)如果训练集性能持续提升,而验证集性能在某个点后开始下降,说明模型可能过拟合。

(b)如果两条曲线趋势一致,说明模型泛化能力较好。

(4)指导特征选择:通过分析曲线,可以选择一个既能保证模型性能,又能减少特征数量的方案,从而简化模型并提高计算效率。

三、R语言代码,以Xgboost为例:

python 复制代码
# Load necessary libraries
library(caret)
library(pROC)
library(ggplot2)
library(xgboost)

# Assume 'data' is your dataframe containing the data
# Set seed to ensure reproducibility
set.seed(123)

# Split data into training and validation sets (80% training, 20% validation)
trainIndex <- createDataPartition(data$X, p = 0.8, list = FALSE)
trainData <- data[trainIndex, ]
validData <- data[-trainIndex, ]

# Prepare matrices for XGBoost
dtrain <- xgb.DMatrix(data = as.matrix(trainData[, -which(names(trainData) == "X")]), label = trainData$X)
dvalid <- xgb.DMatrix(data = as.matrix(validData[, -which(names(validData) == "X")]), label = validData$X)

# Define parameters for XGBoost
params <- list(booster = "gbtree", 
               objective = "binary:logistic", 
               eta = 0.1, 
               gamma = 0, 
               max_depth = 6, 
               min_child_weight = 1, 
               subsample = 0.5, 
               colsample_bytree = 0.9,
               lambda = 10,
               alpha = 5)

# Train the XGBoost model
model <- xgb.train(params = params, data = dtrain, nrounds = 250, watchlist = list(eval = dtrain), verbose = 1)

# Get feature importance
importance_matrix <- xgb.importance(feature_names = colnames(trainData[, -which(names(trainData) == "X")]), model = model)
print(importance_matrix)

# Initialize vectors to store AUC values
train_auc <- numeric(nrow(importance_matrix))
valid_auc <- numeric(nrow(importance_matrix))

# Loop through the number of features
for (i in 1:nrow(importance_matrix)) {
  # Select top i features
  selected_features <- importance_matrix$Feature[1:i]
  
  # Prepare matrices with selected features
  dtrain_selected <- xgb.DMatrix(data = as.matrix(trainData[, selected_features]), label = trainData$X)
  dvalid_selected <- xgb.DMatrix(data = as.matrix(validData[, selected_features]), label = validData$X)
  
  # Train the model with selected features
  model_selected <- xgb.train(params = params, data = dtrain_selected, nrounds = 250, watchlist = list(eval = dtrain_selected), verbose = 0)
  
  # Predict on training and validation sets
  trainPredict <- predict(model_selected, dtrain_selected)
  validPredict <- predict(model_selected, dvalid_selected)
  
  # Calculate AUC
  train_roc <- roc(response = as.numeric(trainData$X) - 1, predictor = trainPredict)
  valid_roc <- roc(response = as.numeric(validData$X) - 1, predictor = validPredict)
  
  # Store AUC values
  train_auc[i] <- auc(train_roc)
  valid_auc[i] <- auc(valid_roc)
}

# Create a dataframe for plotting
auc_data <- data.frame(
  num_features = 1:nrow(importance_matrix),
  train_auc = train_auc,
  valid_auc = valid_auc
)

# Plot AUC values for training and validation sets
ggplot(auc_data, aes(x = num_features)) +
  geom_line(aes(y = train_auc, color = "Training AUC")) +
  geom_line(aes(y = valid_auc, color = "Validation AUC")) +
  labs(title = "AUC vs Number of Features",
       x = "Number of Features",
       y = "AUC",
       color = "Legend") +
  theme_minimal() +
  scale_color_manual(values = c("Training AUC" = "blue", "Validation AUC" = "red"))

看图:

示例解读

图表显示:

当特征数量为 5 时,验证集 AUC 达到峰值(例如 0.85),之后趋于平缓。训练集 AUC 在特征数量为 5 时为 0.88,之后继续上升。

结论:

最优特征数量为 5,因为此时验证集 AUC 达到最高值。增加更多特征可能导致过拟合(训练集 AUC 继续上升,而验证集 AUC 不再提升)。

相关推荐
wdfk_prog11 小时前
[Linux]学习笔记系列 -- [kernel]workqueue
linux·笔记·学习
wdfk_prog11 小时前
[Linux]学习笔记系列 -- [kernel]usermode_helper
linux·笔记·学习
冬夜戏雪11 小时前
【学习日记】【刷题回溯、贪心、动规】
学习
一只爱做笔记的码农11 小时前
【BootstrapBlazor】移植BootstrapBlazor VS工程到Vscode工程,报error blazor106的问题
笔记·学习·c#
xixixi7777712 小时前
“C2隐藏”——命令与控制服务器的隐藏技术
网络·学习·安全·代理·隐藏·合法服务·c2隐藏
名字不相符12 小时前
攻防世界WEB难度一(个人记录)
学习·php·web·萌新
陈天伟教授12 小时前
基于学习的人工智能(4)机器学习基本框架
人工智能·学习·机器学习
7***374513 小时前
DeepSeek在文本分类中的多标签学习
学习·分类·数据挖掘
jiushun_suanli13 小时前
量子纠缠:颠覆认知的宇宙密码
经验分享·学习·量子计算
charlie11451419113 小时前
勇闯前后端Week2:后端基础——Flask API速览
笔记·后端·python·学习·flask·教程