【人工智能】【Python】在Scikit-Learn中使用KNN(K最近邻算法)

今天机器学习课上的代码,在此记录一下。

python 复制代码
# 导入包
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split, cross_val_score
from sklearn.neighbors import KNeighborsClassifier
python 复制代码
# 导入数据
data = pd.read_excel("a.xlsx")
X = data[["搞笑镜头", "拥抱镜头", "打斗镜头"]]
y = data["分类"]
# stratify=y 可以使得y_test的数据分布和y_train的一样
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=114514, stratify=y)
python 复制代码
# 分层操作验证
from collections import Counter
print(Counter(y_train))
print(Counter(y_test))
python 复制代码
# 数据标准化
from sklearn.preprocessing import StandardScaler
sd = StandardScaler()
X_train = sd.fit_transform(X_train)
X_test = sd.transform(X_test)
python 复制代码
# 创建模型
k_values = [1,3,5]
acc = dict()
# 寻找最优k值
for k in k_values:
    knn = KNeighborsClassifier(n_neighbors=k)
    # cv是几折交叉验证
    scores = cross_val_score(knn, X_train, y_train, cv=3, scoring="accuracy")
    acc[k] = scores.mean()
# print(acc)
best_k = max(acc, key=acc.get)
knn = KNeighborsClassifier(n_neighbors=best_k)
python 复制代码
# 创建模型
k_values = [1,3,5]
acc = dict()
# 寻找最优k值
for k in k_values:
    knn = KNeighborsClassifier(n_neighbors=k)
    # cv是几折交叉验证
    scores = cross_val_score(knn, X_train, y_train, cv=3, scoring="accuracy")
    acc[k] = scores.mean()
# print(acc)
best_k = max(acc, key=acc.get)
knn = KNeighborsClassifier(n_neighbors=best_k)
python 复制代码
# 训练模型
knn.fit(X_train, y_train)
print(knn.score(X_test, y_test))
# 预测
y_pred = knn.predict(X_test)
# 输出真实标签
print(y_test)
# 输出预测值
print(y_pred)
# 输出预测概率
print(knn.predict_proba(X_test))
# 类别
print(knn.classes_)

数据处理与划分

通过pandas导入结构化数据,选取"搞笑镜头"等三个特征作为输入变量,影片分类作为目标变量。采用分层抽样(stratify=y)将数据按8:2比例划分训练集和测试集,确保两个集合的类别分布比例与原数据集一致,这对于类别不平衡数据的建模尤为重要。标准化处理(StandardScaler)消除特征量纲差异,这是基于距离计算的KNN算法的必要预处理步骤。

模型调参与优化

针对KNN的核心超参数k值(最近邻数量),采用网格搜索策略测试[1,3,5]三个候选值。通过3折交叉验证(cross_val_score)在训练集上评估不同k值的平均准确率(accuracy),最终选择验证集表现最优的k值构建最终模型。这种交叉验证方法能有效避免单次数据划分带来的偶然性,提升超参数选择的可靠性。

模型评估与输出

在独立测试集上计算模型准确率(score方法)作为最终性能指标,同时输出预测结果的三类核心信息:真实标签(y_test)、预测标签(y_pred)和预测概率(predict_proba)。预测概率矩阵的列顺序与classes_属性显示的类别顺序对应,这对多分类问题的结果解析具有重要意义。

实现特点

代码体现了机器学习项目的典型工作流:数据准备→特征工程→模型训练→参数调优→性能评估。特别值得注意的是对数据分布保持(stratify)、特征标准化、交叉验证等机器学习最佳实践的完整实现,这些细节处理对模型性能有实质性影响。最终的预测概率输出也为后续的决策阈值调整等业务场景应用提供了扩展接口。

本次代码如果搞懂了,使用Scikit-Learn实现机器学习的50%基本上就学会了(老师原话),相比PyTorch的深度学习框架,Scikit-Learn集成了很多传统的机器学习算法,使用起来非常方便简洁。

相关推荐
北京_宏哥1 分钟前
🔥PC端自动化测试实战教程-2-pywinauto 启动PC端应用程序 - 上篇(详细教程)
前端·windows·python
_x_w10 分钟前
【12】数据结构之基于线性表的排序算法
开发语言·数据结构·笔记·python·算法·链表·排序算法
硅谷秋水16 分钟前
OpenDriveVLA:通过大型视觉-语言-动作模型实现端到端自动驾驶
人工智能·机器学习·计算机视觉·语言模型·机器人·自动驾驶
蚝油菜花20 分钟前
【内附榜单】评估AI大模型的代码修复能力!Multi-SWE-bench:字节开源代码修复能力评估基准,覆盖7大主流编程语言
人工智能·开源
北极的树22 分钟前
Vibe coding 最后一公里: 打造一套通用的AI任务拆分和管理系统
人工智能
蚝油菜花22 分钟前
1天消化完Spring全家桶文档!DevDocs:一键深度解析开发文档,自动发现子URL并建立图谱
人工智能·开源
grrrr_123 分钟前
【ctfplus】python靶场记录-任意文件读取+tornado模板注入+yaml反序列化(新手向)
python·web安全·tornado
蚝油菜花24 分钟前
让AI绘画进入「指哪画哪」的精准时代!EasyControl:Tiamat AI 联合上海科大开源图像生成控制框架
人工智能·开源
飞哥数智坊25 分钟前
ADK开源:谷歌持续发力Agent,请快速跟进
人工智能·trae
声网26 分钟前
Orpheus 语音模型支持中文预训练和微调,模拟细微语音特征;谷歌版 MCP 来了,A2A 协议让不同厂商 Agent 协作
人工智能