OpenCV与机器学习:使用opencv实现KNN

前言

KNN算法,即K-Nearest Neighbor,是一种简单的机器学习算法。它的基本思想是:如果一个样本在特征空间中的K个最相似(即特征空间中最邻近)的样本中的大多数属于某一个类别,则该样本也属于这个类别。KNN算法通过计算待分类样本与已知类别样本的距离,找出距离最近的K个样本,根据这K个样本的类别,通过投票的方式确定待分类样本的类别。

在KNN算法中,距离的度量方式主要有欧式距离、曼哈顿距离、余弦相似度等。其中,欧式距离是最常用的距离度量方式,计算的是两点之间的直线距离。

KNN算法的优点包括简单直观、易于实现、对异常点不敏感等。但同时,KNN算法也存在一些局限性,如对数据分布敏感、K值选择主观等。此外,对于大规模数据集和高维数据,KNN算法的效率可能会较低。

总的来说,KNN算法是一种基础且重要的机器学习算法,在数据分类、回归分析、聚类等领域都有广泛的应用。

OpenCV实现KNN

生成随机数据

python 复制代码
import numpy as np

def generate_data(num_samples, num_features=2):
    data_size = (num_samples, num_features)
    train_data = np.random.randint(0, 100, size=data_size)
    labels_size = (num_samples, 1)
    labels = np.random.randint(0, 2, size=labels_size)
    return train_data.astype(np.float32), labels

我们构建了一个函数来生成数据集,num_samples代表的是数据量,num_features默认为2代表的是数据有两个特征。通过随机数生成数据和标签。

python 复制代码
train_data, labels = generate_data(11)
train_data

生成的数据如下:

  • 注:由于随机种子的不同生成的数据很有可能不同,以实际运行结果为准(下面的画图也一样)

下面以第0个数据为例看一下数据和标签的对应

python 复制代码
train_data[0], labels[0]

使用matplotlib绘制数据

python 复制代码
import matplotlib.pyplot as plt

plt.style.use('ggplot')
plt.plot(train_data[0, 0], train_data[0, 1], color='r', marker='^', markersize=10)
plt.xlabel('x coordinate')
plt.ylabel('y coordinate')

在代码中plt.style.use('ggplot')的作用是为了使得画图更加好看。随后随便画了数据中的一个点的效果如下:

为了方便展示数据,我们在下面构建一个绘图函数用不同的颜色和不同的形状表示不同类别的数据。

python 复制代码
def plot_data(all_blue, all_red):
    plt.figure(figsize=(10, 6))
    plt.scatter(all_blue[:, 0], all_blue[:, 1], c='b', marker='s', s=180)
    plt.scatter(all_red[:, 0], all_red[:, 1], c='r', marker='^', s=180)
    plt.xlabel('x coordinate')
    plt.ylabel('y coordinate')

现在我们有了数据和展示函数,那么如何区分数据的类别呢,我们可以根据标签(labels)来选择不同类别的数据。不过生成的标签使用其实并不方便,在这里我们使用ravel()将数据展平。效果对比如下:

python 复制代码
labels, labels.ravel()

现在我们已经有了数据的选择方式和画图方式,现在我们来画一下不同类别的数据。

python 复制代码
blue = train_data[labels.ravel()==0]
red = train_data[labels.ravel()==1]
python 复制代码
plot_data(blue, red)

OpenCV用于KNN训练

在cv2中的ml模块中有一些常见的机器学习方法,使用KNearest_create可以创建一个KNN的分类器,训练时需要传入的数据有数据、数据组织方式、标签

python 复制代码
import cv2

knn = cv2.ml.KNearest_create()
knn.train(train_data, cv2.ml.ROW_SAMPLE, labels)

在代码中使用的数据组织方式为cv2.ml.ROW_SAMPLE,即每一行就是一条数据。

生成新数据并观察

训练完之后我们再生成一个不含标签的数据

python 复制代码
newcomer, _ = generate_data(1)
newcomer

画出原始数据和新生成的不含标签的数据,一起进行对比。

python 复制代码
plot_data(blue, red)
plt.plot(newcomer[0, 0], newcomer[0, 1], 'go', markersize=14)

在下面的图中,绿色的圆形就是新的数据,这个数据不含标签。我们观察到这个数据和蓝色方块很近,所以大概率是属于蓝色方块的。即属于第0类。

预测

我们先来使用最近的一个点来进行预测,这里findNearest第二个参数是1

python 复制代码
ret, results, neighbor, dist = knn.findNearest(newcomer, 1)
results, neighbor, dist

然后使用最近的三个点来进行预测,这里findNearest第二个参数是3

python 复制代码
ret, results, neighbor, dist = knn.findNearest(newcomer, 3)
results, neighbor, dist

可以看到预测结果是正确的

与此同时我们也可以使用predict方法进行预测。但是,首先我们需要设置k,我们要根据最近的k个 数据点来判断新数据具体属于哪一类。

使用方式如下:

python 复制代码
knn.setDefaultK(1)
knn.predict(newcomer)
python 复制代码
knn.setDefaultK(3)
knn.predict(newcomer)
相关推荐
ZHOU_WUYI1 小时前
3.langchain中的prompt模板 (few shot examples in chat models)
人工智能·langchain·prompt
如若1231 小时前
主要用于图像的颜色提取、替换以及区域修改
人工智能·opencv·计算机视觉
老艾的AI世界2 小时前
AI翻唱神器,一键用你喜欢的歌手翻唱他人的曲目(附下载链接)
人工智能·深度学习·神经网络·机器学习·ai·ai翻唱·ai唱歌·ai歌曲
DK221512 小时前
机器学习系列----关联分析
人工智能·机器学习
Robot2512 小时前
Figure 02迎重大升级!!人形机器人独角兽[Figure AI]商业化加速
人工智能·机器人·微信公众平台
浊酒南街3 小时前
Statsmodels之OLS回归
人工智能·数据挖掘·回归
畅联云平台3 小时前
美畅物联丨智能分析,安全管控:视频汇聚平台助力智慧工地建设
人工智能·物联网
加密新世界3 小时前
优化 Solana 程序
人工智能·算法·计算机视觉
hunteritself4 小时前
ChatGPT高级语音模式正在向Web网页端推出!
人工智能·gpt·chatgpt·openai·语音识别
Che_Che_4 小时前
Cross-Inlining Binary Function Similarity Detection
人工智能·网络安全·gnn·二进制相似度检测