python代码实现KNN对鸢尾花的分类

一、KNN模型-KNeighborsClassifier()

1.1 导入sklearn第三方库

python 复制代码
from sklearn import datasets #sklearn的数据集
from sklearn.neighbors import KNeighborsClassifier #sklearn模块的KNN类

我们使用一个叫作鸢尾花数据集的数据,这个数据集里面有 150 条数据,共有 3 个类别,即Setosa 鸢尾花、Versicolour鸢尾花和 Virginica 鸢尾花,每个类别有 50 条数据,每条数据有 4个维度,分别记录了鸢尾花的花萼长度、花萼宽度、花瓣长度和花瓣宽度。

1.2 加载数据集

python 复制代码
iris=datasets.load_iris()
iris_x=iris.data
iris_y=iris.target

只用到数据维度iris.data,数据标签iris.target。

可以直接在python代码中输入iris.data或者iris.target查看数据

1.3 划分训练集和测试集

python 复制代码
randomarr= np.random.permutation(len(iris_x))
#用前140个作为训练集
#randomarr[:-10]表示从数组的开头到倒数第十个元素之前(不包括倒数第十个元素)的所有元素
iris_x_train = iris_x[randomarr[:-10]]
iris_y_train = iris_y[randomarr[:-10]]

#用后10个作为测试集
#randomarr[-10:]表示从数组 randomarr 的倒数第十个元素开始到最后一个元素的子数组,记住索引值
iris_x_test = iris_x[randomarr[-10:]]
iris_y_test = iris_y[randomarr[-10:]]

len(randomarr[:-10]) #输出140

numpy.random.permutation函数用于对数组进行随机排列(即置换)。它的作用是返回一个新的打乱了顺序的数组,而不会修改原始数组

如果输入是一个整数n,那么函数会返回一个包含0到n-1的整数的随机排列。 如果输入是一个数组,那么函数会返回该数组的一个随机排列。

1.4 创建模型,训练模型

python 复制代码
# 创建一个 K 近邻分类器,默认邻居数为 5
knn = KNeighborsClassifier()
# 使用训练数据拟合模型
knn.fit(iris_x_train, iris_y_train)

1.5 应用测试数据集分类

python 复制代码
# 使用模型进行预测
iris_y_predict = knn.predict(iris_x_test)
print('iris_y_predict = ')
print(iris_y_predict)
#输出原始测试数据集的正确标签,以方便对比
print('iris_y_test = ')
print(iris_y_test)
#输出准确率计算结果
print('Accuracy:',score)
iris_y_predict = 
[2 1 2 0 1 1 0 1 2 2]
iris_y_test = 
[2 1 2 0 1 2 0 1 2 2]
Accuracy: 0.9

二、KNN模型-knn.kneighbors()

python 复制代码
neighborpoint=knn.kneighbors([iris_x_test[-1]],5)
#这里是引用[iris_x_test[-1]]#是测试数据集中的最后一个样本第一个参数就是测试数据集里面倒数最后一个样本,5表示返回5个邻居
neighborpoint
(array([[0.26457513, 0.52915026, 0.54772256, 0.54772256, 0.60827625]]),
 array([[  4,  30, 111,  55,  74]], dtype=int64))

KNeighborsClassifier(n_neighbors=5) 是用来进行分类预测的模型,而knn.kneighbors 是用来寻找最近邻居的方法。返回的是最近邻居的索引和对应的距离

knn.kneighbors 方法接受一个数据点或一组数据点作为输入,并返回这些数据点的每个最近邻居的索引和对应的距离。

X: 一个包含数据点特征的数组,用于寻找最近邻居。

n_neighbors(可选):一个整数,表示要返回的最近邻居的数量。如果未指定,则默认为 k,即 KNeighborsClassifier 初始化时设置的邻居数。

python 复制代码
import numpy as np
from sklearn.neighbors import KNeighborsClassifier
# 创建一些示例训练数据
X_train = np.array([[1, 2],[2, 3],[3, 4],[4, 5]])
y_train = np.array([0, 0, 1, 1])  # 样本的类别标签

# 创建一个新的测试数据点
X_test = np.array([[2, 3]])
knn0 = KNeighborsClassifier(n_neighbors=2)
knn0.fit(X_train, y_train)
point=knn0.kneighbors(X_test,2)
point
(array([[0.        , 1.41421356]]), array([[1, 0]], dtype=int64))

三、KNN模型-knn.predict_proba()

python 复制代码
probility=knn.predict_proba(iris_x_test)
predict=knn.predict(iris_x_test)
probility,predict
(array([[0. , 0. , 1. ],
        [0. , 1. , 0. ],
        [0. , 0. , 1. ],
        [1. , 0. , 0. ],
        [0. , 1. , 0. ],
        [0. , 1. , 0. ],
        [1. , 0. , 0. ],
        [0. , 0.8, 0.2],
        [0. , 0. , 1. ],
        [0. , 0. , 1. ]]),
 array([2, 1, 2, 0, 1, 1, 0, 1, 2, 2]))
python 复制代码
import numpy as np
from sklearn.neighbors import KNeighborsClassifier
# 创建一些示例训练数据
X_train = np.array([[1, 2],[2, 3], [3, 4], [4, 5]])
y_train = np.array([0, 1, 1, 2])  # 样本的类别标签

# 创建一个新的测试数据点
X_test = np.array([[1.1, 1.1]])

# 创建并训练 K 近邻分类器
knn = KNeighborsClassifier(n_neighbors=3)
knn.fit(X_train, y_train)

# 预测测试数据点属于每个类别的概率
probabilities = knn.predict_proba(X_test)

print("测试数据点属于每个类别的概率:")
print(probabilities)
测试数据点属于每个类别的概率:
[[0.33333333 0.66666667 0.        ]]

假设我们有一个三分类问题,有以下训练数据:

类别 0:包含一个样本:[1, 2]

类别 1:包含二个样本:[2, 3],[3, 4]

类别 2:包含一个样本:[4, 5]

现在,我们使用一个 K 近邻分类器,设定 k=3,并且对一个新的测试样本 [1.1, 1.1] 进行分类
距离计算

测试样本 [1.1, 1.1] 与训练样本 [1, 2] 之间的距离:0.905

测试样本 [1.1, 1.1] 与训练样本 [2, 3] 之间的距离:1.825

测试样本 [1.1, 1.1] 与训练样本 [3, 4] 之间的距离:2.98

测试样本 [1.1, 1.1] 与训练样本 [4, 5] 之间的距离:4.242

确定最近的K个邻居

根据上述距离计算,最近的3(分几类就找几个邻居)个邻居分别是 [1, 2],[2, 3] 和 [3, 4]。

这三个邻居分别来自于类别0、类别1和类别1。因此,我们预测测试样本 [1.1, 1.1] 属于哪个类别?

类别0:出现1次

类别1:出现2次

类别2:出现0次

根据最近的3个邻居,测试样本 [1.1, 1.1] 最有可能属于类别1。

相关推荐
drebander几秒前
使用 Java Stream 优雅实现List 转化为Map<key,Map<key,value>>
java·python·list
tangliang_cn9 分钟前
java入门 自定义springboot starter
java·开发语言·spring boot
莫叫石榴姐10 分钟前
数据科学与SQL:组距分组分析 | 区间分布问题
大数据·人工智能·sql·深度学习·算法·机器学习·数据挖掘
程序猿阿伟10 分钟前
《智能指针频繁创建销毁:程序性能的“隐形杀手”》
java·开发语言·前端
新知图书21 分钟前
Rust编程与项目实战-模块std::thread(之一)
开发语言·后端·rust
威威猫的栗子23 分钟前
Python Turtle召唤童年:喜羊羊与灰太狼之懒羊羊绘画
开发语言·python
力透键背23 分钟前
display: none和visibility: hidden的区别
开发语言·前端·javascript
bluefox197924 分钟前
使用 Oracle.DataAccess.Client 驱动 和 OleDB 调用Oracle 函数的区别
开发语言·c#
ö Constancy1 小时前
c++ 笔记
开发语言·c++
ChaseDreamRunner1 小时前
迁移学习理论与应用
人工智能·机器学习·迁移学习