识别手写数字,居然可以只靠前端?

前言

之前一篇的神经网络文章,居然意外的受欢迎,有一万多的掘友们看过。github 的 star 数也是破了新高,非常感谢~

文章链接:当一个前端学了很久的神经网络...👈🤣

github 链接:github.com/imoo666/neu...

但是之前边调研边写代码,还是有些乱的,我重新组织了一下代码,让大家能更清晰的了解前端使用神经网络的流程。

不过只是重新讲解一下流程就太水了,这篇就再来一个识别手写数字的项目,顺便理一下我们的思路。

步骤

很多同学反馈 担心前端入坑神经网络很难,但其实就是按部就班的几步,许多步骤都是调用 api,并不需要我们全部手写,还是比较容易的。

核心步骤有下:

  1. 加载和准备数据
  2. 定义模型
  3. 训练模型
  4. 使用模型进行预测

加载和准备数据

既然是手写数字识别,我们首先还是需要一些手写数字的图片,数据集我一般是去 kaggle 找的。

下载链接:www.kaggle.com/code/cdeott...

不过这次的数据是 csv 而非图片压缩包,先下载打开看看怎么个事。

可以观察到是一个 784 * n 的一个表格,表格中的数在 0-255 之间,对图片敏感的同学应该已经反应过来了,784 === 28 * 28,也就是每一行代表了一个 28 * 28 的灰度 图片。

可以简单写一个渲染图片的方法来看一下效果:

看起来跟我们猜想的一样,另外,第一行是表头,第一列是该行的实际数字,用于做验证。

知道这些,那就可以开始加载数据了,目标是将这堆数据转化为可以供 模型训练 的数据。

js 复制代码
  const loadCsvData = async () => {
  
    // 先加载
    const response = await fetch("src/pages/mnist/assets/mnist.csv");
    const text = await response.text();
    
    // 忽略第一行的表头
    const lines = text.trim().split("\n").slice(1);
    
    // 将每一行都转化为张量
    const samples: DigitSample[] = lines.map((line) => {
      const values = line.split(",").map(Number);
      const label = values[0];
      const pixels = tf
        .tensor3d(values.slice(1), [28, 28, 1])
        .div(255) as tf.Tensor3D;
      return { pixels, label };
    });
    
    // 打乱数组
    tf.util.shuffle(samples);
    
    // 将后 50 条作为测试集,其余作为训练集
    const train = samples.slice(0, samples.length - 50);
    const test = samples.slice(-50);

    // 独热编码,一共 10 个可能
    const xTrain = tf.stack(train.map((s) => s.pixels)) as tf.Tensor4D;
    const yTrain = tf.oneHot(
      train.map((s) => s.label),
      10
    ) as tf.Tensor2D;

    return { xTrain, yTrain, test };
  };

定义模型

这次是手写数字的识别,我们需要用到图片识别比较经典的 卷积层 + 最大池化层 的组合,除此之外,这次还添加了一个防止过拟合的 dropout 层。

js 复制代码
  const defineModel = () => {
    const model = tf.sequential({
      layers: [
        // 最大池化层,用于降低图片大小
        tf.layers.maxPooling2d({
          poolSize: 2,
          strides: 2,
          inputShape: [28, 28, 1],
        }),
        // 卷积层,用 32个卷积核进行提取特征
        tf.layers.conv2d({
          filters: 32,
          kernelSize: 3,
          activation: "relu",
          padding: "same",
        }),
        // 将提取结果平铺
        tf.layers.flatten(),
        // 一个普通的隐藏层计算关系
        tf.layers.dense({ units: 64, activation: "relu" }),
        // 防止过拟合
        tf.layers.dropout({ rate: 0.3 }),
        // 分类
        tf.layers.dense({
          units: 10,
          activation: "softmax",
        }),
      ],
    });

    model.compile({
      optimizer: "adam",
      loss: "categoricalCrossentropy",
      metrics: ["accuracy"],
    });

    return model;
  };

训练模型

训练模型没什么需要写的,只是需要配置几个参数(如轮数,批处理数量等),然后按照固定逻辑调用 api 即可

js 复制代码
 const trainModel = async () => {
    setModelState({ model: null, isTraining: true, logs: [] });

    const model = defineModel();
    const { xTrain, yTrain, test } = await loadCsvData();

    await model.fit(xTrain, yTrain, {
      epochs: 20, // 轮数
      batchSize: 8, // 批处理数量
      validationSplit: 0.2, // 用于验证的比例
      callbacks: {
        onEpochEnd: (epoch, logs) => {
          if (!logs) return;
          setModelState((prev) => ({
            ...prev,
            logs: [
              ...prev.logs,
              {
                epoch: epoch + 1,
                loss: Number(logs.loss?.toFixed(4)),
                accuracy: Number((logs.acc ?? logs.accuracy ?? 0).toFixed(4)),
              },
            ],
          }));
        },
      },
    });

    predict(model, test);
    setModelState((prev) => ({ ...prev, model, isTraining: false }));
    tf.dispose([xTrain, yTrain]);
  };

等待模型训练完毕后,model 就是可用的模型,可以用其去预测不同的图片,我选择了 50 张图片用于我们测试正确率。

使用模型进行预测

核心就是调用一下 model.predict() 这个方法用于预测,不过最后给出的结果会是一个十个元素的数组,分别代表是某个数字的概率,我们需要手动取出最高概率的一个作为我们的预测结果。

js 复制代码
const predict = (model: tf.Sequential, samples: DigitSample[]) => {
    const results: PredictionResult[] = samples.map((sample) => {
      const input = sample.pixels.expandDims(0); // 格式化
      const output = model.predict(input) as tf.Tensor; // 预测
      const probs = output.dataSync(); // 张量转数组
      const predicted = output.argMax(1).dataSync()[0]; // 拿到最大的位
      const confidence = Number((probs[predicted] * 100).toFixed(1));
      tf.dispose([input, output]);
      return {
        imageTensor: sample.pixels,
        actual: sample.label,
        predicted,
        confidence,
        correct: predicted === sample.label,
      };
    });
    setPredictions(results);
};

其他

最后可以看一下我们的完整页面

感兴趣的同学可以查看源码,相较于之前的版本做了许多整理工作,都按照本文的步骤进行了函数的划分:github.com/imoo666/neu...

又变强了一步!一起加油前端仔!

相关推荐
前端小巷子7 分钟前
深入 npm 模块安装机制
前端·javascript·面试
cypking1 小时前
electron中IPC 渲染进程与主进程通信方法解析
前端·javascript·electron
西陵1 小时前
Nx带来极致的前端开发体验——借助playground开发提效
前端·javascript·架构
江城开朗的豌豆2 小时前
Element UI动态组件样式修改小妙招,轻松拿捏!
前端·javascript·vue.js
float_六七2 小时前
JavaScript:现代Web开发的核心动力
开发语言·前端·javascript
zhaoyang03012 小时前
vue3笔记(2)自用
前端·javascript·笔记
德育处主任Pro3 小时前
# JsSIP 从入门到实战:构建你的第一个 Web 电话
前端
拾光拾趣录3 小时前
setTimeout(1) 和 setTimeout(2) 的区别
前端·v8
拾光拾趣录3 小时前
内存泄漏的“隐形杀手”
前端·性能优化
摸鱼仙人~3 小时前
HttpServletRequest深度解析:Java Web开发的核心组件
java·开发语言·前端