基于反向传播算法实现手写数字识别的MATLAB实现

一、数据预处理(MNIST数据集)

matlab 复制代码
%% 加载MNIST数据集
[XTrain, YTrain] = digitTrain4DArrayData;
[XTest, YTest] = digitTest4DArrayData;

%% 数据归一化与格式转换
XTrain = double(reshape(XTrain, [], 28 * 28)) / 255;  % 60000x784
XTest = double(reshape(XTest, [], 28 * 28)) / 255;    % 10000x784

% 标签转换为one-hot编码
YTrain = full(ind2vec(YTrain' + 1));  % 60000x10
YTest = full(ind2vec(YTest' + 1));    % 10000x10

二、神经网络架构设计

matlab 复制代码
%% 网络参数设置
inputSize = 784;    % 输入层节点数
hiddenSize = 100;   % 隐藏层节点数
outputSize = 10;    % 输出层节点数
learningRate = 0.1; % 初始学习率
epochs = 50;        % 训练轮次
batchSize = 64;     % 批处理大小

%% 初始化权重参数(He初始化)
W1 = randn(inputSize, hiddenSize) * sqrt(2/inputSize);
b1 = zeros(1, hiddenSize);
W2 = randn(hiddenSize, outputSize) * sqrt(2/hiddenSize);
b2 = zeros(1, outputSize);

三、反向传播算法实现

matlab 复制代码
%% 训练过程
numSamples = size(XTrain, 1);
lossHistory = zeros(epochs, 1);

for epoch = 1:epochs
    % 随机打乱数据
    idx = randperm(numSamples);
    XTrain = XTrain(idx, :);
    YTrain = YTrain(idx, :);
    
    for batch = 1:batchSize:numSamples
        % 获取当前批次数据
        X_batch = XTrain(batch:min(batch+batchSize-1, numSamples), :);
        Y_batch = YTrain(batch:min(batch+batchSize-1, numSamples), :);
        
        % 前向传播
        Z1 = X_batch * W1 + b1;
        A1 = 1./(1 + exp(-Z1));  % Sigmoid激活
        Z2 = A1 * W2 + b2;
        A2 = softmax(Z2, 2);     % Softmax输出
        
        % 计算损失(交叉熵)
        loss = -sum(sum(Y_batch .* log(A2 + eps))) / batchSize;
        lossHistory(epoch) = lossHistory(epoch) + loss;
        
        % 反向传播
        dZ2 = (A2 - Y_batch) / batchSize;
        dW2 = (A1' * dZ2)';
        db2 = sum(dZ2, 1);
        
        dA1 = dZ2 * W2';
        dZ1 = dA1 .* A1 .* (1 - A1);
        dW1 = (X_batch' * dZ1)';
        db1 = sum(dZ1, 1);
        
        % 参数更新(动量优化)
        momentum = 0.9;
        W1 = W1 - learningRate * dW1 + momentum*(W1 - W1_prev);
        b1 = b1 - learningRate * db1 + momentum*(b1 - b1_prev);
        W2 = W2 - learningRate * dW2 + momentum*(W2 - W2_prev);
        b2 = b2 - learningRate * db2 + momentum*(b2 - b2_prev);
        
        % 保存当前参数
        W1_prev = W1;
        b1_prev = b1;
        W2_prev = W2;
        b2_prev = b2;
    end
    
    % 学习率衰减
    learningRate = learningRate * 0.95;
end

%% 绘制损失曲线
plot(1:epochs, lossHistory);
xlabel('Epoch'); ylabel('Loss'); title('Training Loss');

四、模型测试与评估

matlab 复制代码
%% 测试集预测
Z1_test = XTest * W1 + b1;
A1_test = 1./(1 + exp(-Z1_test));
Z2_test = A1_test * W2 + b2;
[~, Y_pred] = max(Z2_test, [], 2);
[~, Y_true] = max(YTest, [], 2);

%% 性能指标计算
accuracy = mean(Y_pred == Y_true);
fprintf('测试集准确率: %.2f%%
', accuracy*100);

%% 混淆矩阵分析
confMat = confusionmat(Y_true, Y_pred);
confusionchart(confMat);
title('Confusion Matrix');

五、GUI界面实现(手写数字识别)

matlab 复制代码
%% 创建GUI界面
fig = uifigure('Name','手写数字识别', 'Position',[100 100 800 600]);
ax = uiaxes(fig, 'Position',[0.1 0.2 0.7 0.7]);
btn = uibutton(fig, 'Text','开始识别', 'Position',[300 50 150 40], ...
    'ButtonPushedFcn', @(btn,event) predict_digit());

%% 图像预处理函数
function img = preprocess(img)
    img = imresize(img, [28 28]);  % 调整尺寸
    img = imbinarize(rgb2gray(img));  % 二值化
    img = double(img(:)');  % 转换为向量
end

%% 识别回调函数
function predict_digit()
    % 获取绘图区域图像
    img = getimage(ax);
    
    % 预处理
    img_processed = preprocess(img);
    
    % 网络预测
    Z1 = img_processed * W1 + b1;
    A1 = 1./(1 + exp(-Z1));
    Z2 = A1 * W2 + b2;
    [~, pred] = max(Z2);
    
    % 显示结果
    imshow(img, 'Parent', ax);
    title(sprintf('识别结果: %d', pred-1));
end

参考代码 基于matlab实现反向传播算法训练神经网络模型,实现手写体数字的识别 www.youwenfan.com/contentcsm/79659.html

六、常见问题解决

  1. 收敛速度慢 采用Adam优化器替代SGD 增加批量大小至256 使用He初始化方法
  2. 过拟合问题 添加Dropout层(比率0.3-0.5) 采用数据增强(旋转±10°,平移±2像素) 增加L2正则化强度
  3. 识别准确率低 检查标签编码是否正确(0-9对应1-10) 验证网络前向传播计算过程 调整学习率(建议初始值0.01-0.1)
相关推荐
老欧学视觉1 小时前
0013机器学习聚类算法(无监督算法)
算法·机器学习·聚类
小鱼小鱼.oO2 小时前
C++ 算法基础知识
c++·算法·哈希算法
曹牧2 小时前
Java String[] 数组的 contains
java·开发语言·windows
yong99902 小时前
LSD直线提取算法 MATLAB
开发语言·算法·matlab
一只专注api接口开发的技术猿2 小时前
构建电商数据中台:基于淘宝 API 关键词搜索接口的设计与实现
大数据·开发语言·数据库
MobotStone2 小时前
一文看懂AI智能体架构:工程师依赖的8种LLM,到底怎么分工?
后端·算法·llm
浩瀚地学2 小时前
【Java】String
java·开发语言·经验分享·笔记·学习
lengxuenong2 小时前
潍坊一中第四届编程挑战赛(初赛)题解
算法
松涛和鸣3 小时前
25、数据结构:树与二叉树的概念、特性及递归实现
linux·开发语言·网络·数据结构·算法