目的
利用tensorflow.js训练模型,搭建神经网络模型,完成手写数字识别
设计
简单三层神经网络
- 输入层
28*28个神经原,代表每一张手写数字图片的灰度 - 隐藏层
100个神经原 - 输出层
-10个神经原,分别代表10个数字
代码
// 导入 TensorFlow.js 库
import tf from "@tensorflow/tfjs";
import * as tfjsnode from "@tensorflow/tfjs-node";
import * as tfvis from "@tensorflow/tfjs-vis";
import fs from "fs";
import plot from "nodeplotlib";
// 定义模型
const model = tf.sequential();
// 添加输入层
model.add(
tf.layers.dense({ units: 64, inputShape: [784], activation: "relu" })
);
// 添加隐藏层
model.add(tf.layers.dense({ units: 100, activation: "relu" }));
// 添加输出层
model.add(tf.layers.dense({ units: 10, activation: "softmax" }));
// 编译模型
model.compile({
optimizer: "sgd",
loss: "categoricalCrossentropy",
metrics: ["accuracy"],
});
const trainDataLen = 3000;
const testDataLen = 2000;
// 加载 MNIST 数据集
import pkg from "mnist";
const { set: Dataset } = pkg;
const set = Dataset(trainDataLen, testDataLen);
const trainingSet = set.training;
const testSet = set.test;
const trainXs = [];
const testXs = [];
const trainLabels = [];
const testLabels = [];
for (let i = 0; i < trainingSet.length; i++) {
trainXs.push(trainingSet[i].input);
trainLabels.push(trainingSet[i].output.indexOf(1));
}
for (let i = 0; i < testSet.length; i++) {
testXs.push(testSet[i].input);
testLabels.push(testSet[i].output.indexOf(1));
}
// 准备数据
const trainXsTensor = tf.tensor(trainXs, [trainDataLen, 784]);
const trainYsOneHot = tf.oneHot(trainLabels, 10);
//记录每轮模型训练中的损失和精度,为了绘制曲线图
var accPlot = [];
var lossPlot = [];
// 模型训练
model
.fit(trainXsTensor, trainYsOneHot, {
batchSize: 64,
epochs: 100,
validationSplit: 0.2,
callbacks: {
onEpochBegin: (epoch) => console.log(`Epoch ${epoch + 1} started...`),
onEpochEnd: async (epoch, logs) => {
console.log(
`Epoch ${epoch + 1} completed. Loss: ${logs.loss.toFixed(
3
)}, Accuracy: ${logs.acc.toFixed(3)}`
);
//记录loss和acc,绘制曲线图
accPlot.push(logs.acc.toFixed(3));
lossPlot.push(logs.loss.toFixed(3));
await tf.nextFrame(); // 防止阻塞
},
onBatchEnd: async (batch, logs) => {
console.log(
`Batch ${batch} completed. Loss: ${logs.loss.toFixed(
3
)}, Accuracy: ${logs.acc.toFixed(3)}`
);
await tf.nextFrame(); // 防止阻塞
},
},
})
.then((history) => {
console.log("Training completed!", history);
//绘制模型训练过程中的损失函数和模型精度曲线变化
const epochs = Array.from({ length: lossPlot.length }, (_, i) => i + 1);
plot.plot(
[
{ x: epochs, y: lossPlot, name: "Loss" },
{ x: epochs, y: accPlot, name: "Accuracy" },
],
{
filename: "loss_acc.png",
}
);
//模型评估
const testXsTensor = tf.tensor(testXs, [testDataLen, 784]);
const testYsOneHot = tf.oneHot(testLabels, 10);
const result = model.evaluate(testXsTensor, testYsOneHot);
const testLoss = result[0].dataSync()[0];
const testAccuracy = result[1].dataSync()[0];
console.log(`Test loss: ${testLoss.toFixed(3)}`);
console.log(`Test accuracy: ${testAccuracy.toFixed(3)}`);
//保存模型
model.save("file://./my-model").then(() => {
console.log("Model saved!");
});
});
package.json
{
"name": "neural_network",
"version": "1.0.0",
"description": "",
"type": "module",
"main": "mlpTest.js",
"scripts": {
"test": "echo \"Error: no test specified\" && exit 1",
},
"author": "",
"license": "ISC",
"dependencies": {
"@tensorflow/tfjs": "^4.17.0",
"@tensorflow/tfjs-node": "^4.17.0",
"@tensorflow/tfjs-vis": "^1.0.0",
"mnist": "^1.1.0",
"nodeplotlib": "^0.7.7"
},
"devDependencies": {
"@babel/core": "^7.0.0",
"@babel/preset-env": "^7.0.0",
"babel-loader": "^8.0.0",
"webpack": "^5.0.0",
"webpack-cli": "^4.0.0"
}
}