机器学习之scikit-learn(简称 sklearn)

scikit-learn(简称 sklearn)是 Python 生态中一个非常流行且强大的机器学习库,支持各种机器学习算法和工具。


核心模块和功能

  1. 监督学习 (Supervised Learning)

    • 分类 (Classification):

      • 支持的算法:KNN、SVM、决策树、随机森林、Logistic回归、朴素贝叶斯等。

      • 示例:

        复制代码
        from sklearn.ensemble import RandomForestClassifier
        
        clf = RandomForestClassifier(n_estimators=100, random_state=42)
        clf.fit(X_train, y_train)  # 训练模型
        predictions = clf.predict(X_test)  # 预测
    • 回归 (Regression):

      • 支持的算法:线性回归、Ridge回归、Lasso回归、SVR等。

      • 示例:

        复制代码
        from sklearn.linear_model import LinearRegression
        
        reg = LinearRegression()
        reg.fit(X_train, y_train)  # 拟合数据
        predictions = reg.predict(X_test)  # 预测
  2. 无监督学习 (Unsupervised Learning)

    • 聚类 (Clustering):

      • 支持的算法:K-Means、DBSCAN、层次聚类等。

      • 示例:

        复制代码
        from sklearn.cluster import KMeans
        
        kmeans = KMeans(n_clusters=3, random_state=42)
        kmeans.fit(X)  # 拟合数据
        labels = kmeans.labels_  # 获取聚类标签
    • 降维 (Dimensionality Reduction):

      • 支持的算法:PCA、TSNE、ICA等。

      • 示例:

        复制代码
        from sklearn.decomposition import PCA
        
        pca = PCA(n_components=2)
        X_reduced = pca.fit_transform(X)  # 降维
  3. 模型选择与优化 (Model Selection and Optimization)

    • 交叉验证 (Cross Validation):

      • 使用 cross_val_score 实现简单交叉验证。

      • 示例:

        复制代码
        from sklearn.model_selection import cross_val_score
        
        scores = cross_val_score(clf, X, y, cv=5)  # 5折交叉验证
        print(scores.mean())  # 平均准确率
    • 超参数调优 (Hyperparameter Tuning):

      • 使用 GridSearchCVRandomizedSearchCV

      • 示例:

        复制代码
        from sklearn.model_selection import GridSearchCV
        
        param_grid = {'n_estimators': [50, 100, 150], 'max_depth': [10, 20, None]}
        grid_search = GridSearchCV(RandomForestClassifier(), param_grid, cv=3)
        grid_search.fit(X_train, y_train)
        print(grid_search.best_params_)  # 最优参数
  4. 预处理 (Preprocessing)

    • 标准化与归一化:

      • 使用 StandardScalerMinMaxScaler

      • 示例:

        复制代码
        from sklearn.preprocessing import StandardScaler
        
        scaler = StandardScaler()
        X_scaled = scaler.fit_transform(X)
    • 特征选择 (Feature Selection):

      • 支持的方法:SelectKBest、递归特征消除 (RFE) 等。

      • 示例:

        复制代码
        from sklearn.feature_selection import SelectKBest, f_classif
        
        selector = SelectKBest(f_classif, k=10)
        X_new = selector.fit_transform(X, y)

常用工具

  1. 评估指标 (Metrics)

    • 分类指标:准确率、F1分数、ROC曲线等。

      复制代码
      from sklearn.metrics import accuracy_score, classification_report
      
      print(accuracy_score(y_test, y_pred))
      print(classification_report(y_test, y_pred))
    • 回归指标:均方误差 (MSE)、R²等。

      复制代码
      from sklearn.metrics import mean_squared_error, r2_score
      
      print(mean_squared_error(y_test, y_pred))
      print(r2_score(y_test, y_pred))
  2. 数据集工具

    • 自带数据集加载:如 irisdigits 等。

      复制代码
      from sklearn.datasets import load_iris
      
      data = load_iris()
      X, y = data.data, data.target
    • 数据集拆分:

      复制代码
      from sklearn.model_selection import train_test_split
      
      X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

完整工作流程示例

以一个分类任务为例,使用随机森林进行训练并评估:

复制代码
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, classification_report

# 1. 加载数据
data = load_iris()
X, y = data.data, data.target

# 2. 数据集划分
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 3. 模型训练
clf = RandomForestClassifier(n_estimators=100, random_state=42)
clf.fit(X_train, y_train)

# 4. 模型预测
y_pred = clf.predict(X_test)

# 5. 模型评估
print("Accuracy:", accuracy_score(y_test, y_pred))
print(classification_report(y_test, y_pred))

适用场景

  1. 快速实现基于传统方法的机器学习任务。
  2. 教学或研究中算法的对比实验。
  3. 中小型数据集的机器学习应用。
相关推荐
Lecea_L1 分钟前
🔍 找到数组里的“节奏感”:最长等差子序列
java·算法
是Dream呀4 分钟前
ResNeXt: 通过聚合残差变换增强深度神经网络
人工智能·算法
DeepLink11 分钟前
Python小练习系列:学生信息排序(sorted + key函数)
python·求职
项目申报小狂人15 分钟前
CUDA详细安装及环境配置——环境配置指南 – CUDA+cuDNN+PyTorch 安装
人工智能·pytorch·python
林泽毅15 分钟前
SwanLab Slack通知插件:让AI训练状态同步更及时
深度学习·机器学习·强化学习
学c真好玩25 分钟前
4.1-python操作wrod/pdf 文件
开发语言·python·pdf
东方佑26 分钟前
使用Python解析PPT文件并生成JSON结构详解
python·json·powerpoint
Auroral15629 分钟前
一文搞懂python实现邮件发送的全流程
python
大霸王龙30 分钟前
LLM(语言学习模型)行为控制技术
python·深度学习·学习
我不是大佬zvj32 分钟前
PyGame开发贪吃蛇小游戏
python·pygame