本教程将一步步带你在 Node.js 环境下,使用 TensorFlow.js(tfjs-node) 搭建并训练一个简单的卷积神经网络,实现 猫 vs 非猫 图像识别。我们会采用"代码分块 + 详细讲解"的形式,保证即使是入门者也能看懂。
1. 环境准备
代码:
js
// Step 1: 加载依赖
const tf = require('@tensorflow/tfjs-node'); // TensorFlow.js Node 版本
const fs = require('fs'); // Node.js 文件系统模块
const path = require('path'); // Node.js 路径处理模块
讲解:
@tensorflow/tfjs-node
:提供在 Node.js 环境下运行深度学习模型的能力。fs
:负责读取本地的图片文件。path
:用来安全地拼接目录和文件路径,避免跨平台出错。
2. 数据路径设置
代码:
js
// Step 2: 定义数据目录
const CAT_DIR = path.join(__dirname, 'data/cats'); // 猫图像所在文件夹
const NOT_CAT_DIR = path.join(__dirname, 'data/not_cats'); // 非猫图像所在文件夹
讲解: 我们需要两个文件夹:
data/cats
:存放猫的图片。data/not_cats
:存放非猫的图片(狗、人、风景等都行)。
模型通过比较这两类图像来学习"猫"的特征。
3. 数据加载与预处理函数
代码:
js
// Step 3: 加载单张图片并预处理
function loadAndPreprocessImage(filePath) {
const buffer = fs.readFileSync(filePath); // 读取文件为二进制
let imageTensor = tf.node.decodeImage(buffer, 3); // 解码为 3 通道(RGB)张量
imageTensor = tf.image.resizeBilinear(imageTensor, [128, 128]); // 调整大小到 128x128
imageTensor = imageTensor.div(255.0); // 归一化到 [0,1]
return imageTensor;
}
讲解:
fs.readFileSync
:把图片文件加载到内存中。tf.node.decodeImage
:把二进制数据转成张量,支持 RGB。resizeBilinear
:统一图片大小(否则模型无法批量处理)。div(255.0)
:将像素值从 0255 转换为 01,更利于训练。
4. 构建数据集
代码:
js
// Step 4: 构造数据集
function createDataset(catDir, notCatDir) {
const catFiles = fs.readdirSync(catDir).map(f => path.join(catDir, f));
const notCatFiles = fs.readdirSync(notCatDir).map(f => path.join(notCatDir, f));
const images = [];
const labels = [];
catFiles.forEach(file => {
images.push(loadAndPreprocessImage(file));
labels.push(1); // 猫标记为 1
});
notCatFiles.forEach(file => {
images.push(loadAndPreprocessImage(file));
labels.push(0); // 非猫标记为 0
});
// 打包成张量
return {
xs: tf.stack(images),
ys: tf.tensor1d(labels, 'int32')
};
}
讲解:
- 逐个读取
cats
和not_cats
的图片,转换成张量。 - 标签(
labels
)采用二分类方式:猫=1,非猫=0。 tf.stack(images)
:把单张图片张量堆叠成一个大的四维张量[数量, 高, 宽, 通道数]
。ys
是标签张量。
5. 模型结构设计
代码:
js
// Step 5: 定义 CNN 模型
function createModel() {
const model = tf.sequential();
// 第一卷积层:边缘检测
model.add(tf.layers.conv2d({
inputShape: [128, 128, 3],
filters: 32,
kernelSize: 3,
activation: 'relu'
}));
model.add(tf.layers.batchNormalization());
model.add(tf.layers.maxPooling2d({ poolSize: [2, 2] }));
// 第二卷积层:纹理检测
model.add(tf.layers.conv2d({ filters: 64, kernelSize: 3, activation: 'relu' }));
model.add(tf.layers.batchNormalization());
model.add(tf.layers.maxPooling2d({ poolSize: [2, 2] }));
// 第三卷积层:高层结构检测
model.add(tf.layers.conv2d({ filters: 128, kernelSize: 3, activation: 'relu' }));
model.add(tf.layers.batchNormalization());
model.add(tf.layers.maxPooling2d({ poolSize: [2, 2] }));
// 展平 + 全连接层
model.add(tf.layers.flatten());
model.add(tf.layers.dense({ units: 128, activation: 'relu' }));
model.add(tf.layers.dropout({ rate: 0.5 }));
model.add(tf.layers.dense({ units: 1, activation: 'sigmoid' }));
return model;
}
讲解:
- Conv2D:用卷积核提取特征。
- BatchNormalization:让数值分布更稳定,防止梯度爆炸/消失。
- MaxPooling2D:缩小图片尺寸,提取主要特征,减少计算量。
- Flatten:把卷积结果拉平为向量。
- Dense:经典全连接层,结合特征进行分类。
- Dropout:随机丢弃部分神经元,防止过拟合。
- Sigmoid 输出:二分类概率(0~1,数值越大越可能是猫)。
6. 模型编译与训练
代码:
js
// Step 6: 编译与训练
async function trainModel(model, data) {
model.compile({
optimizer: tf.train.adam(),
loss: 'binaryCrossentropy',
metrics: ['accuracy']
});
const history = await model.fit(data.xs, data.ys, {
epochs: 10,
batchSize: 16,
validationSplit: 0.2,
callbacks: tf.callbacks.earlyStopping({ monitor: 'val_loss', patience: 3 })
});
console.log('训练完成!');
console.log('历史记录:', history.history);
}
讲解:
- Adam 优化器:常用的自适应学习率优化算法。
- Binary CrossEntropy:二分类专用的损失函数。
- metrics:计算准确率。
- validationSplit=0.2:20% 数据用于验证。
- EarlyStopping:验证集 3 次没有改进就停止训练,避免浪费时间。
7. 主流程
代码:
js
// Step 7: 主入口
(async () => {
const data = createDataset(CAT_DIR, NOT_CAT_DIR); // 构建数据集
const model = createModel(); // 定义模型
await trainModel(model, data); // 训练模型
await model.save('file://./saved-model'); // 保存模型
// 用一张猫图测试预测
const testImage = loadAndPreprocessImage(path.join(CAT_DIR, fs.readdirSync(CAT_DIR)[0]));
const prediction = model.predict(testImage.expandDims(0));
prediction.print();
})();
讲解:
- 加载数据。
- 创建模型。
- 开始训练并输出训练日志。
- 保存模型到
saved-model
文件夹。 - 随机取一张猫图进行预测,打印预测概率。
总结
本文通过完整的代码分块与逐步讲解,带你从零实现了 猫 vs 非猫 分类器:
- 数据准备 → 预处理 → 数据集构建
- 模型设计(3 层卷积 + 全连接)
- 模型训练与保存
- 单张图片预测
这就是一个完整的深度学习入门案例,理解了它,就能进一步扩展到更多分类任务。
⚠️ 提示:本文内容部分由人工智能(GPT)生成,仅供学习、教学与技术参考使用。内容讲解与代码示例经过整理以便理解,但请读者在实际项目中进行验证和测试。