用 TensorFlow.js Node 实现猫图像识别(教学版逐步分解)

本教程将一步步带你在 Node.js 环境下,使用 TensorFlow.js(tfjs-node) 搭建并训练一个简单的卷积神经网络,实现 猫 vs 非猫 图像识别。我们会采用"代码分块 + 详细讲解"的形式,保证即使是入门者也能看懂。


1. 环境准备

代码:

js 复制代码
// Step 1: 加载依赖
const tf = require('@tensorflow/tfjs-node'); // TensorFlow.js Node 版本
const fs = require('fs');                   // Node.js 文件系统模块
const path = require('path');               // Node.js 路径处理模块

讲解:

  • @tensorflow/tfjs-node:提供在 Node.js 环境下运行深度学习模型的能力。
  • fs:负责读取本地的图片文件。
  • path:用来安全地拼接目录和文件路径,避免跨平台出错。

2. 数据路径设置

代码:

js 复制代码
// Step 2: 定义数据目录
const CAT_DIR = path.join(__dirname, 'data/cats');      // 猫图像所在文件夹
const NOT_CAT_DIR = path.join(__dirname, 'data/not_cats'); // 非猫图像所在文件夹

讲解: 我们需要两个文件夹:

  • data/cats:存放猫的图片。
  • data/not_cats:存放非猫的图片(狗、人、风景等都行)。

模型通过比较这两类图像来学习"猫"的特征。


3. 数据加载与预处理函数

代码:

js 复制代码
// Step 3: 加载单张图片并预处理
function loadAndPreprocessImage(filePath) {
  const buffer = fs.readFileSync(filePath);            // 读取文件为二进制
  let imageTensor = tf.node.decodeImage(buffer, 3);    // 解码为 3 通道(RGB)张量
  imageTensor = tf.image.resizeBilinear(imageTensor, [128, 128]); // 调整大小到 128x128
  imageTensor = imageTensor.div(255.0);                // 归一化到 [0,1]
  return imageTensor;
}

讲解:

  1. fs.readFileSync:把图片文件加载到内存中。
  2. tf.node.decodeImage:把二进制数据转成张量,支持 RGB。
  3. resizeBilinear:统一图片大小(否则模型无法批量处理)。
  4. div(255.0):将像素值从 0255 转换为 01,更利于训练。

4. 构建数据集

代码:

js 复制代码
// Step 4: 构造数据集
function createDataset(catDir, notCatDir) {
  const catFiles = fs.readdirSync(catDir).map(f => path.join(catDir, f));
  const notCatFiles = fs.readdirSync(notCatDir).map(f => path.join(notCatDir, f));

  const images = [];
  const labels = [];

  catFiles.forEach(file => {
    images.push(loadAndPreprocessImage(file));
    labels.push(1); // 猫标记为 1
  });

  notCatFiles.forEach(file => {
    images.push(loadAndPreprocessImage(file));
    labels.push(0); // 非猫标记为 0
  });

  // 打包成张量
  return {
    xs: tf.stack(images),
    ys: tf.tensor1d(labels, 'int32')
  };
}

讲解:

  • 逐个读取 catsnot_cats 的图片,转换成张量。
  • 标签(labels)采用二分类方式:猫=1,非猫=0。
  • tf.stack(images):把单张图片张量堆叠成一个大的四维张量 [数量, 高, 宽, 通道数]
  • ys 是标签张量。

5. 模型结构设计

代码:

js 复制代码
// Step 5: 定义 CNN 模型
function createModel() {
  const model = tf.sequential();

  // 第一卷积层:边缘检测
  model.add(tf.layers.conv2d({
    inputShape: [128, 128, 3],
    filters: 32,
    kernelSize: 3,
    activation: 'relu'
  }));
  model.add(tf.layers.batchNormalization());
  model.add(tf.layers.maxPooling2d({ poolSize: [2, 2] }));

  // 第二卷积层:纹理检测
  model.add(tf.layers.conv2d({ filters: 64, kernelSize: 3, activation: 'relu' }));
  model.add(tf.layers.batchNormalization());
  model.add(tf.layers.maxPooling2d({ poolSize: [2, 2] }));

  // 第三卷积层:高层结构检测
  model.add(tf.layers.conv2d({ filters: 128, kernelSize: 3, activation: 'relu' }));
  model.add(tf.layers.batchNormalization());
  model.add(tf.layers.maxPooling2d({ poolSize: [2, 2] }));

  // 展平 + 全连接层
  model.add(tf.layers.flatten());
  model.add(tf.layers.dense({ units: 128, activation: 'relu' }));
  model.add(tf.layers.dropout({ rate: 0.5 }));
  model.add(tf.layers.dense({ units: 1, activation: 'sigmoid' }));

  return model;
}

讲解:

  • Conv2D:用卷积核提取特征。
  • BatchNormalization:让数值分布更稳定,防止梯度爆炸/消失。
  • MaxPooling2D:缩小图片尺寸,提取主要特征,减少计算量。
  • Flatten:把卷积结果拉平为向量。
  • Dense:经典全连接层,结合特征进行分类。
  • Dropout:随机丢弃部分神经元,防止过拟合。
  • Sigmoid 输出:二分类概率(0~1,数值越大越可能是猫)。

6. 模型编译与训练

代码:

js 复制代码
// Step 6: 编译与训练
async function trainModel(model, data) {
  model.compile({
    optimizer: tf.train.adam(),
    loss: 'binaryCrossentropy',
    metrics: ['accuracy']
  });

  const history = await model.fit(data.xs, data.ys, {
    epochs: 10,
    batchSize: 16,
    validationSplit: 0.2,
    callbacks: tf.callbacks.earlyStopping({ monitor: 'val_loss', patience: 3 })
  });

  console.log('训练完成!');
  console.log('历史记录:', history.history);
}

讲解:

  • Adam 优化器:常用的自适应学习率优化算法。
  • Binary CrossEntropy:二分类专用的损失函数。
  • metrics:计算准确率。
  • validationSplit=0.2:20% 数据用于验证。
  • EarlyStopping:验证集 3 次没有改进就停止训练,避免浪费时间。

7. 主流程

代码:

js 复制代码
// Step 7: 主入口
(async () => {
  const data = createDataset(CAT_DIR, NOT_CAT_DIR); // 构建数据集
  const model = createModel();                      // 定义模型
  await trainModel(model, data);                    // 训练模型
  await model.save('file://./saved-model');         // 保存模型

  // 用一张猫图测试预测
  const testImage = loadAndPreprocessImage(path.join(CAT_DIR, fs.readdirSync(CAT_DIR)[0]));
  const prediction = model.predict(testImage.expandDims(0));
  prediction.print();
})();

讲解:

  1. 加载数据。
  2. 创建模型。
  3. 开始训练并输出训练日志。
  4. 保存模型到 saved-model 文件夹。
  5. 随机取一张猫图进行预测,打印预测概率。

总结

本文通过完整的代码分块与逐步讲解,带你从零实现了 猫 vs 非猫 分类器:

  • 数据准备 → 预处理 → 数据集构建
  • 模型设计(3 层卷积 + 全连接)
  • 模型训练与保存
  • 单张图片预测

这就是一个完整的深度学习入门案例,理解了它,就能进一步扩展到更多分类任务。

⚠️ 提示:本文内容部分由人工智能(GPT)生成,仅供学习、教学与技术参考使用。内容讲解与代码示例经过整理以便理解,但请读者在实际项目中进行验证和测试。

相关推荐
gnip6 小时前
JavaScript事件流
前端·javascript
赵得C6 小时前
【前端技巧】Element Table 列标题如何优雅添加 Tooltip 提示?
前端·elementui·vue·table组件
wow_DG6 小时前
【Vue2 ✨】Vue2 入门之旅 · 进阶篇(一):响应式原理
前端·javascript·vue.js
weixin_456904276 小时前
UserManagement.vue和Profile.vue详细解释
前端·javascript·vue.js
资深前端之路6 小时前
react 面试题 react 有什么特点?
前端·react.js·面试·前端框架
aaaweiaaaaaa6 小时前
HTML和CSS学习
前端·css·学习·html
秋秋小事6 小时前
React Hooks useContext
前端·javascript·react.js
Jinuss6 小时前
Vue3源码reactivity响应式篇之reactive响应式对象的track与trigger
前端·vue3
striver_#6 小时前
百度前端社招面经二
前端