识别手写数字,居然可以只靠前端?

前言

之前一篇的神经网络文章,居然意外的受欢迎,有一万多的掘友们看过。github 的 star 数也是破了新高,非常感谢~

文章链接:当一个前端学了很久的神经网络...👈🤣

github 链接:github.com/imoo666/neu...

但是之前边调研边写代码,还是有些乱的,我重新组织了一下代码,让大家能更清晰的了解前端使用神经网络的流程。

不过只是重新讲解一下流程就太水了,这篇就再来一个识别手写数字的项目,顺便理一下我们的思路。

步骤

很多同学反馈 担心前端入坑神经网络很难,但其实就是按部就班的几步,许多步骤都是调用 api,并不需要我们全部手写,还是比较容易的。

核心步骤有下:

  1. 加载和准备数据
  2. 定义模型
  3. 训练模型
  4. 使用模型进行预测

加载和准备数据

既然是手写数字识别,我们首先还是需要一些手写数字的图片,数据集我一般是去 kaggle 找的。

下载链接:www.kaggle.com/code/cdeott...

不过这次的数据是 csv 而非图片压缩包,先下载打开看看怎么个事。

可以观察到是一个 784 * n 的一个表格,表格中的数在 0-255 之间,对图片敏感的同学应该已经反应过来了,784 === 28 * 28,也就是每一行代表了一个 28 * 28 的灰度 图片。

可以简单写一个渲染图片的方法来看一下效果:

看起来跟我们猜想的一样,另外,第一行是表头,第一列是该行的实际数字,用于做验证。

知道这些,那就可以开始加载数据了,目标是将这堆数据转化为可以供 模型训练 的数据。

js 复制代码
  const loadCsvData = async () => {
  
    // 先加载
    const response = await fetch("src/pages/mnist/assets/mnist.csv");
    const text = await response.text();
    
    // 忽略第一行的表头
    const lines = text.trim().split("\n").slice(1);
    
    // 将每一行都转化为张量
    const samples: DigitSample[] = lines.map((line) => {
      const values = line.split(",").map(Number);
      const label = values[0];
      const pixels = tf
        .tensor3d(values.slice(1), [28, 28, 1])
        .div(255) as tf.Tensor3D;
      return { pixels, label };
    });
    
    // 打乱数组
    tf.util.shuffle(samples);
    
    // 将后 50 条作为测试集,其余作为训练集
    const train = samples.slice(0, samples.length - 50);
    const test = samples.slice(-50);

    // 独热编码,一共 10 个可能
    const xTrain = tf.stack(train.map((s) => s.pixels)) as tf.Tensor4D;
    const yTrain = tf.oneHot(
      train.map((s) => s.label),
      10
    ) as tf.Tensor2D;

    return { xTrain, yTrain, test };
  };

定义模型

这次是手写数字的识别,我们需要用到图片识别比较经典的 卷积层 + 最大池化层 的组合,除此之外,这次还添加了一个防止过拟合的 dropout 层。

js 复制代码
  const defineModel = () => {
    const model = tf.sequential({
      layers: [
        // 最大池化层,用于降低图片大小
        tf.layers.maxPooling2d({
          poolSize: 2,
          strides: 2,
          inputShape: [28, 28, 1],
        }),
        // 卷积层,用 32个卷积核进行提取特征
        tf.layers.conv2d({
          filters: 32,
          kernelSize: 3,
          activation: "relu",
          padding: "same",
        }),
        // 将提取结果平铺
        tf.layers.flatten(),
        // 一个普通的隐藏层计算关系
        tf.layers.dense({ units: 64, activation: "relu" }),
        // 防止过拟合
        tf.layers.dropout({ rate: 0.3 }),
        // 分类
        tf.layers.dense({
          units: 10,
          activation: "softmax",
        }),
      ],
    });

    model.compile({
      optimizer: "adam",
      loss: "categoricalCrossentropy",
      metrics: ["accuracy"],
    });

    return model;
  };

训练模型

训练模型没什么需要写的,只是需要配置几个参数(如轮数,批处理数量等),然后按照固定逻辑调用 api 即可

js 复制代码
 const trainModel = async () => {
    setModelState({ model: null, isTraining: true, logs: [] });

    const model = defineModel();
    const { xTrain, yTrain, test } = await loadCsvData();

    await model.fit(xTrain, yTrain, {
      epochs: 20, // 轮数
      batchSize: 8, // 批处理数量
      validationSplit: 0.2, // 用于验证的比例
      callbacks: {
        onEpochEnd: (epoch, logs) => {
          if (!logs) return;
          setModelState((prev) => ({
            ...prev,
            logs: [
              ...prev.logs,
              {
                epoch: epoch + 1,
                loss: Number(logs.loss?.toFixed(4)),
                accuracy: Number((logs.acc ?? logs.accuracy ?? 0).toFixed(4)),
              },
            ],
          }));
        },
      },
    });

    predict(model, test);
    setModelState((prev) => ({ ...prev, model, isTraining: false }));
    tf.dispose([xTrain, yTrain]);
  };

等待模型训练完毕后,model 就是可用的模型,可以用其去预测不同的图片,我选择了 50 张图片用于我们测试正确率。

使用模型进行预测

核心就是调用一下 model.predict() 这个方法用于预测,不过最后给出的结果会是一个十个元素的数组,分别代表是某个数字的概率,我们需要手动取出最高概率的一个作为我们的预测结果。

js 复制代码
const predict = (model: tf.Sequential, samples: DigitSample[]) => {
    const results: PredictionResult[] = samples.map((sample) => {
      const input = sample.pixels.expandDims(0); // 格式化
      const output = model.predict(input) as tf.Tensor; // 预测
      const probs = output.dataSync(); // 张量转数组
      const predicted = output.argMax(1).dataSync()[0]; // 拿到最大的位
      const confidence = Number((probs[predicted] * 100).toFixed(1));
      tf.dispose([input, output]);
      return {
        imageTensor: sample.pixels,
        actual: sample.label,
        predicted,
        confidence,
        correct: predicted === sample.label,
      };
    });
    setPredictions(results);
};

其他

最后可以看一下我们的完整页面

感兴趣的同学可以查看源码,相较于之前的版本做了许多整理工作,都按照本文的步骤进行了函数的划分:github.com/imoo666/neu...

又变强了一步!一起加油前端仔!

相关推荐
耶啵奶膘2 小时前
uniapp+firstUI——上传视频组件fui-upload-video
前端·javascript·uni-app
视频砖家2 小时前
移动端Html5播放器按钮变小的问题解决方法
前端·javascript·viewport功能
lyj1689973 小时前
vue-i18n+vscode+vue 多语言使用
前端·vue.js·vscode
小白变怪兽4 小时前
一、react18+项目初始化(vite)
前端·react.js
ai小鬼头4 小时前
AIStarter如何快速部署Stable Diffusion?**新手也能轻松上手的AI绘图
前端·后端·github
墨菲安全5 小时前
NPM组件 betsson 等窃取主机敏感信息
前端·npm·node.js·软件供应链安全·主机信息窃取·npm组件投毒
GISer_Jing5 小时前
Monorepo+Pnpm+Turborepo
前端·javascript·ecmascript
天涯学馆5 小时前
前端开发也能用 WebAssembly?这些场景超实用!
前端·javascript·面试
我在北京coding6 小时前
TypeError: Cannot read properties of undefined (reading ‘queryComponents‘)
前端·javascript·vue.js
前端开发与ui设计的老司机7 小时前
UI前端与数字孪生结合实践探索:智慧物流的货物追踪与配送优化
前端·ui