第100+43步 ChatGPT学习:R语言实现特征选择曲线图

今天来说个机器学习分类的图,最近文章常出现的:特征选择曲线图(Feature Selection Curve)。

一、何为特征选择曲线图

特征选择曲线图主要用于展示随着纳入模型的特征数量增加,模型性能(如 AUC、准确率等)的变化情况。

二、有何用

(1)评估特征重要性:通过观察曲线,可以判断哪些特征对模型性能的提升贡献最大。例如,当纳入前几个重要特征时,模型性能显著提升,说明这些特征对模型的预测能力至关重要。

(2)确定最优特征数量:图中通常会显示一个"拐点",即增加更多特征后,模型性能的提升趋于平缓甚至下降。这个拐点可以帮助确定最优的特征数量,避免过拟合或引入冗余特征。

(3)比较训练集和验证集性能:通过同时绘制训练集和验证集的性能曲线,可以判断模型是否存在过拟合:

(a)如果训练集性能持续提升,而验证集性能在某个点后开始下降,说明模型可能过拟合。

(b)如果两条曲线趋势一致,说明模型泛化能力较好。

(4)指导特征选择:通过分析曲线,可以选择一个既能保证模型性能,又能减少特征数量的方案,从而简化模型并提高计算效率。

三、R语言代码,以Xgboost为例:

python 复制代码
# Load necessary libraries
library(caret)
library(pROC)
library(ggplot2)
library(xgboost)

# Assume 'data' is your dataframe containing the data
# Set seed to ensure reproducibility
set.seed(123)

# Split data into training and validation sets (80% training, 20% validation)
trainIndex <- createDataPartition(data$X, p = 0.8, list = FALSE)
trainData <- data[trainIndex, ]
validData <- data[-trainIndex, ]

# Prepare matrices for XGBoost
dtrain <- xgb.DMatrix(data = as.matrix(trainData[, -which(names(trainData) == "X")]), label = trainData$X)
dvalid <- xgb.DMatrix(data = as.matrix(validData[, -which(names(validData) == "X")]), label = validData$X)

# Define parameters for XGBoost
params <- list(booster = "gbtree", 
               objective = "binary:logistic", 
               eta = 0.1, 
               gamma = 0, 
               max_depth = 6, 
               min_child_weight = 1, 
               subsample = 0.5, 
               colsample_bytree = 0.9,
               lambda = 10,
               alpha = 5)

# Train the XGBoost model
model <- xgb.train(params = params, data = dtrain, nrounds = 250, watchlist = list(eval = dtrain), verbose = 1)

# Get feature importance
importance_matrix <- xgb.importance(feature_names = colnames(trainData[, -which(names(trainData) == "X")]), model = model)
print(importance_matrix)

# Initialize vectors to store AUC values
train_auc <- numeric(nrow(importance_matrix))
valid_auc <- numeric(nrow(importance_matrix))

# Loop through the number of features
for (i in 1:nrow(importance_matrix)) {
  # Select top i features
  selected_features <- importance_matrix$Feature[1:i]
  
  # Prepare matrices with selected features
  dtrain_selected <- xgb.DMatrix(data = as.matrix(trainData[, selected_features]), label = trainData$X)
  dvalid_selected <- xgb.DMatrix(data = as.matrix(validData[, selected_features]), label = validData$X)
  
  # Train the model with selected features
  model_selected <- xgb.train(params = params, data = dtrain_selected, nrounds = 250, watchlist = list(eval = dtrain_selected), verbose = 0)
  
  # Predict on training and validation sets
  trainPredict <- predict(model_selected, dtrain_selected)
  validPredict <- predict(model_selected, dvalid_selected)
  
  # Calculate AUC
  train_roc <- roc(response = as.numeric(trainData$X) - 1, predictor = trainPredict)
  valid_roc <- roc(response = as.numeric(validData$X) - 1, predictor = validPredict)
  
  # Store AUC values
  train_auc[i] <- auc(train_roc)
  valid_auc[i] <- auc(valid_roc)
}

# Create a dataframe for plotting
auc_data <- data.frame(
  num_features = 1:nrow(importance_matrix),
  train_auc = train_auc,
  valid_auc = valid_auc
)

# Plot AUC values for training and validation sets
ggplot(auc_data, aes(x = num_features)) +
  geom_line(aes(y = train_auc, color = "Training AUC")) +
  geom_line(aes(y = valid_auc, color = "Validation AUC")) +
  labs(title = "AUC vs Number of Features",
       x = "Number of Features",
       y = "AUC",
       color = "Legend") +
  theme_minimal() +
  scale_color_manual(values = c("Training AUC" = "blue", "Validation AUC" = "red"))

看图:

示例解读

图表显示:

当特征数量为 5 时,验证集 AUC 达到峰值(例如 0.85),之后趋于平缓。训练集 AUC 在特征数量为 5 时为 0.88,之后继续上升。

结论:

最优特征数量为 5,因为此时验证集 AUC 达到最高值。增加更多特征可能导致过拟合(训练集 AUC 继续上升,而验证集 AUC 不再提升)。

相关推荐
深藏blue4712 小时前
GPT-5.3 Instant 重磅上线!2026最新 ChatGPT 告别说教,国内使用与 Plus 升级教程
gpt·chatgpt·openai
西岸行者7 天前
学习笔记:SKILLS 能帮助更好的vibe coding
笔记·学习
悠哉悠哉愿意7 天前
【单片机学习笔记】串口、超声波、NE555的同时使用
笔记·单片机·学习
别催小唐敲代码7 天前
嵌入式学习路线
学习
毛小茛7 天前
计算机系统概论——校验码
学习
babe小鑫7 天前
大专经济信息管理专业学习数据分析的必要性
学习·数据挖掘·数据分析
winfreedoms7 天前
ROS2知识大白话
笔记·学习·ros2
在这habit之下7 天前
Linux Virtual Server(LVS)学习总结
linux·学习·lvs
我想我不够好。7 天前
2026.2.25监控学习
学习
im_AMBER7 天前
Leetcode 127 删除有序数组中的重复项 | 删除有序数组中的重复项 II
数据结构·学习·算法·leetcode