MATLAB基于BNT工具箱的多输入分类预测

1. 环境准备和数据加载

matlab 复制代码
% 添加BNT工具箱到路径
addpath(genpath('bnt'));

% 生成示例数据
% 假设有3个输入特征和1个分类输出
n_samples = 1000;
n_features = 3;
n_classes = 3;

% 生成特征数据
X = randn(n_samples, n_features);
% 生成类别标签(基于特征的某种关系)
true_weights = [2, -1, 0.5];
y_scores = X * true_weights' + randn(n_samples, 1)*0.5;
y = discretize(y_scores, [-inf, -0.5, 0.5, inf]);

% 划分训练集和测试集
train_ratio = 0.7;
n_train = floor(n_samples * train_ratio);
X_train = X(1:n_train, :);
y_train = y(1:n_train);
X_test = X(n_train+1:end, :);
y_test = y(n_train+1:end);

2. 构建贝叶斯网络结构

matlab 复制代码
function dag = create_classification_bnet(n_features, n_classes)
    % 创建有向无环图
    % 节点1: 类别变量
    % 节点2-n_features+1: 特征变量
    dag = zeros(n_features + 1, n_features + 1);
    
    % 类别节点指向所有特征节点
    class_node = 1;
    for i = 1:n_features
        feature_node = i + 1;
        dag(class_node, feature_node) = 1;
    end
    
    fprintf('构建了包含%d个特征的分类网络\n', n_features);
end

3. 参数学习和模型训练

matlab 复制代码
function bnet = train_bayesian_classifier(X_train, y_train, n_classes)
    [n_samples, n_features] = size(X_train);
    
    % 创建网络结构
    dag = create_classification_bnet(n_features, n_classes);
    ns = ones(1, n_features + 1);
    ns(1) = n_classes;  % 类别节点的状态数
    
    % 离散化连续特征(简单示例使用3个区间)
    for i = 1:n_features
        ns(i+1) = 3;  % 每个特征离散化为3个状态
    end
    
    % 创建贝叶斯网络
    bnet = mk_bnet(dag, ns);
    
    % 设置节点类型
    bnet.CPD{1} = tabular_CPD(bnet, 1);  % 类别先验
    
    for i = 1:n_features
        bnet.CPD{i+1} = tabular_CPD(bnet, i+1);  % 特征条件概率
    end
    
    % 准备训练数据
    data = cell(n_features + 1, n_samples);
    
    % 类别数据
    data(1, :) = num2cell(y_train');
    
    % 特征数据(离散化)
    for i = 1:n_features
        feature_data = X_train(:, i);
        % 等宽离散化
        edges = linspace(min(feature_data), max(feature_data), 4);
        disc_data = discretize(feature_data, edges);
        data(i+1, :) = num2cell(disc_data');
    end
    
    % 参数学习
    bnet = learn_params(bnet, data);
    
    fprintf('贝叶斯分类器训练完成\n');
end

4. 预测函数

matlab 复制代码
function [predictions, probabilities] = bayesian_predict(bnet, X_test, n_classes)
    [n_test, n_features] = size(X_test);
    
    % 离散化测试数据
    test_data = cell(n_features + 1, n_test);
    test_data(1, :) = {[]};  % 类别未知
    
    for i = 1:n_features
        feature_data = X_test(:, i);
        edges = linspace(min(feature_data), max(feature_data), 4);
        disc_data = discretize(feature_data, edges);
        test_data(i+1, :) = num2cell(disc_data');
    end
    
    % 创建推理引擎
    engine = jtree_inf_engine(bnet);
    
    predictions = zeros(n_test, 1);
    probabilities = zeros(n_test, n_classes);
    
    for i = 1:n_test
        % 设置证据
        evidence = test_data(:, i);
        
        % 进行推理
        [engine, loglik] = enter_evidence(engine, evidence);
        
        % 查询类别节点的后验概率
        marg = marginal_nodes(engine, 1);
        
        probabilities(i, :) = marg.T;
        [~, predictions(i)] = max(marg.T);
    end
end

5. 完整的工作流程

matlab 复制代码
function main_multiple_input_classification()
    % 主函数:多输入分类预测
    
    % 1. 准备数据
    fprintf('准备数据...\n');
    [X, y, X_train, y_train, X_test, y_test] = prepare_data();
    
    % 2. 训练模型
    fprintf('训练贝叶斯分类器...\n');
    n_classes = length(unique(y));
    bnet = train_bayesian_classifier(X_train, y_train, n_classes);
    
    % 3. 预测
    fprintf('进行预测...\n');
    [predictions, probabilities] = bayesian_predict(bnet, X_test, n_classes);
    
    % 4. 评估模型
    fprintf('评估模型性能...\n');
    evaluate_model(y_test, predictions, probabilities);
    
    % 5. 可视化结果
    visualize_results(X_test, y_test, predictions, probabilities);
end

function [X, y, X_train, y_train, X_test, y_test] = prepare_data()
    % 生成更真实的数据
    n_samples = 2000;
    n_features = 4;
    
    % 生成特征数据
    X = zeros(n_samples, n_features);
    
    % 特征1: 正态分布
    X(:,1) = randn(n_samples, 1);
    % 特征2: 均匀分布
    X(:,2) = rand(n_samples, 1) * 10;
    % 特征3: 与特征1相关
    X(:,3) = 0.5 * X(:,1) + randn(n_samples, 1) * 0.5;
    % 特征4: 分类特征
    X(:,4) = randi(3, n_samples, 1);
    
    % 基于特征的复杂关系生成类别
    scores = 2*X(:,1) - 1.5*X(:,2) + 0.8*X(:,3) + 0.5*(X(:,4)-2);
    noise = randn(n_samples, 1) * 0.3;
    y_scores = scores + noise;
    
    % 分为3类
    y = discretize(y_scores, [-inf, -2, 2, inf]);
    
    % 划分数据集
    train_ratio = 0.7;
    n_train = floor(n_samples * train_ratio);
    X_train = X(1:n_train, :);
    y_train = y(1:n_train);
    X_test = X(n_train+1:end, :);
    y_test = y(n_train+1:end);
end

function evaluate_model(y_true, y_pred, probabilities)
    % 计算准确率
    accuracy = sum(y_true == y_pred) / length(y_true);
    fprintf('准确率: %.4f\n', accuracy);
    
    % 混淆矩阵
    cm = confusionmat(y_true, y_pred);
    fprintf('混淆矩阵:\n');
    disp(cm);
    
    % 各类别准确率
    for i = 1:max(y_true)
        class_idx = (y_true == i);
        class_acc = sum(y_true(class_idx) == y_pred(class_idx)) / sum(class_idx);
        fprintf('类别 %d 准确率: %.4f\n', i, class_acc);
    end
    
    % 平均概率
    mean_prob = mean(max(probabilities, [], 2));
    fprintf('平均预测置信度: %.4f\n', mean_prob);
end

function visualize_results(X_test, y_test, predictions, probabilities)
    % 可视化结果
    figure;
    
    % 只使用前两个特征进行可视化
    if size(X_test, 2) >= 2
        subplot(2, 2, 1);
        gscatter(X_test(:,1), X_test(:,2), y_test, 'rgb', 'o');
        title('真实类别');
        xlabel('特征1'); ylabel('特征2');
        
        subplot(2, 2, 2);
        gscatter(X_test(:,1), X_test(:,2), predictions, 'rgb', 'x');
        title('预测类别');
        xlabel('特征1'); ylabel('特征2');
        
        subplot(2, 2, 3);
        correct = (y_test == predictions);
        gscatter(X_test(:,1), X_test(:,2), correct, 'br', 'o*');
        title('分类结果 (蓝色:正确, 红色:错误)');
        xlabel('特征1'); ylabel('特征2');
        
        subplot(2, 2, 4);
        confidence = max(probabilities, [], 2);
        scatter(X_test(:,1), X_test(:,2), 50, confidence, 'filled');
        colorbar;
        title('预测置信度');
        xlabel('特征1'); ylabel('特征2');
    end
end

6. 高级功能:连续特征处理

matlab 复制代码
function bnet = train_with_continuous_features(X_train, y_train, n_classes)
    % 使用连续特征的高斯节点
    
    [n_samples, n_features] = size(X_train);
    
    % 创建网络结构
    dag = create_classification_bnet(n_features, n_classes);
    
    % 节点大小:类别节点离散,特征节点连续
    discrete_nodes = 1;
    continuous_nodes = 2:(n_features+1);
    ns = ones(1, n_features + 1);
    ns(discrete_nodes) = n_classes;
    
    % 创建混合网络
    bnet = mk_bnet(dag, ns, 'discrete', discrete_nodes);
    
    % 设置CPD
    bnet.CPD{1} = tabular_CPD(bnet, 1);  % 离散类别节点
    
    for i = continuous_nodes
        % 高斯节点,均值依赖于父节点(类别)
        bnet.CPD{i} = gaussian_CPD(bnet, i, 'cov_type', 'diag');
    end
    
    % 准备数据
    data = cell(n_features + 1, n_samples);
    data(1, :) = num2cell(y_train');
    for i = 1:n_features
        data{i+1, :} = num2cell(X_train(:, i)');
    end
    
    % 参数学习
    bnet = learn_params(bnet, data);
end
相关推荐
年年测试10 小时前
AI驱动的测试:用Dify工作流实现智能缺陷分析与分类
人工智能·分类·数据挖掘
机器学习之心19 小时前
MATLAB基于改进云物元的模拟机协同训练质量评价
matlab·改进云物元
ytttr87319 小时前
MATLAB实现经验模态分解(EMD)与希尔伯特变换获取能量谱
人工智能·python·matlab
t1987512820 小时前
基于多假设跟踪(MHT)算法的MATLAB实现
开发语言·matlab
机器学习之心1 天前
MATLAB多子种群混沌自适应哈里斯鹰算法优化BP神经网络回归预测
神经网络·算法·matlab
abcwoabcwo1 天前
回归、预测、分类三者关系
分类·数据挖掘·回归
π同学1 天前
基于Matlab的递推最小二乘法参数估计
matlab·最小二乘法
小喵要摸鱼1 天前
【MATLBA】使用教程
matlab
listhi5202 天前
基于空时阵列最佳旋转角度的卫星导航抗干扰信号处理的完整MATLAB仿真
开发语言·matlab·信号处理