自定义数据集 ,使用朴素贝叶斯对其进行分类

python 复制代码
import numpy as np
import matplotlib.pyplot as plt

class1_points = np.array([[1.9, 1.2],
                          [1.5, 2.1],
                          [1.9, 0.5],
                          [1.5, 0.9],
                          [0.9, 1.2],
                          [1.1, 1.7],
                          [1.4, 1.1]])

class2_points = np.array([[3.2, 3.2],
                          [3.7, 2.9],
                          [3.2, 2.6],
                          [1.7, 3.3],
                          [3.4, 2.6],
                          [4.1, 2.3],
                          [3.0, 2.9]])

class3_points = np.array([[3.3, 1.2],
                          [3.8, 0.9],
                          [3.3, 0.6],
                          [2.8, 1.3],
                          [3.5, 0.6],
                          [4.2, 0.3],
                          [3.1, 0.9]])

X=np.concatenate((class1_points,class2_points,class3_points),axis=0)

Y=np.concatenate((np.zeros(len(class1_points)),np.ones(len(class1_points)),np.ones(len(class1_points))+1),axis=0)

print(Y)

prior_prob=[np.sum(Y==0)/len(Y),np.sum(Y==1)/len(Y),np.sum(Y==2)/len(Y)]

class_u=[np.mean(X[Y==0],axis=0),np.mean(X[Y==1],axis=0),np.mean(X[Y==2],axis=0)]

class_cov=[np.cov(X[Y==0],rowvar=False),np.cov(X[Y==1],rowvar=False),np.cov(X[Y==2],rowvar=False)]

def pdf(x, mean, cov):
    n = len(mean)
    coff = 1 / (2 * np.pi) ** (n / 2) * np.sqrt(np.linalg.det(cov))
    exponent = np.exp(-(1 / 2) * np.dot(np.dot((x - mean).T, np.linalg.inv(cov)), (x - mean)))
    return coff * exponent

xx, yy = np.meshgrid(np.arange(0, 5, 0.05), np.arange(0, 4, 0.05))

grid_points = np.c_[xx.ravel(), yy.ravel()]

grid_label = []

for point in grid_points:
    poster_prob = []
    for i in range(3):
        likelihood = pdf(point, class_u[i], class_cov[i])
        poster_prob.append(prior_prob[i] * likelihood)
    pre_class = np.argmax(poster_prob)
    grid_label.append(pre_class)

grid_label = np.array(grid_label)

pre_grid_label = grid_label.reshape(xx.shape)

plt.scatter(class1_points[:,0],class1_points[:,1],c="blue",label="class 1")
plt.scatter(class2_points[:,0],class2_points[:,1],c="red",label="class 2")
plt.scatter(class3_points[:,0],class3_points[:,1],c="yellow",label="class 3")

plt.legend()

contour=plt.contour(xx,yy,pre_grid_label,colors='green')

plt.show()
相关推荐
冷雨夜中漫步5 小时前
Python快速入门(6)——for/if/while语句
开发语言·经验分享·笔记·python
郝学胜-神的一滴6 小时前
深入解析Python字典的继承关系:从abc模块看设计之美
网络·数据结构·python·程序人生
百锦再6 小时前
Reactive编程入门:Project Reactor 深度指南
前端·javascript·python·react.js·django·前端框架·reactjs
喵手7 小时前
Python爬虫实战:旅游数据采集实战 - 携程&去哪儿酒店机票价格监控完整方案(附CSV导出 + SQLite持久化存储)!
爬虫·python·爬虫实战·零基础python爬虫教学·采集结果csv导出·旅游数据采集·携程/去哪儿酒店机票价格监控
2501_944934737 小时前
高职大数据技术专业,CDA和Python认证优先考哪个?
大数据·开发语言·python
helloworldandy8 小时前
使用Pandas进行数据分析:从数据清洗到可视化
jvm·数据库·python
肖永威9 小时前
macOS环境安装/卸载python实践笔记
笔记·python·macos
TechWJ9 小时前
PyPTO编程范式深度解读:让NPU开发像写Python一样简单
开发语言·python·cann·pypto
枷锁—sha9 小时前
【SRC】SQL注入WAF 绕过应对策略(二)
网络·数据库·python·sql·安全·网络安全
abluckyboy10 小时前
Java 实现求 n 的 n^n 次方的最后一位数字
java·python·算法