MATLAB基于BNT工具箱的多输入分类预测

1. 环境准备和数据加载

matlab 复制代码
% 添加BNT工具箱到路径
addpath(genpath('bnt'));

% 生成示例数据
% 假设有3个输入特征和1个分类输出
n_samples = 1000;
n_features = 3;
n_classes = 3;

% 生成特征数据
X = randn(n_samples, n_features);
% 生成类别标签(基于特征的某种关系)
true_weights = [2, -1, 0.5];
y_scores = X * true_weights' + randn(n_samples, 1)*0.5;
y = discretize(y_scores, [-inf, -0.5, 0.5, inf]);

% 划分训练集和测试集
train_ratio = 0.7;
n_train = floor(n_samples * train_ratio);
X_train = X(1:n_train, :);
y_train = y(1:n_train);
X_test = X(n_train+1:end, :);
y_test = y(n_train+1:end);

2. 构建贝叶斯网络结构

matlab 复制代码
function dag = create_classification_bnet(n_features, n_classes)
    % 创建有向无环图
    % 节点1: 类别变量
    % 节点2-n_features+1: 特征变量
    dag = zeros(n_features + 1, n_features + 1);
    
    % 类别节点指向所有特征节点
    class_node = 1;
    for i = 1:n_features
        feature_node = i + 1;
        dag(class_node, feature_node) = 1;
    end
    
    fprintf('构建了包含%d个特征的分类网络\n', n_features);
end

3. 参数学习和模型训练

matlab 复制代码
function bnet = train_bayesian_classifier(X_train, y_train, n_classes)
    [n_samples, n_features] = size(X_train);
    
    % 创建网络结构
    dag = create_classification_bnet(n_features, n_classes);
    ns = ones(1, n_features + 1);
    ns(1) = n_classes;  % 类别节点的状态数
    
    % 离散化连续特征(简单示例使用3个区间)
    for i = 1:n_features
        ns(i+1) = 3;  % 每个特征离散化为3个状态
    end
    
    % 创建贝叶斯网络
    bnet = mk_bnet(dag, ns);
    
    % 设置节点类型
    bnet.CPD{1} = tabular_CPD(bnet, 1);  % 类别先验
    
    for i = 1:n_features
        bnet.CPD{i+1} = tabular_CPD(bnet, i+1);  % 特征条件概率
    end
    
    % 准备训练数据
    data = cell(n_features + 1, n_samples);
    
    % 类别数据
    data(1, :) = num2cell(y_train');
    
    % 特征数据(离散化)
    for i = 1:n_features
        feature_data = X_train(:, i);
        % 等宽离散化
        edges = linspace(min(feature_data), max(feature_data), 4);
        disc_data = discretize(feature_data, edges);
        data(i+1, :) = num2cell(disc_data');
    end
    
    % 参数学习
    bnet = learn_params(bnet, data);
    
    fprintf('贝叶斯分类器训练完成\n');
end

4. 预测函数

matlab 复制代码
function [predictions, probabilities] = bayesian_predict(bnet, X_test, n_classes)
    [n_test, n_features] = size(X_test);
    
    % 离散化测试数据
    test_data = cell(n_features + 1, n_test);
    test_data(1, :) = {[]};  % 类别未知
    
    for i = 1:n_features
        feature_data = X_test(:, i);
        edges = linspace(min(feature_data), max(feature_data), 4);
        disc_data = discretize(feature_data, edges);
        test_data(i+1, :) = num2cell(disc_data');
    end
    
    % 创建推理引擎
    engine = jtree_inf_engine(bnet);
    
    predictions = zeros(n_test, 1);
    probabilities = zeros(n_test, n_classes);
    
    for i = 1:n_test
        % 设置证据
        evidence = test_data(:, i);
        
        % 进行推理
        [engine, loglik] = enter_evidence(engine, evidence);
        
        % 查询类别节点的后验概率
        marg = marginal_nodes(engine, 1);
        
        probabilities(i, :) = marg.T;
        [~, predictions(i)] = max(marg.T);
    end
end

5. 完整的工作流程

matlab 复制代码
function main_multiple_input_classification()
    % 主函数:多输入分类预测
    
    % 1. 准备数据
    fprintf('准备数据...\n');
    [X, y, X_train, y_train, X_test, y_test] = prepare_data();
    
    % 2. 训练模型
    fprintf('训练贝叶斯分类器...\n');
    n_classes = length(unique(y));
    bnet = train_bayesian_classifier(X_train, y_train, n_classes);
    
    % 3. 预测
    fprintf('进行预测...\n');
    [predictions, probabilities] = bayesian_predict(bnet, X_test, n_classes);
    
    % 4. 评估模型
    fprintf('评估模型性能...\n');
    evaluate_model(y_test, predictions, probabilities);
    
    % 5. 可视化结果
    visualize_results(X_test, y_test, predictions, probabilities);
end

function [X, y, X_train, y_train, X_test, y_test] = prepare_data()
    % 生成更真实的数据
    n_samples = 2000;
    n_features = 4;
    
    % 生成特征数据
    X = zeros(n_samples, n_features);
    
    % 特征1: 正态分布
    X(:,1) = randn(n_samples, 1);
    % 特征2: 均匀分布
    X(:,2) = rand(n_samples, 1) * 10;
    % 特征3: 与特征1相关
    X(:,3) = 0.5 * X(:,1) + randn(n_samples, 1) * 0.5;
    % 特征4: 分类特征
    X(:,4) = randi(3, n_samples, 1);
    
    % 基于特征的复杂关系生成类别
    scores = 2*X(:,1) - 1.5*X(:,2) + 0.8*X(:,3) + 0.5*(X(:,4)-2);
    noise = randn(n_samples, 1) * 0.3;
    y_scores = scores + noise;
    
    % 分为3类
    y = discretize(y_scores, [-inf, -2, 2, inf]);
    
    % 划分数据集
    train_ratio = 0.7;
    n_train = floor(n_samples * train_ratio);
    X_train = X(1:n_train, :);
    y_train = y(1:n_train);
    X_test = X(n_train+1:end, :);
    y_test = y(n_train+1:end);
end

function evaluate_model(y_true, y_pred, probabilities)
    % 计算准确率
    accuracy = sum(y_true == y_pred) / length(y_true);
    fprintf('准确率: %.4f\n', accuracy);
    
    % 混淆矩阵
    cm = confusionmat(y_true, y_pred);
    fprintf('混淆矩阵:\n');
    disp(cm);
    
    % 各类别准确率
    for i = 1:max(y_true)
        class_idx = (y_true == i);
        class_acc = sum(y_true(class_idx) == y_pred(class_idx)) / sum(class_idx);
        fprintf('类别 %d 准确率: %.4f\n', i, class_acc);
    end
    
    % 平均概率
    mean_prob = mean(max(probabilities, [], 2));
    fprintf('平均预测置信度: %.4f\n', mean_prob);
end

function visualize_results(X_test, y_test, predictions, probabilities)
    % 可视化结果
    figure;
    
    % 只使用前两个特征进行可视化
    if size(X_test, 2) >= 2
        subplot(2, 2, 1);
        gscatter(X_test(:,1), X_test(:,2), y_test, 'rgb', 'o');
        title('真实类别');
        xlabel('特征1'); ylabel('特征2');
        
        subplot(2, 2, 2);
        gscatter(X_test(:,1), X_test(:,2), predictions, 'rgb', 'x');
        title('预测类别');
        xlabel('特征1'); ylabel('特征2');
        
        subplot(2, 2, 3);
        correct = (y_test == predictions);
        gscatter(X_test(:,1), X_test(:,2), correct, 'br', 'o*');
        title('分类结果 (蓝色:正确, 红色:错误)');
        xlabel('特征1'); ylabel('特征2');
        
        subplot(2, 2, 4);
        confidence = max(probabilities, [], 2);
        scatter(X_test(:,1), X_test(:,2), 50, confidence, 'filled');
        colorbar;
        title('预测置信度');
        xlabel('特征1'); ylabel('特征2');
    end
end

6. 高级功能:连续特征处理

matlab 复制代码
function bnet = train_with_continuous_features(X_train, y_train, n_classes)
    % 使用连续特征的高斯节点
    
    [n_samples, n_features] = size(X_train);
    
    % 创建网络结构
    dag = create_classification_bnet(n_features, n_classes);
    
    % 节点大小:类别节点离散,特征节点连续
    discrete_nodes = 1;
    continuous_nodes = 2:(n_features+1);
    ns = ones(1, n_features + 1);
    ns(discrete_nodes) = n_classes;
    
    % 创建混合网络
    bnet = mk_bnet(dag, ns, 'discrete', discrete_nodes);
    
    % 设置CPD
    bnet.CPD{1} = tabular_CPD(bnet, 1);  % 离散类别节点
    
    for i = continuous_nodes
        % 高斯节点,均值依赖于父节点(类别)
        bnet.CPD{i} = gaussian_CPD(bnet, i, 'cov_type', 'diag');
    end
    
    % 准备数据
    data = cell(n_features + 1, n_samples);
    data(1, :) = num2cell(y_train');
    for i = 1:n_features
        data{i+1, :} = num2cell(X_train(:, i)');
    end
    
    % 参数学习
    bnet = learn_params(bnet, data);
end
相关推荐
jllllyuz39 分钟前
Matlab实现基于Matrix Pencil算法实现声源信号角度和时间估计
开发语言·算法·matlab
Dev7z8 小时前
基于Matlab传统图像处理的风景图像多风格转换与优化
图像处理·matlab·风景
hacker70710 小时前
openGauss 在K12教育场景的数据处理测评:CASE WHEN 实现高效分类
人工智能·分类·数据挖掘
t198751281 天前
基于MATLAB的指纹识别系统完整实现
开发语言·matlab
gihigo19981 天前
基于MATLAB的IEEE 14节点系统牛顿-拉夫逊潮流算法实现
开发语言·算法·matlab
云纳星辰怀自在1 天前
MATLAB: m脚本-fixdt数据类型数据范围
matlab·m脚本·fixdt
一叶知秋h1 天前
matlab实现PID参数功能的简单仿真_gif
matlab·gif·pid
技术净胜1 天前
MATLAB 基因表达数据处理与可视化全流程案例
开发语言·matlab
大数据魔法师1 天前
分类与回归算法(六)- 集成学习(随机森林、梯度提升决策树、Stacking分类)相关理论
分类·回归·集成学习
大数据魔法师1 天前
分类与回归算法(五)- 决策树分类
决策树·分类·回归