用 TensorFlow.js Node 实现猫图像识别(教学版逐步分解)

本教程将一步步带你在 Node.js 环境下,使用 TensorFlow.js(tfjs-node) 搭建并训练一个简单的卷积神经网络,实现 猫 vs 非猫 图像识别。我们会采用"代码分块 + 详细讲解"的形式,保证即使是入门者也能看懂。


1. 环境准备

代码:

js 复制代码
// Step 1: 加载依赖
const tf = require('@tensorflow/tfjs-node'); // TensorFlow.js Node 版本
const fs = require('fs');                   // Node.js 文件系统模块
const path = require('path');               // Node.js 路径处理模块

讲解:

  • @tensorflow/tfjs-node:提供在 Node.js 环境下运行深度学习模型的能力。
  • fs:负责读取本地的图片文件。
  • path:用来安全地拼接目录和文件路径,避免跨平台出错。

2. 数据路径设置

代码:

js 复制代码
// Step 2: 定义数据目录
const CAT_DIR = path.join(__dirname, 'data/cats');      // 猫图像所在文件夹
const NOT_CAT_DIR = path.join(__dirname, 'data/not_cats'); // 非猫图像所在文件夹

讲解: 我们需要两个文件夹:

  • data/cats:存放猫的图片。
  • data/not_cats:存放非猫的图片(狗、人、风景等都行)。

模型通过比较这两类图像来学习"猫"的特征。


3. 数据加载与预处理函数

代码:

js 复制代码
// Step 3: 加载单张图片并预处理
function loadAndPreprocessImage(filePath) {
  const buffer = fs.readFileSync(filePath);            // 读取文件为二进制
  let imageTensor = tf.node.decodeImage(buffer, 3);    // 解码为 3 通道(RGB)张量
  imageTensor = tf.image.resizeBilinear(imageTensor, [128, 128]); // 调整大小到 128x128
  imageTensor = imageTensor.div(255.0);                // 归一化到 [0,1]
  return imageTensor;
}

讲解:

  1. fs.readFileSync:把图片文件加载到内存中。
  2. tf.node.decodeImage:把二进制数据转成张量,支持 RGB。
  3. resizeBilinear:统一图片大小(否则模型无法批量处理)。
  4. div(255.0):将像素值从 0255 转换为 01,更利于训练。

4. 构建数据集

代码:

js 复制代码
// Step 4: 构造数据集
function createDataset(catDir, notCatDir) {
  const catFiles = fs.readdirSync(catDir).map(f => path.join(catDir, f));
  const notCatFiles = fs.readdirSync(notCatDir).map(f => path.join(notCatDir, f));

  const images = [];
  const labels = [];

  catFiles.forEach(file => {
    images.push(loadAndPreprocessImage(file));
    labels.push(1); // 猫标记为 1
  });

  notCatFiles.forEach(file => {
    images.push(loadAndPreprocessImage(file));
    labels.push(0); // 非猫标记为 0
  });

  // 打包成张量
  return {
    xs: tf.stack(images),
    ys: tf.tensor1d(labels, 'int32')
  };
}

讲解:

  • 逐个读取 catsnot_cats 的图片,转换成张量。
  • 标签(labels)采用二分类方式:猫=1,非猫=0。
  • tf.stack(images):把单张图片张量堆叠成一个大的四维张量 [数量, 高, 宽, 通道数]
  • ys 是标签张量。

5. 模型结构设计

代码:

js 复制代码
// Step 5: 定义 CNN 模型
function createModel() {
  const model = tf.sequential();

  // 第一卷积层:边缘检测
  model.add(tf.layers.conv2d({
    inputShape: [128, 128, 3],
    filters: 32,
    kernelSize: 3,
    activation: 'relu'
  }));
  model.add(tf.layers.batchNormalization());
  model.add(tf.layers.maxPooling2d({ poolSize: [2, 2] }));

  // 第二卷积层:纹理检测
  model.add(tf.layers.conv2d({ filters: 64, kernelSize: 3, activation: 'relu' }));
  model.add(tf.layers.batchNormalization());
  model.add(tf.layers.maxPooling2d({ poolSize: [2, 2] }));

  // 第三卷积层:高层结构检测
  model.add(tf.layers.conv2d({ filters: 128, kernelSize: 3, activation: 'relu' }));
  model.add(tf.layers.batchNormalization());
  model.add(tf.layers.maxPooling2d({ poolSize: [2, 2] }));

  // 展平 + 全连接层
  model.add(tf.layers.flatten());
  model.add(tf.layers.dense({ units: 128, activation: 'relu' }));
  model.add(tf.layers.dropout({ rate: 0.5 }));
  model.add(tf.layers.dense({ units: 1, activation: 'sigmoid' }));

  return model;
}

讲解:

  • Conv2D:用卷积核提取特征。
  • BatchNormalization:让数值分布更稳定,防止梯度爆炸/消失。
  • MaxPooling2D:缩小图片尺寸,提取主要特征,减少计算量。
  • Flatten:把卷积结果拉平为向量。
  • Dense:经典全连接层,结合特征进行分类。
  • Dropout:随机丢弃部分神经元,防止过拟合。
  • Sigmoid 输出:二分类概率(0~1,数值越大越可能是猫)。

6. 模型编译与训练

代码:

js 复制代码
// Step 6: 编译与训练
async function trainModel(model, data) {
  model.compile({
    optimizer: tf.train.adam(),
    loss: 'binaryCrossentropy',
    metrics: ['accuracy']
  });

  const history = await model.fit(data.xs, data.ys, {
    epochs: 10,
    batchSize: 16,
    validationSplit: 0.2,
    callbacks: tf.callbacks.earlyStopping({ monitor: 'val_loss', patience: 3 })
  });

  console.log('训练完成!');
  console.log('历史记录:', history.history);
}

讲解:

  • Adam 优化器:常用的自适应学习率优化算法。
  • Binary CrossEntropy:二分类专用的损失函数。
  • metrics:计算准确率。
  • validationSplit=0.2:20% 数据用于验证。
  • EarlyStopping:验证集 3 次没有改进就停止训练,避免浪费时间。

7. 主流程

代码:

js 复制代码
// Step 7: 主入口
(async () => {
  const data = createDataset(CAT_DIR, NOT_CAT_DIR); // 构建数据集
  const model = createModel();                      // 定义模型
  await trainModel(model, data);                    // 训练模型
  await model.save('file://./saved-model');         // 保存模型

  // 用一张猫图测试预测
  const testImage = loadAndPreprocessImage(path.join(CAT_DIR, fs.readdirSync(CAT_DIR)[0]));
  const prediction = model.predict(testImage.expandDims(0));
  prediction.print();
})();

讲解:

  1. 加载数据。
  2. 创建模型。
  3. 开始训练并输出训练日志。
  4. 保存模型到 saved-model 文件夹。
  5. 随机取一张猫图进行预测,打印预测概率。

总结

本文通过完整的代码分块与逐步讲解,带你从零实现了 猫 vs 非猫 分类器:

  • 数据准备 → 预处理 → 数据集构建
  • 模型设计(3 层卷积 + 全连接)
  • 模型训练与保存
  • 单张图片预测

这就是一个完整的深度学习入门案例,理解了它,就能进一步扩展到更多分类任务。

⚠️ 提示:本文内容部分由人工智能(GPT)生成,仅供学习、教学与技术参考使用。内容讲解与代码示例经过整理以便理解,但请读者在实际项目中进行验证和测试。

相关推荐
Hilaku23 分钟前
我用 Gemini 3 Pro 手搓了一个并发邮件群发神器(附源码)
前端·javascript·github
IT_陈寒23 分钟前
Java性能调优实战:5个被低估却提升30%效率的JVM参数
前端·人工智能·后端
快手技术24 分钟前
AAAI 2026|全面发力!快手斩获 3 篇 Oral,12 篇论文入选!
前端·后端·算法
颜酱26 分钟前
前端算法必备:滑动窗口从入门到很熟练(最长/最短/计数三大类型)
前端·后端·算法
全栈前端老曹34 分钟前
【包管理】npm init 项目名后底层发生了什么的完整逻辑
前端·javascript·npm·node.js·json·包管理·底层原理
HHHHHY40 分钟前
mathjs简单实现一个数学计算公式及校验组件
前端·javascript·vue.js
boooooooom43 分钟前
Vue3 provide/inject 跨层级通信:最佳实践与避坑指南
前端·vue.js
一颗烂土豆43 分钟前
Vue 3 + Three.js 打造轻量级 3D 图表库 —— chart3
前端·vue.js·数据可视化
青莲84344 分钟前
Android 动画机制完整详解
android·前端·面试
iReachers1 小时前
HTML打包APK(安卓APP)中下载功能常见问题和详细介绍
前端·javascript·html·html打包apk·网页打包app·下载功能