识别手写数字,居然可以只靠前端?

前言

之前一篇的神经网络文章,居然意外的受欢迎,有一万多的掘友们看过。github 的 star 数也是破了新高,非常感谢~

文章链接:当一个前端学了很久的神经网络...👈🤣

github 链接:github.com/imoo666/neu...

但是之前边调研边写代码,还是有些乱的,我重新组织了一下代码,让大家能更清晰的了解前端使用神经网络的流程。

不过只是重新讲解一下流程就太水了,这篇就再来一个识别手写数字的项目,顺便理一下我们的思路。

步骤

很多同学反馈 担心前端入坑神经网络很难,但其实就是按部就班的几步,许多步骤都是调用 api,并不需要我们全部手写,还是比较容易的。

核心步骤有下:

  1. 加载和准备数据
  2. 定义模型
  3. 训练模型
  4. 使用模型进行预测

加载和准备数据

既然是手写数字识别,我们首先还是需要一些手写数字的图片,数据集我一般是去 kaggle 找的。

下载链接:www.kaggle.com/code/cdeott...

不过这次的数据是 csv 而非图片压缩包,先下载打开看看怎么个事。

可以观察到是一个 784 * n 的一个表格,表格中的数在 0-255 之间,对图片敏感的同学应该已经反应过来了,784 === 28 * 28,也就是每一行代表了一个 28 * 28 的灰度 图片。

可以简单写一个渲染图片的方法来看一下效果:

看起来跟我们猜想的一样,另外,第一行是表头,第一列是该行的实际数字,用于做验证。

知道这些,那就可以开始加载数据了,目标是将这堆数据转化为可以供 模型训练 的数据。

js 复制代码
  const loadCsvData = async () => {
  
    // 先加载
    const response = await fetch("src/pages/mnist/assets/mnist.csv");
    const text = await response.text();
    
    // 忽略第一行的表头
    const lines = text.trim().split("\n").slice(1);
    
    // 将每一行都转化为张量
    const samples: DigitSample[] = lines.map((line) => {
      const values = line.split(",").map(Number);
      const label = values[0];
      const pixels = tf
        .tensor3d(values.slice(1), [28, 28, 1])
        .div(255) as tf.Tensor3D;
      return { pixels, label };
    });
    
    // 打乱数组
    tf.util.shuffle(samples);
    
    // 将后 50 条作为测试集,其余作为训练集
    const train = samples.slice(0, samples.length - 50);
    const test = samples.slice(-50);

    // 独热编码,一共 10 个可能
    const xTrain = tf.stack(train.map((s) => s.pixels)) as tf.Tensor4D;
    const yTrain = tf.oneHot(
      train.map((s) => s.label),
      10
    ) as tf.Tensor2D;

    return { xTrain, yTrain, test };
  };

定义模型

这次是手写数字的识别,我们需要用到图片识别比较经典的 卷积层 + 最大池化层 的组合,除此之外,这次还添加了一个防止过拟合的 dropout 层。

js 复制代码
  const defineModel = () => {
    const model = tf.sequential({
      layers: [
        // 最大池化层,用于降低图片大小
        tf.layers.maxPooling2d({
          poolSize: 2,
          strides: 2,
          inputShape: [28, 28, 1],
        }),
        // 卷积层,用 32个卷积核进行提取特征
        tf.layers.conv2d({
          filters: 32,
          kernelSize: 3,
          activation: "relu",
          padding: "same",
        }),
        // 将提取结果平铺
        tf.layers.flatten(),
        // 一个普通的隐藏层计算关系
        tf.layers.dense({ units: 64, activation: "relu" }),
        // 防止过拟合
        tf.layers.dropout({ rate: 0.3 }),
        // 分类
        tf.layers.dense({
          units: 10,
          activation: "softmax",
        }),
      ],
    });

    model.compile({
      optimizer: "adam",
      loss: "categoricalCrossentropy",
      metrics: ["accuracy"],
    });

    return model;
  };

训练模型

训练模型没什么需要写的,只是需要配置几个参数(如轮数,批处理数量等),然后按照固定逻辑调用 api 即可

js 复制代码
 const trainModel = async () => {
    setModelState({ model: null, isTraining: true, logs: [] });

    const model = defineModel();
    const { xTrain, yTrain, test } = await loadCsvData();

    await model.fit(xTrain, yTrain, {
      epochs: 20, // 轮数
      batchSize: 8, // 批处理数量
      validationSplit: 0.2, // 用于验证的比例
      callbacks: {
        onEpochEnd: (epoch, logs) => {
          if (!logs) return;
          setModelState((prev) => ({
            ...prev,
            logs: [
              ...prev.logs,
              {
                epoch: epoch + 1,
                loss: Number(logs.loss?.toFixed(4)),
                accuracy: Number((logs.acc ?? logs.accuracy ?? 0).toFixed(4)),
              },
            ],
          }));
        },
      },
    });

    predict(model, test);
    setModelState((prev) => ({ ...prev, model, isTraining: false }));
    tf.dispose([xTrain, yTrain]);
  };

等待模型训练完毕后,model 就是可用的模型,可以用其去预测不同的图片,我选择了 50 张图片用于我们测试正确率。

使用模型进行预测

核心就是调用一下 model.predict() 这个方法用于预测,不过最后给出的结果会是一个十个元素的数组,分别代表是某个数字的概率,我们需要手动取出最高概率的一个作为我们的预测结果。

js 复制代码
const predict = (model: tf.Sequential, samples: DigitSample[]) => {
    const results: PredictionResult[] = samples.map((sample) => {
      const input = sample.pixels.expandDims(0); // 格式化
      const output = model.predict(input) as tf.Tensor; // 预测
      const probs = output.dataSync(); // 张量转数组
      const predicted = output.argMax(1).dataSync()[0]; // 拿到最大的位
      const confidence = Number((probs[predicted] * 100).toFixed(1));
      tf.dispose([input, output]);
      return {
        imageTensor: sample.pixels,
        actual: sample.label,
        predicted,
        confidence,
        correct: predicted === sample.label,
      };
    });
    setPredictions(results);
};

其他

最后可以看一下我们的完整页面

感兴趣的同学可以查看源码,相较于之前的版本做了许多整理工作,都按照本文的步骤进行了函数的划分:github.com/imoo666/neu...

又变强了一步!一起加油前端仔!

相关推荐
BillKu2 分钟前
Vue3 + Vite 中使用 Lodash-es 的防抖 debounce 详解
前端·javascript·vue.js
一只小风华~10 分钟前
HTML前端开发:JavaScript的条分支语句if,Switch
前端·javascript·html5
橙子家10 分钟前
Select 组件实现【全选】(基于 Element)
前端
超级土豆粉11 分钟前
HTML 语义化
前端·html
bingbingyihao18 分钟前
UI框架-通知组件
前端·javascript·vue
wordbaby20 分钟前
React Router 预渲染的工作原理和价值(Pre-rendering)
前端·react.js
依旧天真无邪1 小时前
Chrome 优质插件计划
前端·chrome
逝缘~1 小时前
小白学Pinia状态管理
前端·javascript·vue.js·vscode·es6·pinia
光影少年1 小时前
vite原理
前端·javascript·vue.js
C MIKE1 小时前
ztree.js前端插件样式文字大小文字背景修改
开发语言·前端·javascript