[机器学习]04-基于K近邻(KNN)的鸢尾花数据集分类

该方法使用了k近邻(KNN)密度估计计算概率密度

  • :近邻数

  • :总样本数

  • :以x为中心、到第k个近邻的距离 dk(x)dk​(x) 为半径的超球体体积

程序代码:

python 复制代码
import math
import random
import numpy as np
from matplotlib import pyplot as plt
from matplotlib.ticker import MultipleLocator

data_dict = {}
train_data = {}
test_data = {}

with open('Iris数据txt版.txt', 'r') as file:
    for line in file:
        line = line.strip()
        data = line.split('\t')
        if len(data) >= 3:
            try:
                category = data[0]
                attribute1 = eval(data[1])
                attribute2 = eval(data[2])

                if category not in data_dict:
                    data_dict[category] = {'Length': [], 'Width': []}

                data_dict[category]['Length'].append(attribute1)
                data_dict[category]['Width'].append(attribute2)
            except ValueError:
                print(f"Invalid data in line: {line}")
                continue
for category, attributes in data_dict.items():
    print(f'种类: {category}')
    print(len(attributes["Length"]))
    print(len(attributes["Width"]))
    print(f'属性1: {attributes["Length"]}')
    print(f'属性2: {attributes["Width"]}')

for category, attributes in data_dict.items():
    lengths = attributes['Length']
    widths = attributes['Width']
    train_indices = random.sample(range(len(lengths)), 45)
    test_indices = [i for i in range(len(lengths)) if i not in train_indices]

    train_data[category] = {
        'Length': [lengths[i] for i in train_indices],
        'Width': [widths[i] for i in train_indices]
    }

    test_data[category] = {
        'Length': [lengths[i] for i in test_indices],
        'Width': [widths[i] for i in test_indices]
    }

prior_rate = 1.0/len(data_dict)

print(len(train_data['1']['Length']))
print(train_data['1'])
print(len(test_data['1']['Length']))
print(test_data['1'])
print(len(train_data['2']['Length']))
print(train_data['2'])
print(len(test_data['2']['Length']))
print(test_data['2'])
print(len(train_data['3']['Length']))
print(train_data['3'])
print(len(test_data['3']['Length']))
print(test_data['3'])

'''
用传统方法k近邻计算概率密度
'''
kn = 7   #Kn = sqrt(n)
n = 45

length_range = np.around(np.linspace(4, 8, 40), 1)
width_range = np.around(np.linspace(2, 4.5, 25), 1)
length_mesh, width_mesh = np.meshgrid(length_range, width_range)

distances_dict = {}
for length in length_range:
    for width in width_range:
        distances = []
        for category1, attributes1 in train_data.items():
            lengths = attributes1['Length']
            widths = attributes1['Width']
            distances_to_points = np.sqrt((np.array(lengths) - length) ** 2 + (np.array(widths) - width) ** 2)
            n_shortest_distance = np.partition(sorted(distances_to_points,reverse=True), kn-1)[kn-1]
            distances.append(kn/(n*math.pi*math.pow(n_shortest_distance,2)))
        distances_dict[(length, width)] = distances

for key, value in distances_dict.items():
    print(f'坐标点: {key}, 概率密度: {value}')

right = 0
all = 0

for category2,data2 in test_data.items():
    print(category2,data2)
    for i,j in zip(data2['Length'],data2['Width']):
        print(i,j)
        for key, value in distances_dict.items():
            if key[0] == i and key[1] == j:
                print(value)
                predict = value.index(max(value))+1
                print(category2,predict)
                all += 1
                if int(category2) == predict:
                    right += 1

print("正确率:",right/all)

运行结果:

种类: 1

50

50

属性1: [5.1, 4.9, 4.7, 4.6, 5.0, 5.4, 4.6, 5.0, 4.4, 4.9, 5.4, 4.8, 4.8, 4.3, 5.8, 5.7, 5.4, 5.1, 5.7, 5.1, 5.4, 5.1, 4.6, 5.1, 4.8, 5.0, 5.0, 5.2, 5.2, 4.7, 4.8, 5.4, 5.2, 5.5, 4.9, 5.0, 5.5, 4.9, 4.4, 5.1, 5.0, 4.5, 4.4, 5.0, 5.1, 4.8, 5.1, 4.6, 5.3, 5.0]

属性2: [3.5, 3.0, 3.2, 3.1, 3.6, 3.9, 3.4, 3.4, 2.9, 3.1, 3.7, 3.4, 3.0, 3.0, 4.0, 4.4, 3.9, 3.5, 3.8, 3.8, 3.4, 3.7, 3.6, 3.3, 3.4, 3.0, 3.4, 3.5, 3.4, 3.2, 3.1, 3.4, 4.1, 4.2, 3.1, 3.2, 3.5, 3.6, 3.0, 3.4, 3.5, 2.3, 3.2, 3.5, 3.8, 3.0, 3.8, 3.2, 3.7, 3.3]

种类: 2

50

50

属性1: [7.0, 6.4, 6.9, 5.5, 6.5, 5.7, 6.3, 4.9, 6.6, 5.2, 5.0, 5.9, 6.0, 6.1, 5.6, 6.7, 5.6, 5.8, 6.2, 5.6, 5.9, 6.1, 6.3, 6.1, 6.4, 6.6, 6.8, 6.7, 6.0, 5.7, 5.5, 5.5, 5.8, 6.0, 5.4, 6.0, 6.7, 6.3, 5.6, 5.5, 5.5, 6.1, 5.8, 5.0, 5.6, 5.7, 5.7, 6.2, 5.1, 5.7]

属性2: [3.2, 3.2, 3.1, 2.3, 2.8, 2.8, 3.3, 2.4, 2.9, 2.7, 2.0, 3.0, 2.2, 2.9, 2.9, 3.1, 3.0, 2.7, 2.2, 2.5, 3.2, 2.8, 2.5, 2.8, 2.9, 3.0, 2.8, 3.0, 2.9, 2.6, 2.4, 2.4, 2.7, 2.7, 3.0, 3.4, 3.1, 2.3, 3, 2.5, 2.6, 3.0, 2.6, 2.3, 2.7, 3.0, 2.9, 2.9, 2.5, 2.8]

种类: 3

50

50

属性1: [6.3, 5.8, 7.1, 6.3, 6.5, 7.6, 4.9, 7.3, 6.7, 7.2, 6.5, 6.4, 6.8, 5.7, 5.8, 6.4, 6.5, 7.7, 7.7, 6.0, 6.9, 5.6, 7.7, 6.3, 6.7, 7.2, 6.2, 6.1, 6.4, 7.2, 7.4, 7.9, 6.4, 6.3, 6.1, 7.7, 6.3, 6.4, 6.0, 6.9, 6.7, 6.9, 5.8, 6.8, 6.7, 6.7, 6.3, 6.5, 6.2, 5.9]

属性2: [3.3, 2.7, 3.0, 2.9, 3.0, 3.0, 2.5, 2.9, 2.5, 3.6, 3.2, 2.7, 3.0, 2.5, 2.8, 3.2, 3.0, 3.8, 2.6, 2.2, 3.2, 2.8, 2.8, 2.7, 3.3, 3.2, 2.8, 3.0, 2.8, 3.0, 2.8, 3.8, 2.8, 2.8, 2.6, 3.0, 3.4, 3.1, 3.0, 3.1, 3.1, 3.1, 2.7, 3.2, 3.3, 3, 2.5, 3, 3.4, 3]

1 {'Length': [4.6, 5.8, 5.1, 5.5, 5.0], 'Width': [3.4, 4.0, 3.8, 3.5, 3.3]}

4.6 3.4

0.6189358898018181, 0.03614224174025197, 0.020545589702964818

1 1

5.8 4.0

0.14563197407101539, 0.04542648732490388, 0.04542648732490388

1 1

5.1 3.8

0.5501652353793919, 0.04058595998700429, 0.0301919946244788

1 1

5.5 3.5

0.38088362449342483, 0.12076797849791507, 0.07617672489868499

1 1

5.0 3.3

2 {'Length': [6.0, 5.6, 5.5, 5.7, 5.7], 'Width': [2.2, 3.0, 2.6, 3.0, 2.8]}

6.0 2.2

5.6 3.0

0.12076797849791528, 0.5501652353793887, 0.12378717796036282

2 2

5.5 2.6

0.076176724898685, 0.6189358898018155, 0.1207679784979152

2 2

5.7 3.0

0.09522090612335607, 0.5501652353793919, 0.17074093511774208

2 2

5.7 2.8

0.0728159870355077, 1.2378717796036336, 0.24757435592072652

2 2

3 {'Length': [5.8, 6.4, 6.9, 5.8, 6.5], 'Width': [2.7, 3.1, 3.1, 2.7, 3]}

5.8 2.7

0.05501652353793919, 1.2378717796036283, 0.2912639481420303

3 2

6.4 3.1

0.030946794490090756, 0.5501652353793919, 0.5501652353793919

3 2

6.9 3.1

0.016957147665803148, 0.4951487118414499, 0.9902974236829026

3 3

5.8 2.7

0.05501652353793919, 1.2378717796036283, 0.2912639481420303

3 2

6.5 3

0.02552312947636352, 0.9902974236829026, 0.9902974236829043

3 3

正确率: 0.7692307692307693

相关推荐
魔力之心13 分钟前
actuary notes[1]
人工智能·概率
Fine姐24 分钟前
数据挖掘2.3-2.5:梯度,梯度下降以及凸性
人工智能·数据挖掘
瓦香钵钵鸡1 小时前
机器学习通关秘籍|Day 03:决策树、随机森林与线性回归
决策树·随机森林·机器学习·线性回归·最小二乘法·损失函数·信息熵
2501_924730611 小时前
智慧城管复杂人流场景下识别准确率↑32%:陌讯多模态感知引擎实战解析
大数据·人工智能·算法·计算机视觉·目标跟踪·视觉检测·边缘计算
CONDIMENTTTT1 小时前
[机器学习]05-基于Fisher线性判别的鸢尾花数据集分类
人工智能·分类·数据挖掘
Kingfar_11 小时前
智能移动终端导航APP用户体验研究案例分享
人工智能·算法·人机交互·ux·用户界面·用户体验
dlraba8022 小时前
机器学习-----SVM(支持向量机)算法简介
算法·机器学习·支持向量机
攻城狮7号2 小时前
小米开源大模型 MiDashengLM-7B:不仅是“听懂”,更能“理解”声音
人工智能·midashenglm-7b·小米开源大模型·声音理解大模型
程序边界2 小时前
AI鉴伪技术:守护数字时代的真实性防线
人工智能
bryant_meng2 小时前
【DeepID】《Deep Learning Face Representation from Predicting 10,000 Classes》
人工智能·深度学习·人脸识别·verification·identification