基于梯度下降、随机梯度下降和牛顿法的逻辑回归MATLAB实现

一、核心算法
1. 数据预处理与模型初始化
matlab 复制代码
%% 数据准备
% 生成示例数据(二分类)
[X, y] = make_classification(n_samples=100, n_features=2, n_redundant=0, n_clusters_per_class=1);
X = [ones(size(X,1),1) X]; % 添加截距项

%% 参数设置
learning_rate = 0.1; % 学习率
max_iter = 1000;     % 最大迭代次数
tol = 1e-4;          % 收敛阈值
2. Sigmoid函数定义
matlab 复制代码
function g = sigmoid(z)
    g = 1.0 ./ (1.0 + exp(-z));
end
3. 损失函数与梯度计算
matlab 复制代码
function [J, grad] = computeCost(X, y, theta)
    m = length(y);
    h = sigmoid(X * theta);
    J = (-1/m) * sum(y .* log(h) + (1-y) .* log(1-h));
    grad = (1/m) * X' * (h - y);
end

二、梯度下降法实现
matlab 复制代码
function theta = gradientDescent(X, y, alpha, max_iter)
    [m, n] = size(X);
    theta = zeros(n,1);
    J_history = zeros(max_iter,1);
    
    for iter = 1:max_iter
        [J, grad] = computeCost(X, y, theta);
        theta = theta - alpha * grad;
        J_history(iter) = J;
        
        % 收敛判断
        if iter > 1 && abs(J_history(iter) - J_history(iter-1)) < tol
            break;
        end
    end
end

三、随机梯度下降法实现
matlab 复制代码
function theta = stochasticGradientDescent(X, y, alpha, max_epoch)
    [m, n] = size(X);
    theta = zeros(n,1);
    J_history = zeros(max_epoch,1);
    
    for epoch = 1:max_epoch
        % 随机打乱样本顺序
        idx = randperm(m);
        X = X(idx,:);
        y = y(idx);
        
        for i = 1:m
            xi = X(i,:)';
            yi = y(i);
            [h, grad] = computeCost(X, y, theta);
            theta = theta - alpha * grad;
        end
        
        % 记录平均损失
        J_history(epoch) = computeCost(X, y, theta);
    end
end

四、牛顿法实现
matlab 复制代码
function theta = newtonMethod(X, y, max_iter)
    [m, n] = size(X);
    theta = zeros(n,1);
    H = zeros(n,n);
    
    for iter = 1:max_iter
        h = sigmoid(X * theta);
        grad = (1/m) * X' * (h - y);
        
        % 计算Hessian矩阵(添加正则化项)
        S = diag(h .* (1-h));
        H = (1/m) * X' * S * X + 1e-4 * eye(n); % 正则化防止奇异
        
        % 参数更新
        theta = theta - H \ grad;
        
        % 收敛判断
        if iter > 1 && norm(grad) < tol
            break;
        end
    end
end

五、性能对比与可视化
matlab 复制代码
%% 训练与比较
theta_gd = gradientDescent(X, y, 0.1, 1000);
theta_sgd = stochasticGradientDescent(X, y, 0.1, 1000);
theta_newton = newtonMethod(X, y, 100);

%% 决策边界绘制
figure;
hold on;
scatter(X(y==1,2), X(y==1,3), 'r', 'filled');
scatter(X(y==0,2), X(y==0,3), 'b', 'filled');

% 绘制各方法决策边界
x1 = linspace(min(X(:,2)), max(X(:,2)), 100);
x2_gd = (-theta_gd(1) - theta_gd(2)*x1) / theta_gd(3);
x2_sgd = (-theta_sgd(1) - theta_sgd(2)*x1) / theta_sgd(3);
x2_newton = (-theta_newton(1) - theta_newton(2)*x1) / theta_newton(3);

plot(x1, x2_gd, 'g', 'LineWidth', 2);
plot(x1, x2_sgd, 'm', 'LineWidth', 2);
plot(x1, x2_newton, 'k', 'LineWidth', 2);
legend('Positive', 'Negative', 'GD', 'SGD', 'Newton');
title('不同优化方法的决策边界对比');

六、优化
  1. 学习率调整
    • 动态衰减策略:alpha = initial_alpha / (1 + decay_rate * iter)
    • 自适应方法:结合AdaGrad或RMSProp
  2. 正则化增强
    • L2正则化:在损失函数中添加 lambda/2 * sum(theta(2:end).^2)
    • L1正则化:使用次梯度法处理稀疏性
  3. 数值稳定性优化
    • Sigmoid函数截断:sigmoid(z) = max(min(z, 30), -30)
    • Hessian矩阵正则化:添加小量对角项防止奇异

参考代码 使用梯度下降法、随机梯度下降法和牛顿法实现的逻辑回归算法 www.youwenfan.com/contentcsi/59877.html

七、扩展应用示例
matlab 复制代码
%% 多分类扩展(One-vs-All)
function models = oneVsAll(X, y, num_classes, method)
    models = cell(num_classes,1);
    for c = 1:num_classes
        % 二分类标签转换
        binary_y = (y == c);
        % 训练单个分类器
        switch method
            case 'gd'
                models{c} = gradientDescent(X, binary_y);
            case 'sgd'
                models{c} = stochasticGradientDescent(X, binary_y);
            case 'newton'
                models{c} = newtonMethod(X, binary_y);
        end
    end
end
相关推荐
熊猫_豆豆3 小时前
目前顶尖AI所用算法,包含的数学内容,详细列举
人工智能·算法
野犬寒鸦3 小时前
从零起步学习Redis || 第二章:Redis中数据类型的深层剖析讲解(下)
java·redis·后端·算法·哈希算法
java1234_小锋3 小时前
Scikit-learn Python机器学习 - 回归分析算法 - 弹性网络 (Elastic-Net)
python·算法·机器学习
hn小菜鸡4 小时前
LeetCode 2570.合并两个二维数组-求和法
数据结构·算法·leetcode
hn小菜鸡4 小时前
LeetCode 524.通过删除字母匹配到字典里最长单词
算法·leetcode·职场和发展
Greedy Alg4 小时前
LeetCode 226. 翻转二叉树
算法
我要成为c嘎嘎大王4 小时前
【C++】模版专题
c++·算法
jndingxin4 小时前
算法面试(5)------NMS(非极大值抑制)原理 Soft-NMS、DIoU-NMS 是什么?
人工智能·算法·目标跟踪
苏纪云4 小时前
算法<java>——排序(冒泡、插入、选择、归并、快速、计数、堆、桶、基数)
java·开发语言·算法