第100+14步 ChatGPT学习:R实现随机森林分类

基于R 4.2.2版本演示

一、写在前面

有不少大佬问做机器学习分类能不能用R语言,不想学Python咯。

答曰:可!用GPT或者Kimi转一下就得了呗。

加上最近也没啥内容写了,就帮各位搬运一下吧。

二、R代码实现随机森林分类

(1) 导入数据

我习惯用RStudio自带的导入功能:

(2) 建立随机森林模型(默认参数)

R 复制代码
# Load necessary libraries
library(caret)
library(pROC)
library(ggplot2)

# Assume 'data' is your dataframe containing the data
# Set seed to ensure reproducibility
set.seed(123)

# Split data into training and validation sets (80% training, 20% validation)
trainIndex <- createDataPartition(data$X, p = 0.8, list = FALSE)
trainData <- data[trainIndex, ]
validData <- data[-trainIndex, ]

# Convert the target variable to a factor for classification
trainData$X <- as.factor(trainData$X)
validData$X <- as.factor(validData$X)

# Define control method for training with cross-validation
trainControl <- trainControl(method = "cv", number = 10)

# Fit Random Forest model on the training set
model <- train(X ~ ., data = trainData, method = "rf", trControl = trainControl)

# Print the best parameters found by the model
best_params <- model$bestTune
cat("The best parameters found are:\n")
print(best_params)
# Predict on the training and validation sets
trainPredict <- predict(model, trainData, type = "prob")[,2]
validPredict <- predict(model, validData, type = "prob")[,2]

# Calculate ROC curves and AUC values
trainRoc <- roc(response = trainData$X, predictor = trainPredict)
validRoc <- roc(response = validData$X, predictor = validPredict)

# Plot ROC curves with AUC values
ggplot(data = data.frame(fpr = trainRoc$specificities, tpr = trainRoc$sensitivities), aes(x = 1 - fpr, y = tpr)) +
  geom_line(color = "blue") +
  geom_area(alpha = 0.2, fill = "blue") +
  geom_abline(slope = 1, intercept = 0, linetype = "dashed", color = "black") +
  ggtitle("Training ROC Curve") +
  xlab("False Positive Rate") +
  ylab("True Positive Rate") +
  annotate("text", x = 0.5, y = 0.1, label = paste("Training AUC =", round(auc(trainRoc), 2)), hjust = 0.5, color = "blue")

ggplot(data = data.frame(fpr = validRoc$specificities, tpr = validRoc$sensitivities), aes(x = 1 - fpr, y = tpr)) +
  geom_line(color = "red") +
  geom_area(alpha = 0.2, fill = "red") +
  geom_abline(slope = 1, intercept = 0, linetype = "dashed", color = "black") +
  ggtitle("Validation ROC Curve") +
  xlab("False Positive Rate") +
  ylab("True Positive Rate") +
  annotate("text", x = 0.5, y = 0.2, label = paste("Validation AUC =", round(auc(validRoc), 2)), hjust = 0.5, color = "red")

# Calculate confusion matrices based on 0.5 cutoff for probability
confMatTrain <- table(trainData$X, trainPredict >= 0.5)
confMatValid <- table(validData$X, validPredict >= 0.5)

# Function to plot confusion matrix using ggplot2
plot_confusion_matrix <- function(conf_mat, dataset_name) {
  conf_mat_df <- as.data.frame(as.table(conf_mat))
  colnames(conf_mat_df) <- c("Actual", "Predicted", "Freq")
  
  p <- ggplot(data = conf_mat_df, aes(x = Predicted, y = Actual, fill = Freq)) +
    geom_tile(color = "white") +
    geom_text(aes(label = Freq), vjust = 1.5, color = "black", size = 5) +
    scale_fill_gradient(low = "white", high = "steelblue") +
    labs(title = paste("Confusion Matrix -", dataset_name, "Set"), x = "Predicted Class", y = "Actual Class") +
    theme_minimal() +
    theme(axis.text.x = element_text(angle = 45, hjust = 1), plot.title = element_text(hjust = 0.5))
  
  print(p)
}

# Now call the function to plot and display the confusion matrices
plot_confusion_matrix(confMatTrain, "Training")
plot_confusion_matrix(confMatValid, "Validation")

# Extract values for calculations
a_train <- confMatTrain[1, 1]
b_train <- confMatTrain[1, 2]
c_train <- confMatTrain[2, 1]
d_train <- confMatTrain[2, 2]

a_valid <- confMatValid[1, 1]
b_valid <- confMatValid[1, 2]
c_valid <- confMatValid[2, 1]
d_valid <- confMatValid[2, 2]

# Training Set Metrics
acc_train <- (a_train + d_train) / sum(confMatTrain)
error_rate_train <- 1 - acc_train
sen_train <- d_train / (d_train + c_train)
sep_train <- a_train / (a_train + b_train)
precision_train <- d_train / (b_train + d_train)
F1_train <- (2 * precision_train * sen_train) / (precision_train + sen_train)
MCC_train <- (d_train * a_train - b_train * c_train) / sqrt((d_train + b_train) * (d_train + c_train) * (a_train + b_train) * (a_train + c_train))
auc_train <- roc(response = trainData$X, predictor = trainPredict)$auc

# Validation Set Metrics
acc_valid <- (a_valid + d_valid) / sum(confMatValid)
error_rate_valid <- 1 - acc_valid
sen_valid <- d_valid / (d_valid + c_valid)
sep_valid <- a_valid / (a_valid + b_valid)
precision_valid <- d_valid / (b_valid + d_valid)
F1_valid <- (2 * precision_valid * sen_valid) / (precision_valid + sen_valid)
MCC_valid <- (d_valid * a_valid - b_valid * c_valid) / sqrt((d_valid + b_valid) * (d_valid + c_valid) * (a_valid + b_valid) * (a_valid + c_valid))
auc_valid <- roc(response = validData$X, predictor = validPredict)$auc

# Print Metrics
cat("Training Metrics\n")
cat("Accuracy:", acc_train, "\n")
cat("Error Rate:", error_rate_train, "\n")
cat("Sensitivity:", sen_train, "\n")
cat("Specificity:", sep_train, "\n")
cat("Precision:", precision_train, "\n")
cat("F1 Score:", F1_train, "\n")
cat("MCC:", MCC_train, "\n")
cat("AUC:", auc_train, "\n\n")

cat("Validation Metrics\n")
cat("Accuracy:", acc_valid, "\n")
cat("Error Rate:", error_rate_valid, "\n")
cat("Sensitivity:", sen_valid, "\n")
cat("Specificity:", sep_valid, "\n")
cat("Precision:", precision_valid, "\n")
cat("F1 Score:", F1_valid, "\n")
cat("MCC:", MCC_valid, "\n")
cat("AUC:", auc_valid, "\n")

在R语言中,使用 caret 包训练随机森林模型时,最常见的可调参数是 mtry,但还有其他几个参数可以根据需要调整。这些参数通常是从 randomForest 包继承而来的,因为 caret 包的随机森林方法默认使用的是这个包(所以第一次要安装)。下面是一些可以调整的关键参数:

①mtry: 在每个分割中考虑的变量数量。默认情况下,对于分类问题,mtry 默认值是总变量数的平方根;对于回归问题,是总变量数的三分之一。

②ntree: 构建的树的数量。更多的树可以提高模型的稳定性和准确性,但会增加计算时间和内存使用。默认值通常是 500。

③nodesize: 每个叶节点最少包含的样本数。增加这个参数的值可以减少模型的过拟合,但可能会导致欠拟合。对于分类问题,默认值通常为 1,而回归问题则较大。

④maxnodes: 最大的节点数。这限制了树的最大大小,可以用来控制模型复杂度。

⑤importance: 是否计算变量重要性。这不会影响模型的预测能力,但会影响变量重要性分数的计算。

⑥replace: 是否进行有放回抽样。默认为 TRUE,意味着进行有放回的抽样。

⑦classwt: 类的权重,++++用于分类问题中的不平衡数据。++++

⑧cutoff: ++++分类问题中用于预测类别的概率阈值。这通常是一个类别数目的向量。++++

⑨sampsize: 用于每棵树的样本大小,如果使用有放回抽样(replace=TRUE),它决定了每个样本被抽样的次数。

结果输出(默认参数):

在默认参数中,caret包只会默默帮我们找几个合适的mtry值进行测试,其他默认值。

随机森林祖传的过拟合现象。

三、随机森林调参方法(4个值)

设置mtry值取值2、数据集列数的平方根、数据集列数的一半;ntree取值100、500和1000;nodesize取值1、5、10;maxnodes取值30、50、100:

R 复制代码
# Load necessary libraries
library(caret)
library(pROC)
library(ggplot2)
library(randomForest)  # Using randomForest for model fitting

# Assume 'data' is your dataframe containing the data
# Set seed to ensure reproducibility
set.seed(123)

# Split data into training and validation sets (80% training, 20% validation)
trainIndex <- createDataPartition(data$X, p = 0.8, list = FALSE)
trainData <- data[trainIndex, ]
validData <- data[-trainIndex, ]

# Convert the target variable to a factor for classification
trainData$X <- as.factor(trainData$X)
validData$X <- as.factor(validData$X)

# Define ranges for each parameter including mtry
mtry_range <- c(2, sqrt(ncol(trainData)), ncol(trainData)/2)
ntree_range <- c(100, 500, 1000)
nodesize_range <- c(1, 5, 10)
maxnodes_range <- c(30, 50, 100)

# Initialize variables to store the best model, its AUC, and parameters
best_auc <- 0
best_model <- NULL
best_params <- NULL  # Initialize best_params to store parameter values

# Nested loops to try different combinations of parameters including mtry
for (mtry in mtry_range) {
  for (ntree in ntree_range) {
    for (nodesize in nodesize_range) {
      for (maxnodes in maxnodes_range) {
        # Train the Random Forest model
        rf_model <- randomForest(X ~ ., data = trainData, mtry=mtry, ntree=ntree, nodesize=nodesize, maxnodes=maxnodes)
        
        # Predict on validation set using probabilities for ROC
        validProb <- predict(rf_model, newdata = validData, type = "prob")[,2]
        
        # Calculate AUC
        validRoc <- roc(validData$X, validProb)
        auc <- auc(validRoc)
        
        # Update the best model if the current model is better
        if (auc > best_auc) {
          best_auc <- auc
          best_model <- rf_model
          best_params <- list(mtry=mtry, ntree=ntree, nodesize=nodesize, maxnodes=maxnodes)  # Store parameters of the best model
        }
      }
    }
  }
}

# Check if a best model was found and output parameters
if (!is.null(best_params)) {
  cat("The best model parameters are:\n")
  print(best_params)
} else {
  cat("No model was found to exceed the baseline performance.\n")
}

# Use best model for predictions
trainPredict <- predict(best_model, trainData, type = "prob")[,2]
validPredict <- predict(best_model, validData, type = "prob")[,2]

# Rest of the analysis, including ROC curves and plotting
trainRoc <- roc(response = trainData$X, predictor = trainPredict)
validRoc <- roc(response = validData$X, predictor = validPredict)

# Plotting code continues unchanged
ggplot(data = data.frame(fpr = trainRoc$specificities, tpr = trainRoc$sensitivities), aes(x = 1 - fpr, y = tpr)) +
  geom_line(color = "blue") +
  geom_area(alpha = 0.2, fill = "blue") +
  geom_abline(slope = 1, intercept = 0, linetype = "dashed", color = "black") +
  ggtitle("Training ROC Curve") +
  xlab("False Positive Rate") +
  ylab("True Positive Rate") +
  annotate("text", x = 0.5, y = 0.1, label = paste("Training AUC =", round(auc(trainRoc), 2)), hjust = 0.5, color = "blue")

ggplot(data = data.frame(fpr = validRoc$specificities, tpr = validRoc$sensitivities), aes(x = 1 - fpr, y = tpr)) +
  geom_line(color = "red") +
  geom_area(alpha = 0.2, fill = "red") +
  geom_abline(slope = 1, intercept = 0, linetype = "dashed", color = "black") +
  ggtitle("Validation ROC Curve") +
  xlab("False Positive Rate") +
  ylab("True Positive Rate") +
  annotate("text", x = 0.5, y = 0.2, label = paste("Validation AUC =", round(auc(validRoc), 2)), hjust = 0.5, color = "red")

confMatTrain <- table(trainData$X, trainPredict >= 0.5)
confMatValid <- table(validData$X, validPredict >= 0.5)

# Function to plot confusion matrix using ggplot2
plot_confusion_matrix <- function(conf_mat, dataset_name) {
  conf_mat_df <- as.data.frame(as.table(conf_mat))
  colnames(conf_mat_df) <- c("Actual", "Predicted", "Freq")
  
  p <- ggplot(data = conf_mat_df, aes(x = Predicted, y = Actual, fill = Freq)) +
    geom_tile(color = "white") +
    geom_text(aes(label = Freq), vjust = 1.5, color = "black", size = 5) +
    scale_fill_gradient(low = "white", high = "steelblue") +
    labs(title = paste("Confusion Matrix -", dataset_name, "Set"), x = "Predicted Class", y = "Actual Class") +
    theme_minimal() +
    theme(axis.text.x = element_text(angle = 45, hjust = 1), plot.title = element_text(hjust = 0.5))
  
  print(p)
}

# Function to plot confusion matrix and further analysis remains the same...
plot_confusion_matrix(confMatTrain, "Training")
plot_confusion_matrix(confMatValid, "Validation")


# Extract values for calculations
a_train <- confMatTrain[1, 1]
b_train <- confMatTrain[1, 2]
c_train <- confMatTrain[2, 1]
d_train <- confMatTrain[2, 2]

a_valid <- confMatValid[1, 1]
b_valid <- confMatValid[1, 2]
c_valid <- confMatValid[2, 1]
d_valid <- confMatValid[2, 2]

# Training Set Metrics
acc_train <- (a_train + d_train) / sum(confMatTrain)
error_rate_train <- 1 - acc_train
sen_train <- d_train / (d_train + c_train)
sep_train <- a_train / (a_train + b_train)
precision_train <- d_train / (b_train + d_train)
F1_train <- (2 * precision_train * sen_train) / (precision_train + sen_train)
MCC_train <- (d_train * a_train - b_train * c_train) / sqrt((d_train + b_train) * (d_train + c_train) * (a_train + b_train) * (a_train + c_train))
auc_train <- roc(response = trainData$X, predictor = trainPredict)$auc

# Validation Set Metrics
acc_valid <- (a_valid + d_valid) / sum(confMatValid)
error_rate_valid <- 1 - acc_valid
sen_valid <- d_valid / (d_valid + c_valid)
sep_valid <- a_valid / (a_valid + b_valid)
precision_valid <- d_valid / (b_valid + d_valid)
F1_valid <- (2 * precision_valid * sen_valid) / (precision_valid + sen_valid)
MCC_valid <- (d_valid * a_valid - b_valid * c_valid) / sqrt((d_valid + b_valid) * (d_valid + c_valid) * (a_valid + b_valid) * (a_valid + c_valid))
auc_valid <- roc(response = validData$X, predictor = validPredict)$auc

# Print Metrics
cat("Training Metrics\n")
cat("Accuracy:", acc_train, "\n")
cat("Error Rate:", error_rate_train, "\n")
cat("Sensitivity:", sen_train, "\n")
cat("Specificity:", sep_train, "\n")
cat("Precision:", precision_train, "\n")
cat("F1 Score:", F1_train, "\n")
cat("MCC:", MCC_train, "\n")
cat("AUC:", auc_train, "\n\n")

cat("Validation Metrics\n")
cat("Accuracy:", acc_valid, "\n")
cat("Error Rate:", error_rate_valid, "\n")
cat("Sensitivity:", sen_valid, "\n")
cat("Specificity:", sep_valid, "\n")
cat("Precision:", precision_valid, "\n")
cat("F1 Score:", F1_valid, "\n")
cat("MCC:", MCC_valid, "\n")
cat("AUC:", auc_valid, "\n")

结果输出:

以上是找到的相对最优参数组合,看看具体性能:

还不错,至少过拟合调下来了。

五、最后

数据嘛:

链接:https://pan.baidu.com/s/1rEf6JZyzA1ia5exoq5OF7g?pwd=x8xm

提取码:x8xm

相关推荐
西岸行者6 天前
学习笔记:SKILLS 能帮助更好的vibe coding
笔记·学习
悠哉悠哉愿意6 天前
【单片机学习笔记】串口、超声波、NE555的同时使用
笔记·单片机·学习
别催小唐敲代码6 天前
嵌入式学习路线
学习
毛小茛6 天前
计算机系统概论——校验码
学习
babe小鑫6 天前
大专经济信息管理专业学习数据分析的必要性
学习·数据挖掘·数据分析
winfreedoms6 天前
ROS2知识大白话
笔记·学习·ros2
在这habit之下6 天前
Linux Virtual Server(LVS)学习总结
linux·学习·lvs
我想我不够好。6 天前
2026.2.25监控学习
学习
im_AMBER6 天前
Leetcode 127 删除有序数组中的重复项 | 删除有序数组中的重复项 II
数据结构·学习·算法·leetcode
CodeJourney_J6 天前
从“Hello World“ 开始 C++
c语言·c++·学习