【人工智能】【Python】在Scikit-Learn中使用KNN(K最近邻算法)

今天机器学习课上的代码,在此记录一下。

python 复制代码
# 导入包
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split, cross_val_score
from sklearn.neighbors import KNeighborsClassifier
python 复制代码
# 导入数据
data = pd.read_excel("a.xlsx")
X = data[["搞笑镜头", "拥抱镜头", "打斗镜头"]]
y = data["分类"]
# stratify=y 可以使得y_test的数据分布和y_train的一样
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=114514, stratify=y)
python 复制代码
# 分层操作验证
from collections import Counter
print(Counter(y_train))
print(Counter(y_test))
python 复制代码
# 数据标准化
from sklearn.preprocessing import StandardScaler
sd = StandardScaler()
X_train = sd.fit_transform(X_train)
X_test = sd.transform(X_test)
python 复制代码
# 创建模型
k_values = [1,3,5]
acc = dict()
# 寻找最优k值
for k in k_values:
    knn = KNeighborsClassifier(n_neighbors=k)
    # cv是几折交叉验证
    scores = cross_val_score(knn, X_train, y_train, cv=3, scoring="accuracy")
    acc[k] = scores.mean()
# print(acc)
best_k = max(acc, key=acc.get)
knn = KNeighborsClassifier(n_neighbors=best_k)
python 复制代码
# 创建模型
k_values = [1,3,5]
acc = dict()
# 寻找最优k值
for k in k_values:
    knn = KNeighborsClassifier(n_neighbors=k)
    # cv是几折交叉验证
    scores = cross_val_score(knn, X_train, y_train, cv=3, scoring="accuracy")
    acc[k] = scores.mean()
# print(acc)
best_k = max(acc, key=acc.get)
knn = KNeighborsClassifier(n_neighbors=best_k)
python 复制代码
# 训练模型
knn.fit(X_train, y_train)
print(knn.score(X_test, y_test))
# 预测
y_pred = knn.predict(X_test)
# 输出真实标签
print(y_test)
# 输出预测值
print(y_pred)
# 输出预测概率
print(knn.predict_proba(X_test))
# 类别
print(knn.classes_)

数据处理与划分

通过pandas导入结构化数据,选取"搞笑镜头"等三个特征作为输入变量,影片分类作为目标变量。采用分层抽样(stratify=y)将数据按8:2比例划分训练集和测试集,确保两个集合的类别分布比例与原数据集一致,这对于类别不平衡数据的建模尤为重要。标准化处理(StandardScaler)消除特征量纲差异,这是基于距离计算的KNN算法的必要预处理步骤。

模型调参与优化

针对KNN的核心超参数k值(最近邻数量),采用网格搜索策略测试[1,3,5]三个候选值。通过3折交叉验证(cross_val_score)在训练集上评估不同k值的平均准确率(accuracy),最终选择验证集表现最优的k值构建最终模型。这种交叉验证方法能有效避免单次数据划分带来的偶然性,提升超参数选择的可靠性。

模型评估与输出

在独立测试集上计算模型准确率(score方法)作为最终性能指标,同时输出预测结果的三类核心信息:真实标签(y_test)、预测标签(y_pred)和预测概率(predict_proba)。预测概率矩阵的列顺序与classes_属性显示的类别顺序对应,这对多分类问题的结果解析具有重要意义。

实现特点

代码体现了机器学习项目的典型工作流:数据准备→特征工程→模型训练→参数调优→性能评估。特别值得注意的是对数据分布保持(stratify)、特征标准化、交叉验证等机器学习最佳实践的完整实现,这些细节处理对模型性能有实质性影响。最终的预测概率输出也为后续的决策阈值调整等业务场景应用提供了扩展接口。

本次代码如果搞懂了,使用Scikit-Learn实现机器学习的50%基本上就学会了(老师原话),相比PyTorch的深度学习框架,Scikit-Learn集成了很多传统的机器学习算法,使用起来非常方便简洁。

相关推荐
Juchecar7 分钟前
分析:将现代开源浏览器的JavaScript引擎更换为Python的可行性与操作
前端·javascript·python
科大饭桶9 分钟前
昇腾AI自学Day2-- 深度学习基础工具与数学
人工智能·pytorch·python·深度学习·numpy
什么都想学的阿超28 分钟前
【大语言模型 02】多头注意力深度剖析:为什么需要多个头
人工智能·语言模型·自然语言处理
努力还债的学术吗喽1 小时前
2021 IEEE【论文精读】用GAN让音频隐写术骗过AI检测器 - 对抗深度学习的音频信息隐藏
人工智能·深度学习·生成对抗网络·密码学·音频·gan·隐写
明道云创始人任向晖1 小时前
20个进入实用阶段的AI应用场景(零售电商业篇)
人工智能·零售
数据智研1 小时前
【数据分享】大清河(大庆河)流域上游土地利用
人工智能
聚客AI1 小时前
🔷告别天价算力!2025性价比最高的LLM私有化训练路径
人工智能·llm·掘金·日新计划
天波信息技术分享1 小时前
AI 云电竞游戏盒子:从“盒子”到“云-端-芯”一体化竞技平台的架构实践
人工智能·游戏·架构
用户5191495848452 小时前
curl --continue-at 参数异常行为分析:文件覆盖与删除风险
人工智能·aigc
用户84913717547162 小时前
joyagent智能体学习(第1期):项目概览与架构解析
人工智能·llm·agent