基于MATLAB实现多变量高斯过程回归(GPR)

一、数据生成与预处理

matlab 复制代码
%% 生成多变量输入输出数据
rng(0); % 固定随机种子
n_samples = 200;

% 输入特征(3维)
X = [linspace(-5,5,n_samples)', randn(n_samples,2)*2];

% 输出目标(2维非线性函数)
Y(:,1) = sin(X(:,1)).*exp(-0.1*X(:,1).^2) + 0.5*X(:,2);
Y(:,2) = cos(X(:,1)).*exp(-0.1*X(:,1).^2) + 0.3*X(:,3) + 0.1*randn(n_samples,1);

%% 数据可视化
figure;
subplot(2,1,1);
scatter3(X(:,1),X(:,2),X(:,3),50,Y(:,1),'filled');
xlabel('X1'); ylabel('X2'); zlabel('X3');
title('输入数据分布(3D)');

subplot(2,1,2);
plot(Y(:,1),'r',Y(:,2),'b');
legend('输出1','输出2');
title('输出目标分布');

二、模型训练与预测

matlab 复制代码
%% 数据划分(70%训练,30%测试)
cv = cvpartition(size(X,1),'HoldOut',0.3);
X_train = X(training(cv),:);
Y_train = Y(training(cv),:);
X_test = X(test(cv),:);
Y_test = Y(test(cv),:);

%% 数据归一化(参考)
[yp_train, ps_input] = mapminmax(X_train',0,1);
yp_test = mapminmax('apply',X_test',ps_input);
[yt_train, ps_output] = mapminmax(Y_train',0,1);
yt_test = mapminmax('apply',Y_test',ps_output);

%% 训练GPR模型(多输出)
kernel = {'ardsquaredexponential','ardsquaredexponential'}; % 自动相关尺度核
gpr_model = fitrgp(yp_train', yt_train', 'Kernel', kernel, ...
    'Basis', 'constant', 'FitMethod', 'exact', ...
    'PredictMethod', 'exact', 'Standardize', 1);

%% 预测
[ypred,ypredci] = predict(gpr_model, yp_test');
ypred = mapminmax('reverse',ypred',ps_output)';
ypredci = mapminmax('reverse',ypredci',ps_output)';

三、性能评估(参考)

matlab 复制代码
%% 计算评价指标
R2 = 1 - sum((Y_test - ypred).^2)/sum((Y_test - mean(Y_test)).^2);
MAE = mean(abs(Y_test - ypred));
MSE = mean((Y_test - ypred).^2);
RMSE = sqrt(MSE);

%% 输出结果
fprintf('===== 模型评估 =====
');
fprintf('R²: %.4f\n', R2);
fprintf('MAE: %.4f\n', MAE);
fprintf('RMSE: %.4f\n', RMSE);

四、可视化分析

matlab 复制代码
%% 三维预测结果可视化
figure;
scatter3(X_test(:,1),X_test(:,2),X_test(:,3),50,Y_test(:,1),'filled');
hold on;
plot3(X_test(:,1),X_test(:,2),X_test(:,3),...
    'o', 'MarkerSize', 8, 'MarkerFaceColor', 'r');
xlabel('X1'); ylabel('X2'); zlabel('X3');
title('输出1预测结果(红点为真实值,曲面为预测值)');

%% 置信区间可视化(以输出2为例)
figure;
plot(Y_test(:,2),'b', 'LineWidth',1.5);
hold on;
plot(ypred(:,2),'r--', 'LineWidth',1.5);
fill([1:length(ypred)], [ypred(:,2); flipud(ypredci(:,2))], ...
    [0.8 0.8 1], 'FaceAlpha',0.3, 'EdgeColor','none');
xlabel('样本序号'); ylabel('输出值');
legend('真实值','预测均值','95%置信区间');
title('输出2预测置信区间');

五、关键参数说明

  1. 核函数选择 ardsquaredexponential:自动相关尺度平方指数核(推荐多变量场景) matern52:Matérn 5/2核(适合有导数连续性的数据)

  2. 超参数优化

    可通过最大似然估计自动优化:

    matlab 复制代码
    gpr_model = fitrgp(yp_train', yt_train', ...
        'Basis', 'constant', 'FitMethod', 'exact', ...
        'OptimizeHyperparameters', 'all', ...
        'HyperparameterOptimizationOptions', struct('AcquisitionFunctionName','expected-improvement-plus'));

参考代码 多变量高斯过程回归示例 www.youwenfan.com/contentcsp/97722.html

六、扩展应用场景

  1. 时序预测

    将时间序列数据转换为监督学习格式:

    matlab 复制代码
    X_t = [X(1:end-1,:) X(2:end,:)]; % 输入为时间步特征
    Y_t = X(2:end,:); % 输出为下一时刻状态
  2. 不确定性量化

    通过预测方差分析不确定性:

    matlab 复制代码
    [~, S] = predict(gpr_model, yp_test');
    figure;
    semilogy(Y_test(:,1), 'b', 'LineWidth',1.5);
    hold on;
    plot(ypred(:,1), 'r--', 'LineWidth',1.5);
    errorbar(1:length(ypred), ypred(:,1), 2*sqrt(S(:,1)), 'r.');

七、注意事项

  1. 数据规模限制 样本数建议不超过10,000(内存限制) 特征维度建议不超过20(计算复杂度O(n³))
  2. 计算加速 使用FIT_METHOD='sd'进行稀疏近似 设置ActiveSetSize=100控制活跃集大小
相关推荐
白露与泡影2 小时前
2026版Java架构师面试题及答案整理汇总
java·开发语言
cici158742 小时前
大规模MIMO系统中Alamouti预编码的QPSK复用性能MATLAB仿真
算法·matlab·预编码算法
一个天蝎座 白勺 程序猿2 小时前
KingbaseES查询逻辑优化深度解析:从子查询到语义优化的全链路实践
开发语言·数据库·kingbasees·金仓数据库
skywalker_113 小时前
Java中异常
java·开发语言·异常
2501_940315263 小时前
航电oj:首字母变大写
开发语言·c++·算法
没有天赋那就反复3 小时前
JAVA 静态方法
java·开发语言
Thomas_YXQ3 小时前
Unity3D在ios平台下内存的优化详解
开发语言·macos·ios·性能优化·cocoa
咸甜适中3 小时前
rust的docx-rs库,自定义docx模版批量生成docx文档(逐行注释)
开发语言·rust·docx·docx-rs
浒畔居3 小时前
泛型编程与STL设计思想
开发语言·c++·算法
Fcy6483 小时前
C++ 异常详解
开发语言·c++·异常