机器学习笔记——K近邻算法、手写数字识别

KNN算法

"物以类聚,人以群分"相似的数据往往拥有相同的类别

其大概原理就是一个样本归到哪一类,当前样本需要归到频次最高的哪个类去

也就是说有一个待分类的样本,然后跟他周围的k个样本来看,k中哪一个类最多,待分类的样本就是哪一个。

那就以手写数字识别为例吧

c 复制代码
import matplotlib.pyplot as plt
import numpy as np
import os
#%%
# 读入mnist数据集
m_x = np.loadtxt('./data/mnist_x', delimiter=' ')
m_y = np.loadtxt('./data/mnist_y')
#%%
# 数据集可视化
data = np.reshape(np.array(m_x[0], dtype=int), [28, 28])
plt.figure()
plt.imshow(data, cmap='gray')
#%%
# 将数据集分为训练集和测试集
ratio = 0.8
split = int(len(m_x) * ratio)
# 打乱数据
np.random.seed(0)
idx = np.random.permutation(np.arange(len(m_x))) #随机排序
m_x = m_x[idx]
m_y = m_y[idx]
x_train, x_test = m_x[:split], m_x[split:]
y_train, y_test = m_y[:split], m_y[split:]
#%%
#定义距离函数
def distance(x,y):
    return np.sqrt(np.sum(np.square(x-y)))

#%%
#定义KNN模型
class KNN:
    def __init__(self,k,label_num):
        self.k=k
        self.label_num=label_num #类别的数量
    def fit(self,x_train,y_train):
        self.x_train=x_train
        self.y_train=y_train
    def get_knn_indices(self,x): #获得距离目标样本最近的k个点的标签,a来做self_x.train
        dis=list(map(lambda a:distance(a,x),self.x_train))
        knn_indices=np.argsort(dis) #对距离排序,在选择k个出来
        knn_indices=knn_indices[:self.k]#标签
        return knn_indices
     def get_label(self,x):#计算k个点中,样本的标签数量是多少
         knn_indices=self.get_knn_indices(x)
         label_statistic=np.zeros(shape=[self.label_num])
         for index in knn_indices:
             label=int(self.y_train[index])
             label_statistic[label]+=1
         return np.argmax(label_statistic) #找出最大的类别
     def predict(self,x_test):
         predicted_test_labels=np.zeros(shape=[len(x_test)],dtype=int)
         for i,x in enumerate(x_test): #枚举
             predicted_test_labels[i]=self.get_label(x)
         return predicted_test_labels

#%%
for k in range(1,10):
    knn=KNN(k,label_num=10)
    knn.fit(x_train,y_train)
    predicted_labels=knn.predict(x_test)
    accuracy=np.mean(predicted_labels==y_test)
    print(f'k的取值为{k},预测准确率为{accuracy*100:.lf}%')
相关推荐
嵌入式@秋刀鱼2 小时前
《第四章-筋骨淬炼》 C++修炼生涯笔记(基础篇)数组与函数
开发语言·数据结构·c++·笔记·算法·链表·visual studio code
嵌入式@秋刀鱼2 小时前
《第五章-心法进阶》 C++修炼生涯笔记(基础篇)指针与结构体⭐⭐⭐⭐⭐
c语言·开发语言·数据结构·c++·笔记·算法·visual studio code
m0_678693332 小时前
深度学习笔记26-天气预测(Tensorflow)
笔记·深度学习·tensorflow
硅谷秋水2 小时前
NORA:一个用于具身任务的小型开源通才视觉-语言-动作模型
人工智能·深度学习·机器学习·计算机视觉·语言模型·机器人
桂?2 小时前
使用离线依赖解决Android Studio编译报错(下载不了jar)——笔记
笔记·android studio·jar
EQ-雪梨蛋花汤4 小时前
【Unity笔记】Unity Animation组件使用详解:Play方法重载与动画播放控制
笔记·unity·游戏引擎
scdifsn4 小时前
动手学深度学习13.3. 目标检测和边界框-笔记&练习(PyTorch)
笔记·深度学习·目标检测·目标识别·标注边界框
霸王蟹4 小时前
前端项目Excel数据导出同时出现中英文表头错乱情况解决方案。
笔记·学习·typescript·excel·vue3·react·vite
LuH11244 小时前
【论文阅读笔记】ICLR 2025 | 解析Ref-Gaussian如何实现高质量可交互反射渲染
论文阅读·笔记·论文笔记
mwicogito5 小时前
实验复现:应用 RIR 触发器的 TrojanRoom 后门攻击实现
人工智能·python·机器学习·语音识别·后门攻击