MATLAB SGM(半全局匹配)算法实现

MATLAB SGM(半全局匹配)算法实现,用于立体视觉中的稠密视差图计算。

SGM算法MATLAB实现

matlab 复制代码
%% SGM(半全局匹配)算法MATLAB实现
clear all; close all; clc;

%% 1. 读取立体图像对
% 确保有left.png和right.png在相同目录下,或提供完整路径
I_left = imread('left.png');
I_right = imread('right.png');

% 转换为灰度图(如果是彩色图像)
if size(I_left, 3) == 3
    I_left_gray = rgb2gray(I_left);
    I_right_gray = rgb2gray(I_right);
else
    I_left_gray = I_left;
    I_right_gray = I_right;
end

% 转换为double类型以便计算
I_left_gray = im2double(I_left_gray);
I_right_gray = im2double(I_right_gray);

%% 2. 算法参数设置
params = struct();
params.max_disparity = 64;      % 最大视差搜索范围
params.window_size = 3;         % 匹配窗口大小(奇数)
params.P1 = 0.1;                % 惩罚系数P1(小视差变化)
params.P2 = 0.3;                % 惩罚系数P2(大视差变化)
params.num_directions = 8;      % 聚合路径方向数(通常4或8)
params.census_size = 5;         % Census变换窗口大小(奇数)
params.enable_subpixel = true;  % 启用亚像素优化
params.enable_lr_check = true;  % 启用左右一致性检查
params.lr_threshold = 1;        % 左右一致性检查阈值

%% 3. Census变换(计算特征描述子)
fprintf('进行Census变换...\n');
[C_left, C_right] = census_transform(I_left_gray, I_right_gray, params.census_size);
fprintf('Census变换完成\n');

%% 4. 计算匹配代价(基于Hamming距离)
fprintf('计算初始匹配代价...\n');
cost_volume = compute_matching_cost(C_left, C_right, params.max_disparity, params.window_size);
fprintf('匹配代价计算完成\n');

%% 5. 代价聚合(SGM核心)
fprintf('进行SGM代价聚合...\n');
aggregated_cost = sgm_cost_aggregation(cost_volume, params);
fprintf('代价聚合完成\n');

%% 6. 视差计算(赢家通吃)
fprintf('计算初始视差图...\n');
disparity_map = compute_disparity(aggregated_cost);
fprintf('视差计算完成\n');

%% 7. 亚像素优化(可选)
if params.enable_subpixel
    fprintf('进行亚像素优化...\n');
    disparity_map = subpixel_refinement(disparity_map, aggregated_cost);
    fprintf('亚像素优化完成\n');
end

%% 8. 左右一致性检查(可选)
if params.enable_lr_check
    fprintf('进行左右一致性检查...\n');
    % 计算右视差图
    [~, C_right_flip] = census_transform(I_right_gray, I_left_gray, params.census_size);
    cost_volume_right = compute_matching_cost(C_right_flip, C_left, params.max_disparity, params.window_size);
    aggregated_cost_right = sgm_cost_aggregation(cost_volume_right, params);
    disparity_map_right = compute_disparity(aggregated_cost_right);
    
    % 一致性检查
    disparity_map = left_right_check(disparity_map, disparity_map_right, params.lr_threshold);
    fprintf('一致性检查完成\n');
end

%% 9. 视差图后处理
fprintf('进行视差图后处理...\n');
disparity_map = disparity_postprocessing(disparity_map, params.max_disparity);
fprintf('后处理完成\n');

%% 10. 结果显示
figure('Position', [100, 100, 1200, 500]);

% 原始图像
subplot(2, 3, 1);
imshow(I_left_gray);
title('左图像');
subplot(2, 3, 2);
imshow(I_right_gray);
title('右图像');

% Census特征图
subplot(2, 3, 3);
imshow(C_left(:,:,1), []);
title('Census特征图(第一通道)');

% 代价体积切片(在特定行)
subplot(2, 3, 4);
row_to_show = round(size(cost_volume, 1)/2);
cost_slice = squeeze(cost_volume(row_to_show, :, :));
imagesc(cost_slice');
xlabel('图像列'); ylabel('视差'); title('匹配代价切片');
colorbar; axis xy;

% 视差图
subplot(2, 3, 5);
imagesc(disparity_map, [0 params.max_disparity]);
colormap('jet'); colorbar;
axis image; title('SGM视差图');

% 3D点云显示(可选)
subplot(2, 3, 6);
show_point_cloud = true;
if show_point_cloud
    % 生成简单的3D显示(需要调整参数)
    [h, w] = size(disparity_map);
    [X, Y] = meshgrid(1:w, 1:h);
    Z = disparity_map;
    
    % 下采样以加快显示
    step = 4;
    X_small = X(1:step:end, 1:step:end);
    Y_small = Y(1:step:end, 1:step:end);
    Z_small = Z(1:step:end, 1:step:end);
    
    surf(X_small, Y_small, Z_small, 'EdgeColor', 'none', 'FaceAlpha', 0.7);
    colormap('jet'); colorbar;
    axis tight; view(-30, 30);
    title('3D视差表面');
    xlabel('X'); ylabel('Y'); zlabel('视差');
end

fprintf('\nSGM算法完成!\n');
fprintf('视差范围:[%.2f, %.2f]\n', min(disparity_map(:)), max(disparity_map(:)));

%% 关键函数实现

% ==============================================
% 函数1: Census变换
% ==============================================
function [C_left, C_right] = census_transform(I_left, I_right, window_size)
    [h, w] = size(I_left);
    half_win = floor(window_size/2);
    
    % 扩展图像边界
    I_left_pad = padarray(I_left, [half_win, half_win], 'replicate');
    I_right_pad = padarray(I_right, [half_win, half_win], 'replicate');
    
    % 初始化Census特征
    C_left = zeros(h, w, window_size^2-1, 'uint8');
    C_right = zeros(h, w, window_size^2-1, 'uint8');
    
    for i = 1:h
        for j = 1:w
            % 提取局部窗口
            window_left = I_left_pad(i:i+window_size-1, j:j+window_size-1);
            window_right = I_right_pad(i:i+window_size-1, j:j+window_size-1);
            
            % 中心像素值
            center_left = window_left(half_win+1, half_win+1);
            center_right = window_right(half_win+1, half_win+1);
            
            % 计算Census特征(与中心像素比较)
            census_left = [];
            census_right = [];
            idx = 1;
            
            for m = 1:window_size
                for n = 1:window_size
                    if (m == half_win+1) && (n == half_win+1)
                        continue; % 跳过中心像素
                    end
                    
                    % 左图像Census特征
                    if window_left(m, n) >= center_left
                        census_left(idx) = 1;
                    else
                        census_left(idx) = 0;
                    end
                    
                    % 右图像Census特征
                    if window_right(m, n) >= center_right
                        census_right(idx) = 1;
                    else
                        census_right(idx) = 0;
                    end
                    
                    idx = idx + 1;
                end
            end
            
            C_left(i, j, :) = census_left;
            C_right(i, j, :) = census_right;
        end
    end
end

% ==============================================
% 函数2: 计算匹配代价
% ==============================================
function cost_volume = compute_matching_cost(C_left, C_right, max_disparity, window_size)
    [h, w, feature_dim] = size(C_left);
    half_win = floor(window_size/2);
    
    % 初始化代价体积
    cost_volume = inf(h, w, max_disparity+1, 'single');
    
    % 遍历所有像素
    for i = half_win+1:h-half_win
        for j = half_win+1:w-half_win
            % 提取左图像特征窗口
            left_window = C_left(i-half_win:i+half_win, j-half_win:j+half_win, :);
            
            % 遍历所有可能视差
            for d = 0:max_disparity
                if j-d < half_win+1
                    continue; % 超出图像边界
                end
                
                % 提取右图像对应窗口
                right_window = C_right(i-half_win:i+half_win, j-d-half_win:j-d+half_win, :);
                
                % 计算Hamming距离(特征不同的位数)
                diff_bits = sum(sum(sum(left_window ~= right_window, 3), 2), 1);
                cost_volume(i, j, d+1) = diff_bits;
            end
        end
        % 显示进度
        if mod(i, 50) == 0
            fprintf('  处理进度: %.1f%%\n', 100*i/h);
        end
    end
end

% ==============================================
% 函数3: SGM代价聚合
% ==============================================
function aggregated_cost = sgm_cost_aggregation(cost_volume, params)
    [h, w, num_disp] = size(cost_volume);
    aggregated_cost = zeros(h, w, num_disp, 'single');
    
    % 定义聚合方向(8方向)
    directions = [
         1,  0;  % 右
         0,  1;  % 下
         1,  1;  % 右下
        -1,  1;  % 左下
        -1,  0;  % 左
         0, -1;  % 上
        -1, -1;  % 左上
         1, -1;  % 右上
    ];
    
    num_directions = min(params.num_directions, size(directions, 1));
    
    % 对每个方向进行聚合
    for dir_idx = 1:num_directions
        fprintf('  聚合方向 %d/%d\n', dir_idx, num_directions);
        
        dx = directions(dir_idx, 1);
        dy = directions(dir_idx, 2);
        
        % 沿当前方向聚合
        dir_cost = aggregate_direction(cost_volume, dx, dy, params.P1, params.P2);
        aggregated_cost = aggregated_cost + dir_cost;
    end
    
    aggregated_cost = aggregated_cost / num_directions;
end

% ==============================================
% 函数4: 单方向代价聚合
% ==============================================
function dir_cost = aggregate_direction(cost_volume, dx, dy, P1, P2)
    [h, w, num_disp] = size(cost_volume);
    dir_cost = zeros(h, w, num_disp, 'single');
    
    % 确定遍历顺序
    if dx >= 0 && dy >= 0
        start_i = 1; end_i = h; step_i = 1;
        start_j = 1; end_j = w; step_j = 1;
    elseif dx >= 0 && dy < 0
        start_i = h; end_i = 1; step_i = -1;
        start_j = 1; end_j = w; step_j = 1;
    elseif dx < 0 && dy >= 0
        start_i = 1; end_i = h; step_i = 1;
        start_j = w; end_j = 1; step_j = -1;
    else
        start_i = h; end_i = 1; step_i = -1;
        start_j = w; end_j = 1; step_j = -1;
    end
    
    % 动态规划聚合
    for i = start_i:step_i:end_i
        for j = start_j:step_j:end_j
            % 前一个像素位置
            prev_i = i - dy;
            prev_j = j - dx;
            
            % 如果是起始像素,直接使用初始代价
            if prev_i < 1 || prev_i > h || prev_j < 1 || prev_j > w
                dir_cost(i, j, :) = cost_volume(i, j, :);
                continue;
            end
            
            % 获取前一个像素的聚合代价
            L_prev = squeeze(dir_cost(prev_i, prev_j, :));
            
            % 获取当前像素的初始代价
            C_current = squeeze(cost_volume(i, j, :));
            
            % 计算当前像素的聚合代价
            L_current = compute_path_cost(L_prev, C_current, P1, P2);
            
            dir_cost(i, j, :) = L_current;
        end
    end
end

% ==============================================
% 函数5: 计算路径代价
% ==============================================
function L_current = compute_path_cost(L_prev, C_current, P1, P2)
    num_disp = length(C_current);
    L_current = zeros(num_disp, 1, 'single');
    
    % 找到前一个像素的最小代价
    min_L_prev = min(L_prev);
    
    % 对每个视差计算路径代价
    for d = 1:num_disp
        % 计算各种情况的代价
        cost1 = L_prev(d);                         % 相同视差
        cost2 = min(L_prev) + P1;                  % 小视差变化
        cost3 = min(L_prev(1:num_disp ~= d)) + P2; % 大视差变化
        
        % 取最小值
        L_current(d) = C_current(d) + min([cost1, cost2, cost3]) - min_L_prev;
    end
end

% ==============================================
% 函数6: 计算视差图
% ==============================================
function disparity_map = compute_disparity(aggregated_cost)
    [h, w, num_disp] = size(aggregated_cost);
    disparity_map = zeros(h, w, 'single');
    
    % 赢家通吃(WTA)策略
    for i = 1:h
        for j = 1:w
            % 找到最小代价对应的视差
            [~, best_disp] = min(aggregated_cost(i, j, :));
            disparity_map(i, j) = best_disp - 1; % 转换为0-based视差
        end
    end
end

% ==============================================
% 函数7: 亚像素优化
% ==============================================
function disparity_refined = subpixel_refinement(disparity_map, aggregated_cost)
    [h, w] = size(disparity_map);
    disparity_refined = disparity_map;
    
    for i = 2:h-1
        for j = 2:w-1
            d = round(disparity_map(i, j)) + 1; % 转换为1-based索引
            
            if d < 2 || d > size(aggregated_cost, 3)-1
                continue; % 边界情况
            end
            
            % 获取当前视差及其相邻视差的代价
            C_minus1 = aggregated_cost(i, j, d-1);
            C_0 = aggregated_cost(i, j, d);
            C_plus1 = aggregated_cost(i, j, d+1);
            
            % 抛物线拟合
            delta = (C_minus1 - C_plus1) / (2 * (C_minus1 - 2*C_0 + C_plus1 + eps));
            disparity_refined(i, j) = disparity_map(i, j) + delta;
        end
    end
end

% ==============================================
% 函数8: 左右一致性检查
% ==============================================
function disparity_filtered = left_right_check(disparity_left, disparity_right, threshold)
    disparity_filtered = disparity_left;
    [h, w] = size(disparity_left);
    
    for i = 1:h
        for j = 1:w
            d_left = disparity_left(i, j);
            
            % 对应的右图像素位置
            j_right = round(j - d_left);
            
            if j_right >= 1 && j_right <= w
                d_right = disparity_right(i, j_right);
                
                % 检查一致性
                if abs(d_left - d_right) > threshold
                    disparity_filtered(i, j) = 0; % 标记为无效
                end
            else
                disparity_filtered(i, j) = 0; % 边界无效
            end
        end
    end
end

% ==============================================
% 函数9: 视差图后处理
% ==============================================
function disparity_processed = disparity_postprocessing(disparity_map, max_disparity)
    disparity_processed = disparity_map;
    
    % 1. 中值滤波(去除噪声)
    disparity_processed = medfilt2(disparity_processed, [3, 3]);
    
    % 2. 空洞填充(简单的最近邻填充)
    [h, w] = size(disparity_processed);
    for i = 2:h-1
        for j = 2:w-1
            if disparity_processed(i, j) == 0
                % 查找最近的合法视差
                neighbors = disparity_processed(i-1:i+1, j-1:j+1);
                valid_neighbors = neighbors(neighbors > 0);
                if ~isempty(valid_neighbors)
                    disparity_processed(i, j) = median(valid_neighbors);
                end
            end
        end
    end
    
    % 3. 双边滤波(保边平滑)
    disparity_processed = imbilatfilt(disparity_processed);
end

算法关键参数说明

参数 推荐值 作用
max_disparity 32-128 视差搜索范围,越大计算越慢
window_size 3-9 匹配窗口大小,奇数
P1 0.05-0.2 小视差变化惩罚系数
P2 0.2-0.8 大视差变化惩罚系数
num_directions 4或8 聚合路径数,8更精确但更慢
census_size 5-9 Census变换窗口大小

参考代码 matlab实现的SGM算法 www.youwenfan.com/contentcsq/63909.html

使用建议

  1. 测试图像:使用标准的立体图像对(如Middlebury数据集)
  2. 内存优化 :对于大图像,考虑分块处理或降低max_disparity
  3. 性能提升
    • 使用Mex/C++实现关键函数
    • 并行化方向聚合
    • 使用GPU加速

功能扩展

matlab 复制代码
% 可选的扩展功能
function disparity_map = sgm_advanced(I_left, I_right, params)
    % 1. 多尺度处理(金字塔)
    % 2. 自适应P2参数
    % 3. 置信度图计算
    % 4. 多视图SGM
    % 5. 深度学习特征结合
end
相关推荐
独自破碎E2 小时前
大整数哈希
算法·哈希算法
纤纡.2 小时前
逻辑回归实战进阶:交叉验证与采样技术破解数据痛点(二)
算法·机器学习·逻辑回归
czhc11400756632 小时前
协议 25
java·开发语言·算法
范纹杉想快点毕业2 小时前
状态机设计与嵌入式系统开发完整指南从面向过程到面向对象,从理论到实践的全面解析
linux·服务器·数据库·c++·算法·mongodb·mfc
fish-man2 小时前
测试加粗效果
算法
晓13132 小时前
第二章 【C语言篇:入门】 C 语言基础入门
c语言·算法
yong99902 小时前
MATLAB面波频散曲线反演程序
开发语言·算法·matlab
JicasdC123asd3 小时前
【工业检测】基于YOLO13-C3k2-EIEM的铸造缺陷检测与分类系统_1
人工智能·算法·分类
会编程的土豆3 小时前
新手前端小细节
前端·css·html·项目