机器学习-朴素贝叶斯

文章目录

一、朴素贝叶斯简介

1.含义

朴素贝叶斯(Naive Bayes)是一种基于贝叶斯定理与特征条件独立假设的分类方法。它之所以被称为"朴素",是因为它假设特征之间相互独立,即一个特征的出现与另一个特征无关,这在现实世界中往往不成立,但这一假设使得朴素贝叶斯分类器变得简单且高效。

2.公式

P ( A ∣ B ) = P ( B ∣ A ) P ( A ) P ( B ) P(A∣B)=\frac{P(B|A)P(A)}{P(B)} P(A∣B)=P(B)P(B∣A)P(A)

其中:

P(A∣B) 是后验概率,即在给定数据B下,属于类别A的概率。

P(B∣A) 是似然概率,即在类别A下观测到数据B的概率。

P(A) 是先验概率,即类别A出现的概率。

P(B) 是证据因子,对于所有类别是相同的,因此不影响分类决策。

二、代码实现

下面这段代码的主要目的是使用朴素贝叶斯分类器来对鸢尾花数据集进行分类,代码实现了使用多项式朴素贝叶斯对鸢尾花数据集进行分类的基本流程,包括数据加载、预处理、模型训练、预测和性能评估,我们将其一步步拆分,进行更详细的讲解。

1.数据加载和预处理

python 复制代码
import pandas as pd

def cm_plot(y, yp):
    from sklearn.metrics import confusion_matrix
    import matplotlib.pyplot as plt

    cm = confusion_matrix(y, yp)
    plt.matshow(cm, cmap=plt.cm.Blues)
    plt.colorbar()
    for x in range(len(cm)):
        for y in range(len(cm)):
            plt.annotate(cm[x, y], xy=(y, x), horizontalalignment='center', verticalalignment='center')
            plt.ylabel('True label')
            plt.xlabel('Predicted label')
    return plt

data = pd.read_csv("iris.csv")
data_a = data.drop(columns=data.columns[0])

x = data_a.iloc[:, :-1]
y = data_a.iloc[:, -1]
  • 使用pandas读取iris.csv文件,该文件应包含鸢尾花数据集。
  • 数据集的第一列是不需要的列,因此删除。
  • 数据集被分为特征集x和目标集y,其中x包含除最后一列外的所有列,y包含最后一列。

2.切分数据集

python 复制代码
from sklearn.model_selection import train_test_split
x_train, x_test, y_train, y_test = \
    train_test_split(x, y, test_size=0.2, random_state=0)
  • 使用train_test_split函数将数据集切分为训练集和测试集,测试集占总数据的20%,随机种子设为0,以便拥有可重复性。

3.模型训练

python 复制代码
from sklearn.naive_bayes import MultinomialNB#导入朴素贝叶斯分类器
classifier = MultinomialNB(alpha=1)
classifier.fit(x_train,y_train)
  • 使用多项式朴素贝叶斯分类器对训练集进行训练。

4.性能评估

python 复制代码
#绘制混淆矩阵
from sklearn import metrics
train_pred = classifier.predict(x_train)
cm_plot(y_train,train_pred).show()
print(metrics.classification_report(y_train, train_pred))
score = classifier.score(x_train, y_train)
print(score)
  • 对训练集进行预测,得到预测结果train_pred。
  • 使用cm_plot函数绘制训练集的混淆矩阵,并进行可视化。
  • 打印分类报告,该报告提供了主要分类指标的文本报告,如精确度、召回率、F1分数等。
  • 打印训练集上的准确度分数,来评估训练集上的性能。

5.测试集预测

python 复制代码
test_pred = classifier.predict(x_test)
cm_plot(y_test,test_pred).show()
print(metrics.classification_report(y_test, test_pred))
  • 对测试集进行预测,并展示测试集的混淆矩阵,最后将分类报告打印出来。

6.详细代码

python 复制代码
import pandas as pd

def cm_plot(y, yp):
    from sklearn.metrics import confusion_matrix
    import matplotlib.pyplot as plt

    cm = confusion_matrix(y, yp)
    plt.matshow(cm, cmap=plt.cm.Blues)
    plt.colorbar()
    for x in range(len(cm)):
        for y in range(len(cm)):
            plt.annotate(cm[x, y], xy=(y, x), horizontalalignment='center', verticalalignment='center')
            plt.ylabel('True label')
            plt.xlabel('Predicted label')
    return plt

data = pd.read_csv("iris.csv")
data_a = data.drop(columns=data.columns[0])

x = data_a.iloc[:, :-1]
y = data_a.iloc[:, -1]

"""切分数据集"""
from sklearn.model_selection import train_test_split
x_train, x_test, y_train, y_test = \
    train_test_split(x, y, test_size=0.2, random_state=0)

from sklearn.naive_bayes import MultinomialNB#导入朴素贝叶斯分类器
classifier = MultinomialNB(alpha=1)
classifier.fit(x_train,y_train)

"""预测训练集"""
#绘制混淆矩阵
from sklearn import metrics
train_pred = classifier.predict(x_train)
cm_plot(y_train,train_pred).show()
print(metrics.classification_report(y_train, train_pred))
score = classifier.score(x_train, y_train)
print(score)
"""测试集预测"""
test_pred = classifier.predict(x_test)
cm_plot(y_test,test_pred).show()
print(metrics.classification_report(y_test, test_pred))

三、总结

朴素贝叶斯分类器因其简单性和高效性,在文本分类、垃圾邮件检测、情感分析等领域有着广泛的应用。但同时也有着自己的优缺点。

  • 优点
    • 简单高效:由于假设特征之间相互独立,大大简化了计算。
    • 处理缺失数据:对缺失数据不敏感,可以通过忽略该特征或使用该特征的先验概率来处理。
    • 易于实现:算法实现相对简单,易于理解和应用。
  • 缺点
    • 特征独立性假设:现实中特征之间往往存在相关性,这一假设限制了朴素贝叶斯的性能。
    • 参数估计问题:如果某个特征在训练数据中未出现,则条件概率为零,这会导致整个后验概率为零,即所谓的"零概率问题"。
相关推荐
Mr.Winter`4 分钟前
运动规划实战案例 | 基于四叉树分解的路径规划(附ROS C++/Python仿真)
人工智能·机器人·自动驾驶·ros·计算机图形学·ros2·路径规划
神经星星5 分钟前
SEER只是开始?美国NIH发文禁止中国用户访问生物医学核心数据,国产数据库已就位
人工智能·机器学习·开源
Helios@5 分钟前
BN测试和训练时有什么不同, 在测试时怎么使用?
人工智能·深度学习·机器学习
HelloDam6 分钟前
leetcode28.找出字符串中第一个匹配项的下标,KMP算法保姆级教程(带动图)
java·后端·算法
EasyGBS10 分钟前
国标GB28181视频平台EasyCVR顺应智慧农业自动化趋势,打造大棚实时视频监控防线
大数据·网络·人工智能·安全·音视频
智驱力人工智能12 分钟前
打造船岸“5G+AI”智能慧眼 智驱力赋能客船数智管理
人工智能·5g·智能驾驶·视觉分析·智慧传播·智慧海防·智能巡航
不是编程家18 分钟前
优选算法第七讲:分治
算法
LuckyLay25 分钟前
LeetCode算法题(Go语言实现)_36
算法·leetcode·golang
S01d13r28 分钟前
LeetCode 解题思路 33(Hot 100)
javascript·算法·leetcode
cskywit28 分钟前
CNN注意力机制的进化史:深度解析10种注意力模块如何重塑卷积神经网络
人工智能·神经网络·cnn