基于MATLAB实现多变量高斯过程回归(GPR)

一、数据生成与预处理

matlab 复制代码
%% 生成多变量输入输出数据
rng(0); % 固定随机种子
n_samples = 200;

% 输入特征(3维)
X = [linspace(-5,5,n_samples)', randn(n_samples,2)*2];

% 输出目标(2维非线性函数)
Y(:,1) = sin(X(:,1)).*exp(-0.1*X(:,1).^2) + 0.5*X(:,2);
Y(:,2) = cos(X(:,1)).*exp(-0.1*X(:,1).^2) + 0.3*X(:,3) + 0.1*randn(n_samples,1);

%% 数据可视化
figure;
subplot(2,1,1);
scatter3(X(:,1),X(:,2),X(:,3),50,Y(:,1),'filled');
xlabel('X1'); ylabel('X2'); zlabel('X3');
title('输入数据分布(3D)');

subplot(2,1,2);
plot(Y(:,1),'r',Y(:,2),'b');
legend('输出1','输出2');
title('输出目标分布');

二、模型训练与预测

matlab 复制代码
%% 数据划分(70%训练,30%测试)
cv = cvpartition(size(X,1),'HoldOut',0.3);
X_train = X(training(cv),:);
Y_train = Y(training(cv),:);
X_test = X(test(cv),:);
Y_test = Y(test(cv),:);

%% 数据归一化(参考)
[yp_train, ps_input] = mapminmax(X_train',0,1);
yp_test = mapminmax('apply',X_test',ps_input);
[yt_train, ps_output] = mapminmax(Y_train',0,1);
yt_test = mapminmax('apply',Y_test',ps_output);

%% 训练GPR模型(多输出)
kernel = {'ardsquaredexponential','ardsquaredexponential'}; % 自动相关尺度核
gpr_model = fitrgp(yp_train', yt_train', 'Kernel', kernel, ...
    'Basis', 'constant', 'FitMethod', 'exact', ...
    'PredictMethod', 'exact', 'Standardize', 1);

%% 预测
[ypred,ypredci] = predict(gpr_model, yp_test');
ypred = mapminmax('reverse',ypred',ps_output)';
ypredci = mapminmax('reverse',ypredci',ps_output)';

三、性能评估(参考)

matlab 复制代码
%% 计算评价指标
R2 = 1 - sum((Y_test - ypred).^2)/sum((Y_test - mean(Y_test)).^2);
MAE = mean(abs(Y_test - ypred));
MSE = mean((Y_test - ypred).^2);
RMSE = sqrt(MSE);

%% 输出结果
fprintf('===== 模型评估 =====
');
fprintf('R²: %.4f\n', R2);
fprintf('MAE: %.4f\n', MAE);
fprintf('RMSE: %.4f\n', RMSE);

四、可视化分析

matlab 复制代码
%% 三维预测结果可视化
figure;
scatter3(X_test(:,1),X_test(:,2),X_test(:,3),50,Y_test(:,1),'filled');
hold on;
plot3(X_test(:,1),X_test(:,2),X_test(:,3),...
    'o', 'MarkerSize', 8, 'MarkerFaceColor', 'r');
xlabel('X1'); ylabel('X2'); zlabel('X3');
title('输出1预测结果(红点为真实值,曲面为预测值)');

%% 置信区间可视化(以输出2为例)
figure;
plot(Y_test(:,2),'b', 'LineWidth',1.5);
hold on;
plot(ypred(:,2),'r--', 'LineWidth',1.5);
fill([1:length(ypred)], [ypred(:,2); flipud(ypredci(:,2))], ...
    [0.8 0.8 1], 'FaceAlpha',0.3, 'EdgeColor','none');
xlabel('样本序号'); ylabel('输出值');
legend('真实值','预测均值','95%置信区间');
title('输出2预测置信区间');

五、关键参数说明

  1. 核函数选择 ardsquaredexponential:自动相关尺度平方指数核(推荐多变量场景) matern52:Matérn 5/2核(适合有导数连续性的数据)

  2. 超参数优化

    可通过最大似然估计自动优化:

    matlab 复制代码
    gpr_model = fitrgp(yp_train', yt_train', ...
        'Basis', 'constant', 'FitMethod', 'exact', ...
        'OptimizeHyperparameters', 'all', ...
        'HyperparameterOptimizationOptions', struct('AcquisitionFunctionName','expected-improvement-plus'));

参考代码 多变量高斯过程回归示例 www.youwenfan.com/contentcsp/97722.html

六、扩展应用场景

  1. 时序预测

    将时间序列数据转换为监督学习格式:

    matlab 复制代码
    X_t = [X(1:end-1,:) X(2:end,:)]; % 输入为时间步特征
    Y_t = X(2:end,:); % 输出为下一时刻状态
  2. 不确定性量化

    通过预测方差分析不确定性:

    matlab 复制代码
    [~, S] = predict(gpr_model, yp_test');
    figure;
    semilogy(Y_test(:,1), 'b', 'LineWidth',1.5);
    hold on;
    plot(ypred(:,1), 'r--', 'LineWidth',1.5);
    errorbar(1:length(ypred), ypred(:,1), 2*sqrt(S(:,1)), 'r.');

七、注意事项

  1. 数据规模限制 样本数建议不超过10,000(内存限制) 特征维度建议不超过20(计算复杂度O(n³))
  2. 计算加速 使用FIT_METHOD='sd'进行稀疏近似 设置ActiveSetSize=100控制活跃集大小
相关推荐
Fcy6481 小时前
C++ set&&map的模拟实现
开发语言·c++·stl
你怎么知道我是队长7 小时前
C语言---枚举变量
c语言·开发语言
李慕婉学姐7 小时前
【开题答辩过程】以《基于JAVA的校园即时配送系统的设计与实现》为例,不知道这个选题怎么做的,不知道这个选题怎么开题答辩的可以进来看看
java·开发语言·数据库
吃茄子的猫7 小时前
quecpython中&的具体含义和使用场景
开发语言·python
云栖梦泽7 小时前
易语言中小微企业Windows桌面端IoT监控与控制
开发语言
数据大魔方8 小时前
【期货量化实战】日内动量策略:顺势而为的短线交易法(Python源码)
开发语言·数据库·python·mysql·算法·github·程序员创富
Edward.W9 小时前
Python uv:新一代Python包管理工具,彻底改变开发体验
开发语言·python·uv
小熊officer9 小时前
Python字符串
开发语言·数据库·python
月疯9 小时前
各种信号的模拟(ECG信号、质谱图、EEG信号),方便U-net训练
开发语言·python
荒诞硬汉9 小时前
JavaBean相关补充
java·开发语言