基于反向传播算法实现手写数字识别的MATLAB实现

一、数据预处理(MNIST数据集)

matlab 复制代码
%% 加载MNIST数据集
[XTrain, YTrain] = digitTrain4DArrayData;
[XTest, YTest] = digitTest4DArrayData;

%% 数据归一化与格式转换
XTrain = double(reshape(XTrain, [], 28 * 28)) / 255;  % 60000x784
XTest = double(reshape(XTest, [], 28 * 28)) / 255;    % 10000x784

% 标签转换为one-hot编码
YTrain = full(ind2vec(YTrain' + 1));  % 60000x10
YTest = full(ind2vec(YTest' + 1));    % 10000x10

二、神经网络架构设计

matlab 复制代码
%% 网络参数设置
inputSize = 784;    % 输入层节点数
hiddenSize = 100;   % 隐藏层节点数
outputSize = 10;    % 输出层节点数
learningRate = 0.1; % 初始学习率
epochs = 50;        % 训练轮次
batchSize = 64;     % 批处理大小

%% 初始化权重参数(He初始化)
W1 = randn(inputSize, hiddenSize) * sqrt(2/inputSize);
b1 = zeros(1, hiddenSize);
W2 = randn(hiddenSize, outputSize) * sqrt(2/hiddenSize);
b2 = zeros(1, outputSize);

三、反向传播算法实现

matlab 复制代码
%% 训练过程
numSamples = size(XTrain, 1);
lossHistory = zeros(epochs, 1);

for epoch = 1:epochs
    % 随机打乱数据
    idx = randperm(numSamples);
    XTrain = XTrain(idx, :);
    YTrain = YTrain(idx, :);
    
    for batch = 1:batchSize:numSamples
        % 获取当前批次数据
        X_batch = XTrain(batch:min(batch+batchSize-1, numSamples), :);
        Y_batch = YTrain(batch:min(batch+batchSize-1, numSamples), :);
        
        % 前向传播
        Z1 = X_batch * W1 + b1;
        A1 = 1./(1 + exp(-Z1));  % Sigmoid激活
        Z2 = A1 * W2 + b2;
        A2 = softmax(Z2, 2);     % Softmax输出
        
        % 计算损失(交叉熵)
        loss = -sum(sum(Y_batch .* log(A2 + eps))) / batchSize;
        lossHistory(epoch) = lossHistory(epoch) + loss;
        
        % 反向传播
        dZ2 = (A2 - Y_batch) / batchSize;
        dW2 = (A1' * dZ2)';
        db2 = sum(dZ2, 1);
        
        dA1 = dZ2 * W2';
        dZ1 = dA1 .* A1 .* (1 - A1);
        dW1 = (X_batch' * dZ1)';
        db1 = sum(dZ1, 1);
        
        % 参数更新(动量优化)
        momentum = 0.9;
        W1 = W1 - learningRate * dW1 + momentum*(W1 - W1_prev);
        b1 = b1 - learningRate * db1 + momentum*(b1 - b1_prev);
        W2 = W2 - learningRate * dW2 + momentum*(W2 - W2_prev);
        b2 = b2 - learningRate * db2 + momentum*(b2 - b2_prev);
        
        % 保存当前参数
        W1_prev = W1;
        b1_prev = b1;
        W2_prev = W2;
        b2_prev = b2;
    end
    
    % 学习率衰减
    learningRate = learningRate * 0.95;
end

%% 绘制损失曲线
plot(1:epochs, lossHistory);
xlabel('Epoch'); ylabel('Loss'); title('Training Loss');

四、模型测试与评估

matlab 复制代码
%% 测试集预测
Z1_test = XTest * W1 + b1;
A1_test = 1./(1 + exp(-Z1_test));
Z2_test = A1_test * W2 + b2;
[~, Y_pred] = max(Z2_test, [], 2);
[~, Y_true] = max(YTest, [], 2);

%% 性能指标计算
accuracy = mean(Y_pred == Y_true);
fprintf('测试集准确率: %.2f%%
', accuracy*100);

%% 混淆矩阵分析
confMat = confusionmat(Y_true, Y_pred);
confusionchart(confMat);
title('Confusion Matrix');

五、GUI界面实现(手写数字识别)

matlab 复制代码
%% 创建GUI界面
fig = uifigure('Name','手写数字识别', 'Position',[100 100 800 600]);
ax = uiaxes(fig, 'Position',[0.1 0.2 0.7 0.7]);
btn = uibutton(fig, 'Text','开始识别', 'Position',[300 50 150 40], ...
    'ButtonPushedFcn', @(btn,event) predict_digit());

%% 图像预处理函数
function img = preprocess(img)
    img = imresize(img, [28 28]);  % 调整尺寸
    img = imbinarize(rgb2gray(img));  % 二值化
    img = double(img(:)');  % 转换为向量
end

%% 识别回调函数
function predict_digit()
    % 获取绘图区域图像
    img = getimage(ax);
    
    % 预处理
    img_processed = preprocess(img);
    
    % 网络预测
    Z1 = img_processed * W1 + b1;
    A1 = 1./(1 + exp(-Z1));
    Z2 = A1 * W2 + b2;
    [~, pred] = max(Z2);
    
    % 显示结果
    imshow(img, 'Parent', ax);
    title(sprintf('识别结果: %d', pred-1));
end

参考代码 基于matlab实现反向传播算法训练神经网络模型,实现手写体数字的识别 www.youwenfan.com/contentcsm/79659.html

六、常见问题解决

  1. 收敛速度慢 采用Adam优化器替代SGD 增加批量大小至256 使用He初始化方法
  2. 过拟合问题 添加Dropout层(比率0.3-0.5) 采用数据增强(旋转±10°,平移±2像素) 增加L2正则化强度
  3. 识别准确率低 检查标签编码是否正确(0-9对应1-10) 验证网络前向传播计算过程 调整学习率(建议初始值0.01-0.1)
相关推荐
谎言西西里11 小时前
LeetCode 热题100 --- 双指针专区
算法
代码or搬砖13 小时前
String字符串
android·java·开发语言
leo__52014 小时前
基于两步成像算法的聚束模式SAR MATLAB实现
开发语言·算法·matlab
前端小白在前进14 小时前
力扣刷题:在排序数组中查找元素的第一个和最后一个位置
数据结构·算法·leetcode
Macbethad15 小时前
自动化测试技术报告
开发语言·lua
不会画画的画师15 小时前
Go开发指南:io/ioutil包应用和迁移指南
开发语言·后端·golang
2503_9284115615 小时前
12.22 wxml语法
开发语言·前端·javascript
59803541515 小时前
【java工具类】小数、整数转中文大写
android·java·开发语言
JIngJaneIL15 小时前
基于java + vue个人博客系统(源码+数据库+文档)
java·开发语言·前端·数据库·vue.js·spring boot
某林21215 小时前
基于SLAM Toolbox的移动机器人激光建图算法原理与工程实现
stm32·嵌入式硬件·算法·slam