基于最近邻数据进行分类

人工智能例子汇总:AI常见的算法和例子-CSDN博客

完整代码:

python 复制代码
import torch
import numpy as np
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt

# 生成一个简单的数据集 (2个特征和2个分类)
# X为输入特征,y为标签
X = np.array([[1, 2], [2, 3], [3, 4], [5, 7], [6, 8], [7, 9], [8, 10], [3, 6], [4, 5], [6, 4]])
y = np.array([0, 0, 0, 1, 1, 1, 1, 0, 0, 1])

# 数据转换为 PyTorch 张量
X_tensor = torch.tensor(X, dtype=torch.float32)
y_tensor = torch.tensor(y, dtype=torch.long)

# 打印数据
print("Features:")
print(X_tensor)
print("Labels:")
print(y_tensor)

# 使用 sklearn KNN 分类器,调整邻居数量为 5
knn = KNeighborsClassifier(n_neighbors=5)
knn.fit(X, y)

# 预测
y_pred = knn.predict(X)

# 计算准确率
accuracy = accuracy_score(y, y_pred)
print(f"Accuracy: {accuracy * 100:.2f}%")

# 可视化数据
plt.figure(figsize=(6, 4))
plt.scatter(X[:, 0], X[:, 1], c=y, cmap='bwr', marker='o', edgecolor='k', s=100)
plt.title("KNN Classification Example")
plt.xlabel("Feature 1")
plt.ylabel("Feature 2")
plt.show()

# 测试:给定新的输入数据进行预测
test_data = np.array([[5, 6], [2, 3]])
test_prediction = knn.predict(test_data)

print(f"Predictions for test data {test_data} are {test_prediction}")
  • 生成数据 :创建了一个具有 2 个特征和 2 个类别标签的数据集。X 是输入特征,y 是标签。
  • 转换为 PyTorch 张量:虽然这里我们不需要在 KNN 算法中使用 PyTorch,但我们将数据转换为 PyTorch 张量,显示如何与 PyTorch 数据结构进行交互。
  • KNN 分类器 :使用 sklearn.neighbors.KNeighborsClassifier 创建并训练 KNN 模型。我们将 n_neighbors 设置为 5,即选择 5 个最近邻。
  • 预测与准确率:使用训练好的模型对所有数据进行预测,并计算准确率。
  • 可视化 :使用 matplotlib 将数据点可视化,数据点的颜色根据标签进行区分。
  • 测试预测 :我们对新的测试数据点 [5, 6][2, 3] 进行预测。
  • 结果:
python 复制代码
Features:
tensor([[ 1.,  2.],
        [ 2.,  3.],
        [ 3.,  4.],
        [ 5.,  7.],
        [ 6.,  8.],
        [ 7.,  9.],
        [ 8., 10.],
        [ 3.,  6.],
        [ 4.,  5.],
        [ 6.,  4.]])
Labels:
tensor([0, 0, 0, 1, 1, 1, 1, 0, 0, 1])
Accuracy: 90.00%
Predictions for test data [[5 6]
 [2 3]] are [1 0]
相关推荐
大模型最新论文速读16 小时前
ProFit: 屏蔽低概率 token,解决 SFT 过拟合问题
人工智能·深度学习·机器学习·语言模型·自然语言处理
cskywit16 小时前
VMamba环境本地适配配置
人工智能·深度学习·mamba
victory043117 小时前
minimind SFT失败原因排查和解决办法
人工智能·python·深度学习
逐梦苍穹17 小时前
世界模型通俗讲解:AI大脑里的“物理模拟器“
人工智能·世界模型
发哥来了17 小时前
主流AI视频生成工具商用化能力评测:五大关键维度对比分析
大数据·人工智能·音视频
跳跳糖炒酸奶17 小时前
基于深度学习的单目深度估计综述阅读(1)
人工智能·深度学习·数码相机·单目深度估计
yangpipi-17 小时前
第一章 语言模型基础
人工智能·语言模型·自然语言处理
Piar1231sdafa17 小时前
基于yolo13-C3k2-RVB的洗手步骤识别与检测系统实现_1
人工智能·算法·目标跟踪
做科研的周师兄17 小时前
【MATLAB 实战】|多波段栅格数据提取部分波段均值——批量处理(NoData 修正 + 地理信息保真)_后附完整代码
前端·算法·机器学习·matlab·均值算法·分类·数据挖掘
小北方城市网17 小时前
SpringBoot 集成 MyBatis-Plus 实战(高效 CRUD 与复杂查询):简化数据库操作
java·数据库·人工智能·spring boot·后端·安全·mybatis