深度学习神经网络实战:多层感知机,手写数字识别

目的

利用tensorflow.js训练模型,搭建神经网络模型,完成手写数字识别

设计

简单三层神经网络

  • 输入层
    28*28个神经原,代表每一张手写数字图片的灰度
  • 隐藏层
    100个神经原
  • 输出层
    -10个神经原,分别代表10个数字

代码

复制代码
// 导入 TensorFlow.js 库
import tf from "@tensorflow/tfjs";
import * as tfjsnode from "@tensorflow/tfjs-node";
import * as tfvis from "@tensorflow/tfjs-vis";
import fs from "fs";
import plot from "nodeplotlib";
// 定义模型
const model = tf.sequential();

// 添加输入层
model.add(
  tf.layers.dense({ units: 64, inputShape: [784], activation: "relu" })
);

// 添加隐藏层
model.add(tf.layers.dense({ units: 100, activation: "relu" }));

// 添加输出层
model.add(tf.layers.dense({ units: 10, activation: "softmax" }));

// 编译模型
model.compile({
  optimizer: "sgd",
  loss: "categoricalCrossentropy",
  metrics: ["accuracy"],
});
const trainDataLen = 3000;
const testDataLen = 2000;

// 加载 MNIST 数据集
import pkg from "mnist";
const { set: Dataset } = pkg;
const set = Dataset(trainDataLen, testDataLen);
const trainingSet = set.training;
const testSet = set.test;

const trainXs = [];
const testXs = [];

const trainLabels = [];
const testLabels = [];

for (let i = 0; i < trainingSet.length; i++) {
  trainXs.push(trainingSet[i].input);
  trainLabels.push(trainingSet[i].output.indexOf(1));
}

for (let i = 0; i < testSet.length; i++) {
  testXs.push(testSet[i].input);
  testLabels.push(testSet[i].output.indexOf(1));
}

// 准备数据
const trainXsTensor = tf.tensor(trainXs, [trainDataLen, 784]);
const trainYsOneHot = tf.oneHot(trainLabels, 10);

//记录每轮模型训练中的损失和精度,为了绘制曲线图
var accPlot = [];
var lossPlot = [];

// 模型训练
model
  .fit(trainXsTensor, trainYsOneHot, {
    batchSize: 64,
    epochs: 100,
    validationSplit: 0.2,
    callbacks: {
      onEpochBegin: (epoch) => console.log(`Epoch ${epoch + 1} started...`),
      onEpochEnd: async (epoch, logs) => {
        console.log(
          `Epoch ${epoch + 1} completed. Loss: ${logs.loss.toFixed(
            3
          )}, Accuracy: ${logs.acc.toFixed(3)}`
        );
        //记录loss和acc,绘制曲线图
        accPlot.push(logs.acc.toFixed(3));
        lossPlot.push(logs.loss.toFixed(3));

        await tf.nextFrame(); // 防止阻塞
      },
      onBatchEnd: async (batch, logs) => {
        console.log(
          `Batch ${batch} completed. Loss: ${logs.loss.toFixed(
            3
          )}, Accuracy: ${logs.acc.toFixed(3)}`
        );
        await tf.nextFrame(); // 防止阻塞
      },
    },
  })
  .then((history) => {
    console.log("Training completed!", history);
    //绘制模型训练过程中的损失函数和模型精度曲线变化
    const epochs = Array.from({ length: lossPlot.length }, (_, i) => i + 1);
    plot.plot(
      [
        { x: epochs, y: lossPlot, name: "Loss" },
        { x: epochs, y: accPlot, name: "Accuracy" },
      ],
      {
        filename: "loss_acc.png",
      }
    );

    //模型评估
    const testXsTensor = tf.tensor(testXs, [testDataLen, 784]);
    const testYsOneHot = tf.oneHot(testLabels, 10);

    const result = model.evaluate(testXsTensor, testYsOneHot);
    const testLoss = result[0].dataSync()[0];
    const testAccuracy = result[1].dataSync()[0];

    console.log(`Test loss: ${testLoss.toFixed(3)}`);
    console.log(`Test accuracy: ${testAccuracy.toFixed(3)}`);
    //保存模型
    model.save("file://./my-model").then(() => {
      console.log("Model saved!");
    });
  });

package.json

复制代码
{
  "name": "neural_network",
  "version": "1.0.0",
  "description": "",
  "type": "module",
  "main": "mlpTest.js",
  "scripts": {
    "test": "echo \"Error: no test specified\" && exit 1",
  },
  "author": "",
  "license": "ISC",
  "dependencies": {
    "@tensorflow/tfjs": "^4.17.0",
    "@tensorflow/tfjs-node": "^4.17.0",
    "@tensorflow/tfjs-vis": "^1.0.0",
    "mnist": "^1.1.0",
    "nodeplotlib": "^0.7.7"
  },
  "devDependencies": {
    "@babel/core": "^7.0.0",
    "@babel/preset-env": "^7.0.0",
    "babel-loader": "^8.0.0",
    "webpack": "^5.0.0",
    "webpack-cli": "^4.0.0"
  }
}

模型结果

损失函数与模型精度变化

相关推荐
AiTEN_Robotics36 分钟前
AMR机器人:如何满足现代物料搬运的需求
人工智能·机器人·自动化
产品人卫朋39 分钟前
卫朋:IPD流程落地 - 市场地图拆解篇
大数据·人工智能·物联网
沛沛老爹1 小时前
跨平台Agent Skills开发:适配器模式赋能提示词优化与多AI应用无缝集成
人工智能·agent·适配器模式·rag·企业转型·skills
zhangshuang-peta1 小时前
适用于MCP的Nginx类代理:为何AI工具集成需要网关层
人工智能·ai agent·mcp·peta
Network_Engineer1 小时前
从零手写RNN&BiRNN:从原理到双向实现
人工智能·rnn·深度学习·神经网络
机器学习之心1 小时前
Bayes-TCN+SHAP分析贝叶斯优化深度学习多变量分类预测可解释性分析!Matlab完整代码
深度学习·matlab·分类·贝叶斯优化深度学习
想进部的张同学1 小时前
week1-day5-CNN卷积补充感受野-CUDA 一、CUDA 编程模型基础 1.1 CPU vs GPU 架构线程索引与向量乘法
人工智能·神经网络·cnn
WGS.1 小时前
fastenhancer DPRNN torch 实现
pytorch·深度学习
机器学习之心1 小时前
TCN+SHAP分析深度学习多变量分类预测可解释性分析!Matlab完整代码
深度学习·matlab·分类·多变量分类预测可解释性分析
睡醒了叭1 小时前
目标检测-深度学习-SSD模型项目
人工智能·深度学习·目标检测