使用scikit-learn中的KNN包实现对鸢尾花数据集的预测

1. 导入必要的库

首先,需要导入所需的库:

python 复制代码
import numpy as np
import pandas as pd
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score, classification_report
复制代码

2. 加载鸢尾花数据集

scikit-learn提供了方便的函数来加载鸢尾花数据集:

python 复制代码
# 加载鸢尾花数据集
iris = load_iris()
X = iris.data
y = iris.target

# 将数据集分为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

3. 数据预处理

对数据进行标准化处理,以提高KNN算法的性能:

python 复制代码
# 标准化数据
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)

4. 训练KNN模型

使用KNeighborsClassifier来训练KNN模型:

python 复制代码
# 创建KNN分类器
knn = KNeighborsClassifier(n_neighbors=3)

# 训练模型
knn.fit(X_train, y_train)

5. 进行预测并评估模型

使用测试集进行预测,并评估模型的性能:

python 复制代码
# 进行预测
y_pred = knn.predict(X_test)

# 计算准确率
accuracy = accuracy_score(y_test, y_pred)
print(f'Accuracy: {accuracy:.2f}')

# 打印分类报告
print(classification_report(y_test, y_pred, target_names=iris.target_names))

6. 使用自定义数据集

如果有一个自定义的数据集,可以按照以下步骤进行操作。假设有一个CSV文件custom_dataset.csv,其中包含特征和标签。

python 复制代码
# 加载自定义数据集
custom_data = pd.read_csv('custom_dataset.csv')

# 假设特征列为前n-1列,标签列为最后一列
X_custom = custom_data.iloc[:, :-1].values
y_custom = custom_data.iloc[:, -1].values

# 将数据集分为训练集和测试集
X_custom_train, X_custom_test, y_custom_train, y_custom_test = train_test_split(X_custom, y_custom, test_size=0.2, random_state=42)

# 标准化数据
scaler_custom = StandardScaler()
X_custom_train = scaler_custom.fit_transform(X_custom_train)
X_custom_test = scaler_custom.transform(X_custom_test)

# 创建KNN分类器并训练
knn_custom = KNeighborsClassifier(n_neighbors=3)
knn_custom.fit(X_custom_train, y_custom_train)

# 进行预测
y_custom_pred = knn_custom.predict(X_custom_test)

# 计算准确率
accuracy_custom = accuracy_score(y_custom_test, y_custom_pred)
print(f'Custom Dataset Accuracy: {accuracy_custom:.2f}')

# 如果有标签名称,可以打印分类报告
# print(classification_report(y_custom_test, y_custom_pred, target_names=[...]))
相关推荐
Mister Leon6 分钟前
机器学习Adaboost算法----SAMME算法和SAMME.R算法
算法·机器学习·r语言
zzywxc78717 分钟前
深入探讨AI在测试领域的三大核心应用:自动化测试框架、智能缺陷检测和A/B测试优化,并通过代码示例、流程图和图表详细解析其实现原理和应用场景。
运维·人工智能·低代码·架构·自动化·流程图·ai编程
zskj_zhyl23 分钟前
七彩喜智慧康养:用“适老化设计”让科技成为老人的“温柔拐杖”
大数据·人工智能·科技·机器人·生活
ARM+FPGA+AI工业主板定制专家28 分钟前
基于ARM+FPGA多通道超声信号采集与传输系统设计
linux·人工智能·fpga开发·rk3588·rk3568·codesys
wyiyiyi1 小时前
【目标检测】芯片缺陷识别中的YOLOv12模型、FP16量化、NMS调优
人工智能·yolo·目标检测·计算机视觉·数学建模·性能优化·学习方法
mit6.8242 小时前
[自动化Adapt] 回放策略 | AI模型驱动程序
运维·人工智能·自动化
爱看科技3 小时前
5G-A技术浪潮勾勒通信产业新局,微美全息加快以“5.5G+ AI”新势能深化场景应用
人工智能·5g
打马诗人5 小时前
【YOLO11】【DeepSort】【NCNN】使用YOLOv11和DeepSort进行行人目标跟踪。(基于ncnn框架,c++实现)
人工智能·算法·目标检测
倒悬于世6 小时前
基于千问2.5-VL-7B训练识别人的表情
人工智能
大知闲闲哟6 小时前
深度学习TR3周:Pytorch复现Transformer
pytorch·深度学习·transformer