朴素贝叶斯是数据挖掘分类的基础,本篇文章将介绍一下朴素贝叶斯算法
情景再现
以挑选西瓜为例,西瓜的色泽、瓜蒂、敲响声音、触感、脐部等特征都会影响到西瓜的好坏。那么我们怎么样可以挑选出一个好的西瓜呢?
分析过程
既然挑选西瓜有多个条件可供选择,那么我们可以根据已知的样本数据去获取好瓜的筛选条件,再拿测试的样本数据去匹配,最终得到好瓜的概率。然后我们再根据已知的样本数据去获取坏瓜的筛选条件,拿测试的样本数据去匹配,得到坏瓜的概率,如果好瓜的概率超过坏瓜的概率,那么我们可以认为该瓜是好瓜。
朴素贝叶斯介绍
什么是朴素贝叶斯
官话:朴素贝叶斯(naive Bayes)算法是基于贝叶斯定理与特征条件独立假设的分类方法。对于给定的训练数据集,首先基于特征条件独立假设学习输入输出的联合概率分布。然后基于此模型,对给定的输入x,利用贝叶斯定理求出后验概率最大的输出y。不同于其他分类器,朴素贝叶斯是一种基于概率理论的分类算法;
用个人的话来说:朴素贝叶斯就是基于概率大小进行分类判断的一种方式。就如同上面的西瓜分类,好瓜概率大就认定为是好瓜,否则就认定为坏瓜。
先验概率、似然概率和后验概率
先验概率:根据训练集可以经过统计初步得到的概率,就如同上文所说的好瓜或者坏瓜。标记为P(C)
似然概率:就是在先验概率基础上满足某一定条件的概率。比如当西瓜色泽为青绿色时好瓜的概率。标记为P(x|c)
后验概率:P(c|x) 即我们要求的概率,即P(好瓜|色泽=青绿)
贝叶斯的公式
公式简单介绍:就拿西瓜为例,P(色泽=青绿|好瓜) = P(色泽青绿) * P(好瓜|色泽青绿) | P(好瓜)
简单整体理解就是,先求色泽青绿的瓜中好瓜的概率然后乘以色泽青绿的瓜的数量就是色泽青绿的好瓜数量,然后除以好瓜的数量就是P(色泽=青绿|好瓜)的概率。
朴素贝叶斯的优势
- 实现简单:朴素贝叶斯基于先验概率和似然概率进行后验概率的求解,实现相对简单
- 算法较为快捷
扩展学习
求解先验和似然概率
根据大数定律,当训练集包含充足的独立同分布样本时,P(c)可通过各类样本出现的频率来进行估计。
朴素贝叶斯对条件概率分布做了条件独立性的假设。具体来说,条件独立性假设是:
这里的x(1) -> x(n) 表示各个满足的条件(比如,西瓜的色泽、瓜蒂、敲响声音、触感、脐部等特征)总而言之就是将各个条件满足的概率进行相乘后得到似然概率。
先验概率P©的极大似然估计是
其中N是样本集的总数,K是类别的总数,I表示事件,这里的意思是满足不同特征的好瓜或者坏瓜的概率。
似然概率P(x|c)的极大似然估计分为离散型和连续型两种计算方式:
离散型的计算方式如下:
其中,xi^{(j)}代表第i个样本的第j个特征;a{jl}是第j个特征可能取的第l个值。 代入原来的西瓜模型为例,有:
P(色泽= 青绿|好瓜) = I(色泽=青绿,好瓜) / I(好瓜)也就是说好瓜中色泽青绿的概率。
连续型的似然概率计算方式如下:
也就是说使用正态分布函数进行计算,求出似然概率的值。
拉普拉斯修正
我们可能会遇到这种情况,x(i) 特征的频率为0,那么基于朴素贝叶斯的公式,有似然概率为0,那么会出现计算出的后验概率为0的情况,(如下图所示) 为了避免出现这种情况,我们可以使用拉普拉斯修正算法。
拉普拉斯修正公式:
这里的N为特征类型数量(如敲响有清脆、沉闷、浊响三种,那么N = 3),也就是说P(敲响=清脆|好瓜) = (0+1)/(8+3) ≈ 0.091
代码
训练集如下:
java
package beiyesi;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.FileWriter;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.lang.reflect.Field;
import java.text.SimpleDateFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Date;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
public class BeiyesiUtil2 {
private static List<TestItem2> list = new ArrayList<>();
static {
TestItem2 item1 = new TestItem2();
item1.setColor("青绿");
item1.setRoot("蜷缩");
item1.setSound("浊响");
item1.setWenli("清晰");
item1.setQibu("凹陷");
item1.setTouch("硬滑");
item1.setDensity(0.697);
item1.setSugerRate(0.46);
item1.setGood(true);
list.add(item1);
TestItem2 item2 = new TestItem2();
item2.setColor("乌黑");
item2.setRoot("蜷缩");
item2.setSound("沉闷");
item2.setWenli("清晰");
item2.setQibu("凹陷");
item2.setTouch("硬滑");
item2.setDensity(0.774);
item2.setSugerRate(0.376);
item2.setGood(true);
list.add(item2);
TestItem2 item3 = new TestItem2();
item3.setColor("乌黑");
item3.setRoot("蜷缩");
item3.setSound("浊响");
item3.setWenli("清晰");
item3.setQibu("凹陷");
item3.setTouch("硬滑");
item3.setDensity(0.634);
item3.setSugerRate(0.264);
item3.setGood(true);
list.add(item3);
TestItem2 item4 = new TestItem2();
item4.setColor("青绿");
item4.setRoot("蜷缩");
item4.setSound("沉闷");
item4.setWenli("清晰");
item4.setQibu("凹陷");
item4.setTouch("硬滑");
item4.setDensity(0.608);
item4.setSugerRate(0.318);
item4.setGood(true);
list.add(item4);
TestItem2 item5 = new TestItem2();
item5.setColor("浅白");
item5.setRoot("蜷缩");
item5.setSound("浊响");
item5.setWenli("清晰");
item5.setQibu("凹陷");
item5.setTouch("硬滑");
item5.setDensity(0.556);
item5.setSugerRate(0.215);
item5.setGood(true);
list.add(item5);
TestItem2 item6 = new TestItem2();
item6.setColor("青绿");
item6.setRoot("稍蜷");
item6.setSound("浊响");
item6.setWenli("清晰");
item6.setQibu("稍凹");
item6.setTouch("软粘");
item6.setDensity(0.403);
item6.setSugerRate(0.237);
item6.setGood(true);
list.add(item6);
TestItem2 item7 = new TestItem2();
item7.setColor("乌黑");
item7.setRoot("稍蜷");
item7.setSound("浊响");
item7.setWenli("稍糊");
item7.setQibu("稍凹");
item7.setTouch("软粘");
item7.setDensity(0.481);
item7.setSugerRate(0.149);
item7.setGood(true);
list.add(item7);
TestItem2 item8 = new TestItem2();
item8.setColor("乌黑");
item8.setRoot("稍蜷");
item8.setSound("浊响");
item8.setWenli("清晰");
item8.setQibu("稍凹");
item8.setTouch("硬滑");
item8.setDensity(0.437);
item8.setSugerRate(0.211);
item8.setGood(true);
list.add(item8);
TestItem2 item9 = new TestItem2();
item9.setColor("乌黑");
item9.setRoot("稍蜷");
item9.setSound("沉闷");
item9.setWenli("稍糊");
item9.setQibu("稍凹");
item9.setTouch("硬滑");
item9.setDensity(0.666);
item9.setSugerRate(0.091);
item9.setGood(false);
list.add(item9);
TestItem2 item10 = new TestItem2();
item10.setColor("青绿");
item10.setRoot("硬挺");
item10.setSound("清脆");
item10.setWenli("清晰");
item10.setQibu("平坦");
item10.setTouch("软粘");
item10.setDensity(0.243);
item10.setSugerRate(0.267);
item10.setGood(false);
list.add(item10);
TestItem2 item11 = new TestItem2();
item11.setColor("浅白");
item11.setRoot("硬挺");
item11.setSound("清脆");
item11.setWenli("模糊");
item11.setQibu("平坦");
item11.setTouch("硬滑");
item11.setDensity(0.245);
item11.setSugerRate(0.057);
item11.setGood(false);
list.add(item11);
TestItem2 item12 = new TestItem2();
item12.setColor("浅白");
item12.setRoot("蜷缩");
item12.setSound("浊响");
item12.setWenli("模糊");
item12.setQibu("平坦");
item12.setTouch("软粘");
item12.setDensity(0.343);
item12.setSugerRate(0.099);
item12.setGood(false);
list.add(item12);
TestItem2 item13 = new TestItem2();
item13.setColor("青绿");
item13.setRoot("稍蜷");
item13.setSound("浊响");
item13.setWenli("稍糊");
item13.setQibu("凹陷");
item13.setTouch("硬滑");
item13.setDensity(0.639);
item13.setSugerRate(0.161);
item13.setGood(false);
list.add(item13);
TestItem2 item14 = new TestItem2();
item14.setColor("浅白");
item14.setRoot("稍蜷");
item14.setSound("沉闷");
item14.setWenli("稍糊");
item14.setQibu("凹陷");
item14.setTouch("硬滑");
item14.setDensity(0.657);
item14.setSugerRate(0.198);
item14.setGood(false);
list.add(item14);
TestItem2 item15 = new TestItem2();
item15.setColor("乌黑");
item15.setRoot("稍蜷");
item15.setSound("浊响");
item15.setWenli("清晰");
item15.setQibu("稍凹");
item15.setTouch("软粘");
item15.setDensity(0.36);
item15.setSugerRate(0.37);
item15.setGood(false);
list.add(item15);
TestItem2 item16 = new TestItem2();
item16.setColor("浅白");
item16.setRoot("蜷缩");
item16.setSound("浊响");
item16.setWenli("模糊");
item16.setQibu("平坦");
item16.setTouch("硬滑");
item16.setDensity(0.593);
item16.setSugerRate(0.042);
item16.setGood(false);
list.add(item16);
TestItem2 item17 = new TestItem2();
item17.setColor("青绿");
item17.setRoot("蜷缩");
item17.setSound("沉闷");
item17.setWenli("稍糊");
item17.setQibu("稍凹");
item17.setTouch("硬滑");
item17.setDensity(0.719);
item17.setSugerRate(0.103);
item17.setGood(false);
list.add(item17);
}
/**
* 朴素贝叶斯分类
* @param testJson
* @return 贝叶斯分类
*/
@SuppressWarnings("unchecked")
public static void calc(TestItem2 ...itemList){
//朴素贝叶斯分类
try {
String date = new SimpleDateFormat("yyyyMMdd").format(new Date());
File f = new File("D:/decisionTree/dataset/beiyesi_datas_" + date + ".dat");
//读取/写入数据集
List<TestItem2> items = new ArrayList<>();
if (f.exists()) {
try(ObjectInputStream ios = new ObjectInputStream(new FileInputStream(f))) {
Object obj = ios.readObject();
if (obj instanceof List) {
items = (List<TestItem2>)obj;
}
for (TestItem2 item:itemList) {
System.out.println(item.toString());
}
} catch (FileNotFoundException e) {
// TODO Auto-generated catch block
e.printStackTrace();
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
} catch (ClassNotFoundException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
} else {
items = list;
try(ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(f))) {
oos.writeObject(list);
oos.flush();
} catch (FileNotFoundException e) {
// TODO Auto-generated catch block
e.printStackTrace();
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
System.out.println("---------start--------");
long timepre = System.currentTimeMillis();
calc(items,itemList);
long timeAfter = System.currentTimeMillis() - timepre;
System.out.println("---------end " + timeAfter + "--------");
} catch (IllegalArgumentException | IllegalAccessException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
/**
* 朴素贝叶斯分类
* @param item
* @return
* @throws IllegalAccessException
* @throws IllegalArgumentException
*/
private static void calc(List<TestItem2> items,TestItem2[] testItems)
throws IllegalArgumentException, IllegalAccessException {
//使用递归进行朴素贝叶斯分类
boolean[] isValid= new boolean[testItems.length];
for (int i = 0; i < testItems.length; i++) {
isValid[i] = calc(items,testItems[i]);
}
String date = new SimpleDateFormat("yyyyMMddHHmmss").format(new Date());
File f = new File("D:/decisionTree/dataset/result_" + date + ".dat");
//结果保存
if (!f.exists()){
try(BufferedWriter writer = new BufferedWriter(new FileWriter(f))) {
for (int i = 0; i < testItems.length; i++) {
writer.write("测试数据为:");
writer.newLine();
writer.write(testItems[i].toString());
writer.newLine();
writer.write("结果为:" + (isValid[i] ? "好瓜":"坏瓜"));
writer.newLine();
}
writer.flush();
} catch (FileNotFoundException e) {
// TODO Auto-generated catch block
e.printStackTrace();
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
} finally {
System.out.println("结果已保存");
}
} else {
System.out.println("结果表已存在");
}
}
/**
*
* @param items
* @param condition
* @return
* @throws IllegalAccessException
* @throws IllegalArgumentException
*/
private static boolean calc(
List<TestItem2> items,
TestItem2 trainItem
)
throws IllegalArgumentException, IllegalAccessException {
//利用反射机制获取某一条件的条目数量
Class<?> cls = TestItem2.class;
Field[] fs = cls.getDeclaredFields();
double[] resSim = {1.0,1.0};
//计算连续属性的贝叶斯
Map<String,Double[]> resMap = calcLianxuVar(items,trainItem);
for (Field f:fs) {
if (f.getType().getName().contains("Double")
|| f.getType().getName().contains("Integer")) {
for (int i = 0; i < resSim.length;i++)
resSim[i] *= resMap.get(f.getName())[i];
}
}
System.out.println(Arrays.toString(resSim));
int totalCount[] = {0,0}; //0-好瓜,1-坏瓜
//计算其他离散属性的项集数量
//计算离散数据当前分类的总数
for (TestItem2 item: items) {
boolean goodRes = item.isGood();
if (goodRes) {
totalCount[0]++;
} else {
totalCount[1]++;
}
}
System.out.println(Arrays.toString(totalCount));
//计算离散数据不同分类的数量
Map<String,Double[]> res = new ConcurrentHashMap<>();
for(TestItem2 item: items) {
for(Field f:fs) {
f.setAccessible(true);
Object val = f.get(item);
if (val instanceof String) {
String str = f.getName() + "_" + String.valueOf(val);
boolean goodRes = item.isGood();
if (res.get(str) == null) {
if (goodRes) {
res.put(str, new Double[]{1.0,0.0});
} else {
res.put(str, new Double[]{0.0,1.0});
}
} else {
Double[] arr = res.get(str);
if (goodRes) {
arr[0] = arr[0] + 1;
} else {
arr[1] = arr[1] + 1;
}
res.put(str,arr);
}
}
}
}
//进行当前样本的匹配度计算
for (int i = 0; i < 2; i++) {
for (Field f:fs) {
f.setAccessible(true);
Object val = f.get(trainItem);
//求分类数量
List<String> typeList = new ArrayList<>();
for (String key: res.keySet()) {
if (key.contains(f.getName()) &&
!typeList.contains(key)) {
typeList.add(key);
}
}
if (val instanceof String) {
String str = f.getName() + "_" + String.valueOf(val);
//使用拉普拉斯修正,调整值为0的数据
resSim[i] *= (res.get(str)[i] + 1)/ ((totalCount[i] + typeList.size()) * 1.0);
}
}
}
return resSim[0] >= resSim[1] ? true:false;
}
/**
* 计算连续属性的概率密度函数(正态分布)
* @param items
* @param trainItem
* @return
* @throws IllegalAccessException
* @throws IllegalArgumentException
*/
private static Map<String,Double[]> calcLianxuVar(List<TestItem2> items, TestItem2 trainItem)
throws IllegalArgumentException, IllegalAccessException {
//计算均值和方差,只计算一次
Double avg[],fangcha[],counter[];//0-好瓜,1-坏瓜
Map<String,Double[]> avgMap = new ConcurrentHashMap<>();
Map<String,Double[]> resMap = new ConcurrentHashMap<>();
//训练模型
//计算平均值
Class<?> cls = TestItem2.class;
Field[] fs = cls.getDeclaredFields();
for (Field f: fs) {
if (!(f.getType().getName().contains("Double")
|| f.getType().getName().contains("Integer"))){
continue;
}
//重新分配内存
avg = new Double[2];
counter = new Double[2];
for (int i = 0; i < 2;i++) {
avg[i] = 0.0;
counter[i] = 0.0;
}
for (TestItem2 item: items) {
f.setAccessible(true);
boolean goodRes = item.isGood();
if (goodRes) {
avg[0] += (Double)f.get(item);
counter[0]++;
} else {
avg[1] += (Double)f.get(item);
counter[1]++;
}
}
for (int i = 0;i < 2; i++) {
avg[i] = avg[i] / counter[i];
}
avgMap.put(f.getName(), avg);
}
//计算方差
for (Field f: fs) {
if (!(f.getType().getName().contains("Double")
|| f.getType().getName().contains("Integer"))){
continue;
}
//重新分配内存
fangcha = new Double[2];
counter = new Double[2];
for (int i = 0; i < 2;i++) {
fangcha[i] = 0.0;
counter[i] = 0.0;
}
for (TestItem2 item: items) {
f.setAccessible(true);
boolean goodRes = item.isGood();
if (goodRes) {
fangcha[0] +=
Math.pow(((Double)f.get(item) - avgMap.get(f.getName())[0]), 2);
counter[0]++;
} else {
fangcha[1] +=
Math.pow(((Double)f.get(item) - avgMap.get(f.getName())[1]), 2);
counter[1]++;
}
}
for (int i =0;i < 2;i++) {
fangcha[i] = fangcha[i] / ((counter[i] - 1) * 1.0);
}
resMap.put(f.getName(), fangcha);
}
Field[] fields = trainItem.getClass().getDeclaredFields();
for(Map.Entry<String, Double[]> set:resMap.entrySet()) {
for (Field f: fields) {
f.setAccessible(true);
if (f.getName().equalsIgnoreCase(set.getKey())) {
Double val = (Double) f.get(trainItem);
Double[] res = set.getValue();
for (int i = 0; i < res.length;i++) {
res[i] = 1/((Math.sqrt(2 * Math.PI) * Math.sqrt(set.getValue()[i])))
* Math.exp(-1 * Math.pow((val - avgMap.get(set.getKey())[i]),2) / (2 * set.getValue()[i] * 1.0));
}
resMap.put(f.getName(), res);
}
}
}
return resMap;
}
public static void main(String[] args) {
TestItem2 item1 = new TestItem2();
item1.setColor("青绿");
item1.setRoot("蜷缩");
item1.setSound("浊响");
item1.setWenli("清晰");
item1.setQibu("凹陷");
item1.setTouch("硬滑");
item1.setDensity(0.697);
item1.setSugerRate(0.46);
item1.setGood(true);
list.add(item1);
calc(item1);
}
}
TestItem 类
java
package beiyesi;
import java.io.Serializable;
/**
* 好瓜坏瓜
* @author zygswo
*
*/
public class TestItem2 implements Serializable{
@Override
public String toString() {
return "TestItem2 [color=" + color + ", root=" + root + ", sound=" + sound + ", wenli=" + wenli
+ ", qibu=" + qibu + ", touch=" + touch + ", density=" + density + ", sugerRate=" + sugerRate
+ ", isGood=" + isGood + "]";
}
private static final long serialVersionUID = 1L;
public String getRootStyle() {
return color;
}
public void setColor(String color) {
this.color = color;
}
public String getRoot() {
return root;
}
public void setRoot(String root) {
this.root = root;
}
public String getSound() {
return sound;
}
public void setSound(String sound) {
this.sound = sound;
}
public String getWenli() {
return wenli;
}
public void setWenli(String wenli) {
this.wenli = wenli;
}
public String getQibu() {
return qibu;
}
public void setQibu(String qibu) {
this.qibu = qibu;
}
public String getTouch() {
return touch;
}
public void setTouch(String touch) {
this.touch = touch;
}
public Double getDensity() {
return density;
}
public void setDensity(double density) {
this.density = density;
}
public Double getSugerRate() {
return sugerRate;
}
public void setSugerRate(double sugerRate) {
this.sugerRate = sugerRate;
}
public String getColor() {
return color;
}
public boolean isGood() {
return isGood;
}
public void setGood(boolean isGood) {
this.isGood = isGood;
}
/**
* 色泽
*/
private String color;
/**
* 根蒂
*/
private String root;
/**
* 敲声
*/
private String sound;
/**
* 纹理
*/
private String wenli;
/**
* 脐部
*/
private String qibu;
/**
* 触感
*/
private String touch;
/**
* 密度
*/
private Double density;
/**
* 含糖量
*/
private Double sugerRate;
/**
* 是否好瓜
*/
private boolean isGood;
}