机器学习-朴素贝叶斯

文章目录

一、朴素贝叶斯简介

1.含义

朴素贝叶斯(Naive Bayes)是一种基于贝叶斯定理与特征条件独立假设的分类方法。它之所以被称为"朴素",是因为它假设特征之间相互独立,即一个特征的出现与另一个特征无关,这在现实世界中往往不成立,但这一假设使得朴素贝叶斯分类器变得简单且高效。

2.公式

P ( A ∣ B ) = P ( B ∣ A ) P ( A ) P ( B ) P(A∣B)=\frac{P(B|A)P(A)}{P(B)} P(A∣B)=P(B)P(B∣A)P(A)

其中:

P(A∣B) 是后验概率,即在给定数据B下,属于类别A的概率。

P(B∣A) 是似然概率,即在类别A下观测到数据B的概率。

P(A) 是先验概率,即类别A出现的概率。

P(B) 是证据因子,对于所有类别是相同的,因此不影响分类决策。

二、代码实现

下面这段代码的主要目的是使用朴素贝叶斯分类器来对鸢尾花数据集进行分类,代码实现了使用多项式朴素贝叶斯对鸢尾花数据集进行分类的基本流程,包括数据加载、预处理、模型训练、预测和性能评估,我们将其一步步拆分,进行更详细的讲解。

1.数据加载和预处理

python 复制代码
import pandas as pd

def cm_plot(y, yp):
    from sklearn.metrics import confusion_matrix
    import matplotlib.pyplot as plt

    cm = confusion_matrix(y, yp)
    plt.matshow(cm, cmap=plt.cm.Blues)
    plt.colorbar()
    for x in range(len(cm)):
        for y in range(len(cm)):
            plt.annotate(cm[x, y], xy=(y, x), horizontalalignment='center', verticalalignment='center')
            plt.ylabel('True label')
            plt.xlabel('Predicted label')
    return plt

data = pd.read_csv("iris.csv")
data_a = data.drop(columns=data.columns[0])

x = data_a.iloc[:, :-1]
y = data_a.iloc[:, -1]
  • 使用pandas读取iris.csv文件,该文件应包含鸢尾花数据集。
  • 数据集的第一列是不需要的列,因此删除。
  • 数据集被分为特征集x和目标集y,其中x包含除最后一列外的所有列,y包含最后一列。

2.切分数据集

python 复制代码
from sklearn.model_selection import train_test_split
x_train, x_test, y_train, y_test = \
    train_test_split(x, y, test_size=0.2, random_state=0)
  • 使用train_test_split函数将数据集切分为训练集和测试集,测试集占总数据的20%,随机种子设为0,以便拥有可重复性。

3.模型训练

python 复制代码
from sklearn.naive_bayes import MultinomialNB#导入朴素贝叶斯分类器
classifier = MultinomialNB(alpha=1)
classifier.fit(x_train,y_train)
  • 使用多项式朴素贝叶斯分类器对训练集进行训练。

4.性能评估

python 复制代码
#绘制混淆矩阵
from sklearn import metrics
train_pred = classifier.predict(x_train)
cm_plot(y_train,train_pred).show()
print(metrics.classification_report(y_train, train_pred))
score = classifier.score(x_train, y_train)
print(score)
  • 对训练集进行预测,得到预测结果train_pred。
  • 使用cm_plot函数绘制训练集的混淆矩阵,并进行可视化。
  • 打印分类报告,该报告提供了主要分类指标的文本报告,如精确度、召回率、F1分数等。
  • 打印训练集上的准确度分数,来评估训练集上的性能。

5.测试集预测

python 复制代码
test_pred = classifier.predict(x_test)
cm_plot(y_test,test_pred).show()
print(metrics.classification_report(y_test, test_pred))
  • 对测试集进行预测,并展示测试集的混淆矩阵,最后将分类报告打印出来。

6.详细代码

python 复制代码
import pandas as pd

def cm_plot(y, yp):
    from sklearn.metrics import confusion_matrix
    import matplotlib.pyplot as plt

    cm = confusion_matrix(y, yp)
    plt.matshow(cm, cmap=plt.cm.Blues)
    plt.colorbar()
    for x in range(len(cm)):
        for y in range(len(cm)):
            plt.annotate(cm[x, y], xy=(y, x), horizontalalignment='center', verticalalignment='center')
            plt.ylabel('True label')
            plt.xlabel('Predicted label')
    return plt

data = pd.read_csv("iris.csv")
data_a = data.drop(columns=data.columns[0])

x = data_a.iloc[:, :-1]
y = data_a.iloc[:, -1]

"""切分数据集"""
from sklearn.model_selection import train_test_split
x_train, x_test, y_train, y_test = \
    train_test_split(x, y, test_size=0.2, random_state=0)

from sklearn.naive_bayes import MultinomialNB#导入朴素贝叶斯分类器
classifier = MultinomialNB(alpha=1)
classifier.fit(x_train,y_train)

"""预测训练集"""
#绘制混淆矩阵
from sklearn import metrics
train_pred = classifier.predict(x_train)
cm_plot(y_train,train_pred).show()
print(metrics.classification_report(y_train, train_pred))
score = classifier.score(x_train, y_train)
print(score)
"""测试集预测"""
test_pred = classifier.predict(x_test)
cm_plot(y_test,test_pred).show()
print(metrics.classification_report(y_test, test_pred))

三、总结

朴素贝叶斯分类器因其简单性和高效性,在文本分类、垃圾邮件检测、情感分析等领域有着广泛的应用。但同时也有着自己的优缺点。

  • 优点
    • 简单高效:由于假设特征之间相互独立,大大简化了计算。
    • 处理缺失数据:对缺失数据不敏感,可以通过忽略该特征或使用该特征的先验概率来处理。
    • 易于实现:算法实现相对简单,易于理解和应用。
  • 缺点
    • 特征独立性假设:现实中特征之间往往存在相关性,这一假设限制了朴素贝叶斯的性能。
    • 参数估计问题:如果某个特征在训练数据中未出现,则条件概率为零,这会导致整个后验概率为零,即所谓的"零概率问题"。
相关推荐
qzhqbb3 分钟前
基于统计方法的语言模型
人工智能·语言模型·easyui
冷眼看人间恩怨28 分钟前
【话题讨论】AI大模型重塑软件开发:定义、应用、优势与挑战
人工智能·ai编程·软件开发
2401_8830410829 分钟前
新锐品牌电商代运营公司都有哪些?
大数据·人工智能
pianmian135 分钟前
python数据结构基础(7)
数据结构·算法
AI极客菌1 小时前
Controlnet作者新作IC-light V2:基于FLUX训练,支持处理风格化图像,细节远高于SD1.5。
人工智能·计算机视觉·ai作画·stable diffusion·aigc·flux·人工智能作画
阿_旭2 小时前
一文读懂| 自注意力与交叉注意力机制在计算机视觉中作用与基本原理
人工智能·深度学习·计算机视觉·cross-attention·self-attention
王哈哈^_^2 小时前
【数据集】【YOLO】【目标检测】交通事故识别数据集 8939 张,YOLO道路事故目标检测实战训练教程!
前端·人工智能·深度学习·yolo·目标检测·计算机视觉·pyqt
Power20246662 小时前
NLP论文速读|LongReward:基于AI反馈来提升长上下文大语言模型
人工智能·深度学习·机器学习·自然语言处理·nlp
数据猎手小k3 小时前
AIDOVECL数据集:包含超过15000张AI生成的车辆图像数据集,目的解决旨在解决眼水平分类和定位问题。
人工智能·分类·数据挖掘
好奇龙猫3 小时前
【学习AI-相关路程-mnist手写数字分类-win-硬件:windows-自我学习AI-实验步骤-全连接神经网络(BPnetwork)-操作流程(3) 】
人工智能·算法