MATLAB SGM(半全局匹配)算法实现,用于立体视觉中的稠密视差图计算。
SGM算法MATLAB实现
matlab
%% SGM(半全局匹配)算法MATLAB实现
clear all; close all; clc;
%% 1. 读取立体图像对
% 确保有left.png和right.png在相同目录下,或提供完整路径
I_left = imread('left.png');
I_right = imread('right.png');
% 转换为灰度图(如果是彩色图像)
if size(I_left, 3) == 3
I_left_gray = rgb2gray(I_left);
I_right_gray = rgb2gray(I_right);
else
I_left_gray = I_left;
I_right_gray = I_right;
end
% 转换为double类型以便计算
I_left_gray = im2double(I_left_gray);
I_right_gray = im2double(I_right_gray);
%% 2. 算法参数设置
params = struct();
params.max_disparity = 64; % 最大视差搜索范围
params.window_size = 3; % 匹配窗口大小(奇数)
params.P1 = 0.1; % 惩罚系数P1(小视差变化)
params.P2 = 0.3; % 惩罚系数P2(大视差变化)
params.num_directions = 8; % 聚合路径方向数(通常4或8)
params.census_size = 5; % Census变换窗口大小(奇数)
params.enable_subpixel = true; % 启用亚像素优化
params.enable_lr_check = true; % 启用左右一致性检查
params.lr_threshold = 1; % 左右一致性检查阈值
%% 3. Census变换(计算特征描述子)
fprintf('进行Census变换...\n');
[C_left, C_right] = census_transform(I_left_gray, I_right_gray, params.census_size);
fprintf('Census变换完成\n');
%% 4. 计算匹配代价(基于Hamming距离)
fprintf('计算初始匹配代价...\n');
cost_volume = compute_matching_cost(C_left, C_right, params.max_disparity, params.window_size);
fprintf('匹配代价计算完成\n');
%% 5. 代价聚合(SGM核心)
fprintf('进行SGM代价聚合...\n');
aggregated_cost = sgm_cost_aggregation(cost_volume, params);
fprintf('代价聚合完成\n');
%% 6. 视差计算(赢家通吃)
fprintf('计算初始视差图...\n');
disparity_map = compute_disparity(aggregated_cost);
fprintf('视差计算完成\n');
%% 7. 亚像素优化(可选)
if params.enable_subpixel
fprintf('进行亚像素优化...\n');
disparity_map = subpixel_refinement(disparity_map, aggregated_cost);
fprintf('亚像素优化完成\n');
end
%% 8. 左右一致性检查(可选)
if params.enable_lr_check
fprintf('进行左右一致性检查...\n');
% 计算右视差图
[~, C_right_flip] = census_transform(I_right_gray, I_left_gray, params.census_size);
cost_volume_right = compute_matching_cost(C_right_flip, C_left, params.max_disparity, params.window_size);
aggregated_cost_right = sgm_cost_aggregation(cost_volume_right, params);
disparity_map_right = compute_disparity(aggregated_cost_right);
% 一致性检查
disparity_map = left_right_check(disparity_map, disparity_map_right, params.lr_threshold);
fprintf('一致性检查完成\n');
end
%% 9. 视差图后处理
fprintf('进行视差图后处理...\n');
disparity_map = disparity_postprocessing(disparity_map, params.max_disparity);
fprintf('后处理完成\n');
%% 10. 结果显示
figure('Position', [100, 100, 1200, 500]);
% 原始图像
subplot(2, 3, 1);
imshow(I_left_gray);
title('左图像');
subplot(2, 3, 2);
imshow(I_right_gray);
title('右图像');
% Census特征图
subplot(2, 3, 3);
imshow(C_left(:,:,1), []);
title('Census特征图(第一通道)');
% 代价体积切片(在特定行)
subplot(2, 3, 4);
row_to_show = round(size(cost_volume, 1)/2);
cost_slice = squeeze(cost_volume(row_to_show, :, :));
imagesc(cost_slice');
xlabel('图像列'); ylabel('视差'); title('匹配代价切片');
colorbar; axis xy;
% 视差图
subplot(2, 3, 5);
imagesc(disparity_map, [0 params.max_disparity]);
colormap('jet'); colorbar;
axis image; title('SGM视差图');
% 3D点云显示(可选)
subplot(2, 3, 6);
show_point_cloud = true;
if show_point_cloud
% 生成简单的3D显示(需要调整参数)
[h, w] = size(disparity_map);
[X, Y] = meshgrid(1:w, 1:h);
Z = disparity_map;
% 下采样以加快显示
step = 4;
X_small = X(1:step:end, 1:step:end);
Y_small = Y(1:step:end, 1:step:end);
Z_small = Z(1:step:end, 1:step:end);
surf(X_small, Y_small, Z_small, 'EdgeColor', 'none', 'FaceAlpha', 0.7);
colormap('jet'); colorbar;
axis tight; view(-30, 30);
title('3D视差表面');
xlabel('X'); ylabel('Y'); zlabel('视差');
end
fprintf('\nSGM算法完成!\n');
fprintf('视差范围:[%.2f, %.2f]\n', min(disparity_map(:)), max(disparity_map(:)));
%% 关键函数实现
% ==============================================
% 函数1: Census变换
% ==============================================
function [C_left, C_right] = census_transform(I_left, I_right, window_size)
[h, w] = size(I_left);
half_win = floor(window_size/2);
% 扩展图像边界
I_left_pad = padarray(I_left, [half_win, half_win], 'replicate');
I_right_pad = padarray(I_right, [half_win, half_win], 'replicate');
% 初始化Census特征
C_left = zeros(h, w, window_size^2-1, 'uint8');
C_right = zeros(h, w, window_size^2-1, 'uint8');
for i = 1:h
for j = 1:w
% 提取局部窗口
window_left = I_left_pad(i:i+window_size-1, j:j+window_size-1);
window_right = I_right_pad(i:i+window_size-1, j:j+window_size-1);
% 中心像素值
center_left = window_left(half_win+1, half_win+1);
center_right = window_right(half_win+1, half_win+1);
% 计算Census特征(与中心像素比较)
census_left = [];
census_right = [];
idx = 1;
for m = 1:window_size
for n = 1:window_size
if (m == half_win+1) && (n == half_win+1)
continue; % 跳过中心像素
end
% 左图像Census特征
if window_left(m, n) >= center_left
census_left(idx) = 1;
else
census_left(idx) = 0;
end
% 右图像Census特征
if window_right(m, n) >= center_right
census_right(idx) = 1;
else
census_right(idx) = 0;
end
idx = idx + 1;
end
end
C_left(i, j, :) = census_left;
C_right(i, j, :) = census_right;
end
end
end
% ==============================================
% 函数2: 计算匹配代价
% ==============================================
function cost_volume = compute_matching_cost(C_left, C_right, max_disparity, window_size)
[h, w, feature_dim] = size(C_left);
half_win = floor(window_size/2);
% 初始化代价体积
cost_volume = inf(h, w, max_disparity+1, 'single');
% 遍历所有像素
for i = half_win+1:h-half_win
for j = half_win+1:w-half_win
% 提取左图像特征窗口
left_window = C_left(i-half_win:i+half_win, j-half_win:j+half_win, :);
% 遍历所有可能视差
for d = 0:max_disparity
if j-d < half_win+1
continue; % 超出图像边界
end
% 提取右图像对应窗口
right_window = C_right(i-half_win:i+half_win, j-d-half_win:j-d+half_win, :);
% 计算Hamming距离(特征不同的位数)
diff_bits = sum(sum(sum(left_window ~= right_window, 3), 2), 1);
cost_volume(i, j, d+1) = diff_bits;
end
end
% 显示进度
if mod(i, 50) == 0
fprintf(' 处理进度: %.1f%%\n', 100*i/h);
end
end
end
% ==============================================
% 函数3: SGM代价聚合
% ==============================================
function aggregated_cost = sgm_cost_aggregation(cost_volume, params)
[h, w, num_disp] = size(cost_volume);
aggregated_cost = zeros(h, w, num_disp, 'single');
% 定义聚合方向(8方向)
directions = [
1, 0; % 右
0, 1; % 下
1, 1; % 右下
-1, 1; % 左下
-1, 0; % 左
0, -1; % 上
-1, -1; % 左上
1, -1; % 右上
];
num_directions = min(params.num_directions, size(directions, 1));
% 对每个方向进行聚合
for dir_idx = 1:num_directions
fprintf(' 聚合方向 %d/%d\n', dir_idx, num_directions);
dx = directions(dir_idx, 1);
dy = directions(dir_idx, 2);
% 沿当前方向聚合
dir_cost = aggregate_direction(cost_volume, dx, dy, params.P1, params.P2);
aggregated_cost = aggregated_cost + dir_cost;
end
aggregated_cost = aggregated_cost / num_directions;
end
% ==============================================
% 函数4: 单方向代价聚合
% ==============================================
function dir_cost = aggregate_direction(cost_volume, dx, dy, P1, P2)
[h, w, num_disp] = size(cost_volume);
dir_cost = zeros(h, w, num_disp, 'single');
% 确定遍历顺序
if dx >= 0 && dy >= 0
start_i = 1; end_i = h; step_i = 1;
start_j = 1; end_j = w; step_j = 1;
elseif dx >= 0 && dy < 0
start_i = h; end_i = 1; step_i = -1;
start_j = 1; end_j = w; step_j = 1;
elseif dx < 0 && dy >= 0
start_i = 1; end_i = h; step_i = 1;
start_j = w; end_j = 1; step_j = -1;
else
start_i = h; end_i = 1; step_i = -1;
start_j = w; end_j = 1; step_j = -1;
end
% 动态规划聚合
for i = start_i:step_i:end_i
for j = start_j:step_j:end_j
% 前一个像素位置
prev_i = i - dy;
prev_j = j - dx;
% 如果是起始像素,直接使用初始代价
if prev_i < 1 || prev_i > h || prev_j < 1 || prev_j > w
dir_cost(i, j, :) = cost_volume(i, j, :);
continue;
end
% 获取前一个像素的聚合代价
L_prev = squeeze(dir_cost(prev_i, prev_j, :));
% 获取当前像素的初始代价
C_current = squeeze(cost_volume(i, j, :));
% 计算当前像素的聚合代价
L_current = compute_path_cost(L_prev, C_current, P1, P2);
dir_cost(i, j, :) = L_current;
end
end
end
% ==============================================
% 函数5: 计算路径代价
% ==============================================
function L_current = compute_path_cost(L_prev, C_current, P1, P2)
num_disp = length(C_current);
L_current = zeros(num_disp, 1, 'single');
% 找到前一个像素的最小代价
min_L_prev = min(L_prev);
% 对每个视差计算路径代价
for d = 1:num_disp
% 计算各种情况的代价
cost1 = L_prev(d); % 相同视差
cost2 = min(L_prev) + P1; % 小视差变化
cost3 = min(L_prev(1:num_disp ~= d)) + P2; % 大视差变化
% 取最小值
L_current(d) = C_current(d) + min([cost1, cost2, cost3]) - min_L_prev;
end
end
% ==============================================
% 函数6: 计算视差图
% ==============================================
function disparity_map = compute_disparity(aggregated_cost)
[h, w, num_disp] = size(aggregated_cost);
disparity_map = zeros(h, w, 'single');
% 赢家通吃(WTA)策略
for i = 1:h
for j = 1:w
% 找到最小代价对应的视差
[~, best_disp] = min(aggregated_cost(i, j, :));
disparity_map(i, j) = best_disp - 1; % 转换为0-based视差
end
end
end
% ==============================================
% 函数7: 亚像素优化
% ==============================================
function disparity_refined = subpixel_refinement(disparity_map, aggregated_cost)
[h, w] = size(disparity_map);
disparity_refined = disparity_map;
for i = 2:h-1
for j = 2:w-1
d = round(disparity_map(i, j)) + 1; % 转换为1-based索引
if d < 2 || d > size(aggregated_cost, 3)-1
continue; % 边界情况
end
% 获取当前视差及其相邻视差的代价
C_minus1 = aggregated_cost(i, j, d-1);
C_0 = aggregated_cost(i, j, d);
C_plus1 = aggregated_cost(i, j, d+1);
% 抛物线拟合
delta = (C_minus1 - C_plus1) / (2 * (C_minus1 - 2*C_0 + C_plus1 + eps));
disparity_refined(i, j) = disparity_map(i, j) + delta;
end
end
end
% ==============================================
% 函数8: 左右一致性检查
% ==============================================
function disparity_filtered = left_right_check(disparity_left, disparity_right, threshold)
disparity_filtered = disparity_left;
[h, w] = size(disparity_left);
for i = 1:h
for j = 1:w
d_left = disparity_left(i, j);
% 对应的右图像素位置
j_right = round(j - d_left);
if j_right >= 1 && j_right <= w
d_right = disparity_right(i, j_right);
% 检查一致性
if abs(d_left - d_right) > threshold
disparity_filtered(i, j) = 0; % 标记为无效
end
else
disparity_filtered(i, j) = 0; % 边界无效
end
end
end
end
% ==============================================
% 函数9: 视差图后处理
% ==============================================
function disparity_processed = disparity_postprocessing(disparity_map, max_disparity)
disparity_processed = disparity_map;
% 1. 中值滤波(去除噪声)
disparity_processed = medfilt2(disparity_processed, [3, 3]);
% 2. 空洞填充(简单的最近邻填充)
[h, w] = size(disparity_processed);
for i = 2:h-1
for j = 2:w-1
if disparity_processed(i, j) == 0
% 查找最近的合法视差
neighbors = disparity_processed(i-1:i+1, j-1:j+1);
valid_neighbors = neighbors(neighbors > 0);
if ~isempty(valid_neighbors)
disparity_processed(i, j) = median(valid_neighbors);
end
end
end
end
% 3. 双边滤波(保边平滑)
disparity_processed = imbilatfilt(disparity_processed);
end
算法关键参数说明
| 参数 | 推荐值 | 作用 |
|---|---|---|
max_disparity |
32-128 | 视差搜索范围,越大计算越慢 |
window_size |
3-9 | 匹配窗口大小,奇数 |
P1 |
0.05-0.2 | 小视差变化惩罚系数 |
P2 |
0.2-0.8 | 大视差变化惩罚系数 |
num_directions |
4或8 | 聚合路径数,8更精确但更慢 |
census_size |
5-9 | Census变换窗口大小 |
参考代码 matlab实现的SGM算法 www.youwenfan.com/contentcsq/63909.html
使用建议
- 测试图像:使用标准的立体图像对(如Middlebury数据集)
- 内存优化 :对于大图像,考虑分块处理或降低
max_disparity - 性能提升 :
- 使用Mex/C++实现关键函数
- 并行化方向聚合
- 使用GPU加速
功能扩展
matlab
% 可选的扩展功能
function disparity_map = sgm_advanced(I_left, I_right, params)
% 1. 多尺度处理(金字塔)
% 2. 自适应P2参数
% 3. 置信度图计算
% 4. 多视图SGM
% 5. 深度学习特征结合
end