一、数据预处理(MNIST数据集)
matlab
%% 加载MNIST数据集
[XTrain, YTrain] = digitTrain4DArrayData;
[XTest, YTest] = digitTest4DArrayData;
%% 数据归一化与格式转换
XTrain = double(reshape(XTrain, [], 28 * 28)) / 255; % 60000x784
XTest = double(reshape(XTest, [], 28 * 28)) / 255; % 10000x784
% 标签转换为one-hot编码
YTrain = full(ind2vec(YTrain' + 1)); % 60000x10
YTest = full(ind2vec(YTest' + 1)); % 10000x10
二、神经网络架构设计
matlab
%% 网络参数设置
inputSize = 784; % 输入层节点数
hiddenSize = 100; % 隐藏层节点数
outputSize = 10; % 输出层节点数
learningRate = 0.1; % 初始学习率
epochs = 50; % 训练轮次
batchSize = 64; % 批处理大小
%% 初始化权重参数(He初始化)
W1 = randn(inputSize, hiddenSize) * sqrt(2/inputSize);
b1 = zeros(1, hiddenSize);
W2 = randn(hiddenSize, outputSize) * sqrt(2/hiddenSize);
b2 = zeros(1, outputSize);
三、反向传播算法实现
matlab
%% 训练过程
numSamples = size(XTrain, 1);
lossHistory = zeros(epochs, 1);
for epoch = 1:epochs
% 随机打乱数据
idx = randperm(numSamples);
XTrain = XTrain(idx, :);
YTrain = YTrain(idx, :);
for batch = 1:batchSize:numSamples
% 获取当前批次数据
X_batch = XTrain(batch:min(batch+batchSize-1, numSamples), :);
Y_batch = YTrain(batch:min(batch+batchSize-1, numSamples), :);
% 前向传播
Z1 = X_batch * W1 + b1;
A1 = 1./(1 + exp(-Z1)); % Sigmoid激活
Z2 = A1 * W2 + b2;
A2 = softmax(Z2, 2); % Softmax输出
% 计算损失(交叉熵)
loss = -sum(sum(Y_batch .* log(A2 + eps))) / batchSize;
lossHistory(epoch) = lossHistory(epoch) + loss;
% 反向传播
dZ2 = (A2 - Y_batch) / batchSize;
dW2 = (A1' * dZ2)';
db2 = sum(dZ2, 1);
dA1 = dZ2 * W2';
dZ1 = dA1 .* A1 .* (1 - A1);
dW1 = (X_batch' * dZ1)';
db1 = sum(dZ1, 1);
% 参数更新(动量优化)
momentum = 0.9;
W1 = W1 - learningRate * dW1 + momentum*(W1 - W1_prev);
b1 = b1 - learningRate * db1 + momentum*(b1 - b1_prev);
W2 = W2 - learningRate * dW2 + momentum*(W2 - W2_prev);
b2 = b2 - learningRate * db2 + momentum*(b2 - b2_prev);
% 保存当前参数
W1_prev = W1;
b1_prev = b1;
W2_prev = W2;
b2_prev = b2;
end
% 学习率衰减
learningRate = learningRate * 0.95;
end
%% 绘制损失曲线
plot(1:epochs, lossHistory);
xlabel('Epoch'); ylabel('Loss'); title('Training Loss');
四、模型测试与评估
matlab
%% 测试集预测
Z1_test = XTest * W1 + b1;
A1_test = 1./(1 + exp(-Z1_test));
Z2_test = A1_test * W2 + b2;
[~, Y_pred] = max(Z2_test, [], 2);
[~, Y_true] = max(YTest, [], 2);
%% 性能指标计算
accuracy = mean(Y_pred == Y_true);
fprintf('测试集准确率: %.2f%%
', accuracy*100);
%% 混淆矩阵分析
confMat = confusionmat(Y_true, Y_pred);
confusionchart(confMat);
title('Confusion Matrix');
五、GUI界面实现(手写数字识别)
matlab
%% 创建GUI界面
fig = uifigure('Name','手写数字识别', 'Position',[100 100 800 600]);
ax = uiaxes(fig, 'Position',[0.1 0.2 0.7 0.7]);
btn = uibutton(fig, 'Text','开始识别', 'Position',[300 50 150 40], ...
'ButtonPushedFcn', @(btn,event) predict_digit());
%% 图像预处理函数
function img = preprocess(img)
img = imresize(img, [28 28]); % 调整尺寸
img = imbinarize(rgb2gray(img)); % 二值化
img = double(img(:)'); % 转换为向量
end
%% 识别回调函数
function predict_digit()
% 获取绘图区域图像
img = getimage(ax);
% 预处理
img_processed = preprocess(img);
% 网络预测
Z1 = img_processed * W1 + b1;
A1 = 1./(1 + exp(-Z1));
Z2 = A1 * W2 + b2;
[~, pred] = max(Z2);
% 显示结果
imshow(img, 'Parent', ax);
title(sprintf('识别结果: %d', pred-1));
end
参考代码 基于matlab实现反向传播算法训练神经网络模型,实现手写体数字的识别 www.youwenfan.com/contentcsm/79659.html
六、常见问题解决
- 收敛速度慢 采用Adam优化器替代SGD 增加批量大小至256 使用He初始化方法
- 过拟合问题 添加Dropout层(比率0.3-0.5) 采用数据增强(旋转±10°,平移±2像素) 增加L2正则化强度
- 识别准确率低 检查标签编码是否正确(0-9对应1-10) 验证网络前向传播计算过程 调整学习率(建议初始值0.01-0.1)