初识 Tensorflow.js【Plan - June - Week 3】

一、TensorFlow.js

TensorFlow.js 是 TensorFlow 的 JavaScript 实现,支持在浏览器或 Node.js 环境中训练和部署机器学习模型。


1、TensorFlow.js 能做什么?

  • 在浏览器中训练机器学习模型
  • 加载并使用已有的模型(TensorFlow SavedModel、Keras 模型、TensorFlow Hub 等)
  • 在 Node.js 环境中训练和部署模型
  • 将模型从 Python TensorFlow 转换成 JS 可用格式

2、安装 TensorFlow.js

浏览器

html 复制代码
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs"></script>

或使用 包管理器

bash 复制代码
npm install @tensorflow/tfjs

Node.js

bash 复制代码
npm install @tensorflow/tfjs-node

Node.js 环境下还可以选择 @tensorflow/tfjs-node-gpu 获得 GPU 加速支持。


3、核心概念

Tensor (张量)

  • 数据容器:Tensor 是多维数组,类似于 NumPy 数组。
  • 示例:
js 复制代码
const tensor = tf.tensor([1, 2, 3, 4], [2, 2]);
tensor.print();

输出:

复制代码
[[1, 2],
 [3, 4]]

操作 (Operations)

TensorFlow.js 提供了丰富的张量运算,如矩阵乘法、加减、转置等。

js 复制代码
const a = tf.tensor([[1, 2], [3, 4]]);
const b = tf.tensor([[5, 6], [7, 8]]);
const result = tf.matMul(a, b);
result.print();

模型 (Models)

顺序模型 (Sequential)
js 复制代码
const model = tf.sequential();
model.add(tf.layers.dense({units: 32, inputShape: [50], activation: 'relu'}));
model.add(tf.layers.dense({units: 1, activation: 'linear'}));
函数式模型 (Functional)
js 复制代码
const input = tf.input({shape: [50]});
const dense1 = tf.layers.dense({units: 32, activation: 'relu'}).apply(input);
const output = tf.layers.dense({units: 1, activation: 'linear'}).apply(dense1);
const model = tf.model({inputs: input, outputs: output});

训练 (Training)

js 复制代码
model.compile({
  optimizer: 'sgd',
  loss: 'meanSquaredError'
});

// 假设 xs 和 ys 是 Tensor
await model.fit(xs, ys, {
  epochs: 10,
  batchSize: 32
});

加载/保存模型

加载预训练模型
js 复制代码
const model = await tf.loadLayersModel('https://example.com/model.json');
保存模型
js 复制代码
await model.save('downloads://my-model');  // 浏览器下载
await model.save('file://path-to-save');   // Node.js 保存到文件系统

4、内存管理

TensorFlow.js 不会自动回收内存,需手动处理:

js 复制代码
tf.tidy(() => {
  const y = tf.add(a, b);
  y.print();
}); // 自动释放内部中间张量

手动释放:

js 复制代码
tensor.dispose();

5、部署场景

场景 技术栈
浏览器推理/训练 TensorFlow.js + WebGL/WebGPU
Node.js 推理/训练 TensorFlow.js Node Bindings
混合端部署 将模型转换成不同端适用格式

6、工具链支持

  • TensorFlow Converter:Python 模型 → TensorFlow.js 格式
  • TensorFlow Hub:直接加载预训练模型
  • tfjs-vis:用于可视化训练过程、权重分布等

二、张量与运算

张量是 TensorFlow.js 的核心数据结构,用于存储和操作数据。TensorFlow.js 提供了丰富的运算 API 实现高效的数值计算。


1、什么是张量(Tensor)?

  • 张量 = 多维数组,类似于 NumPy ndarray。
  • 数据类型支持:float32(默认)、int32boolcomplex64string

张量维度示例

维度 含义 示例
0D 标量 tf.scalar(42)
1D 向量 tf.tensor1d([1, 2, 3])
2D 矩阵 tf.tensor2d([[1, 2], [3, 4]])
3D+ 高维数组(张量) tf.tensor3d([...])

2、创建张量

从数组创建

js 复制代码
const t1 = tf.tensor([1, 2, 3, 4], [2, 2]);
const t2 = tf.tensor2d([[1, 2], [3, 4]]);

特殊张量

js 复制代码
tf.zeros([2, 3]).print();
tf.ones([2, 2]).print();
tf.eye(3).print();  // 单位矩阵

生成序列

js 复制代码
tf.range(0, 10, 2).print();  // 0, 2, 4, 6, 8

3、张量运算

TensorFlow.js 提供无副作用的函数式 API,每次运算都会返回新张量。

基本运算

js 复制代码
const a = tf.tensor([1, 2]);
const b = tf.tensor([3, 4]);
tf.add(a, b).print();   // [4, 6]
tf.mul(a, b).print();   // [3, 8]

矩阵乘法

js 复制代码
const m1 = tf.tensor2d([[1, 2], [3, 4]]);
const m2 = tf.tensor2d([[5, 6], [7, 8]]);
tf.matMul(m1, m2).print();

广播机制

js 复制代码
const x = tf.tensor1d([1, 2, 3]);
const y = tf.scalar(2);
tf.mul(x, y).print();  // 每个元素乘2

取子集与变形

js 复制代码
const t = tf.tensor2d([[1,2,3],[4,5,6]]);
t.reshape([3,2]).print();
t.slice([0,1], [2,2]).print();  // 从 (0,1) 开始取 2x2 子矩阵

4、内存管理

张量和中间结果默认不会自动释放,需要手动管理。

自动释放临时张量

js 复制代码
tf.tidy(() => {
  const a = tf.tensor1d([1, 2, 3]);
  const b = a.square();
  b.print();
});  // 离开 tidy 后自动释放 a 和 b

手动释放

js 复制代码
const t = tf.tensor([1, 2, 3]);
t.dispose();

查看内存使用

js 复制代码
console.log(tf.memory());

5、异步操作

某些操作是异步的,例如从张量取值:

js 复制代码
const t = tf.tensor([1, 2, 3]);
t.data().then(data => console.log(data));  // TypedArray
t.array().then(arr => console.log(arr));   // 普通数组

在 Node.js 或浏览器 async/await 环境:

js 复制代码
const data = await t.data();
console.log(data);

6、计算设备

  • TensorFlow.js 默认会使用 WebGLWebGPU 在浏览器加速。
  • 在 Node.js 环境可选用 @tensorflow/tfjs-node@tensorflow/tfjs-node-gpu 获得更高性能。

三、MobileNet 模型

MobileNet 是一种轻量级卷积神经网络,用于高效的图像分类与特征提取,适合在浏览器、移动端和边缘设备运行。


1、MobileNet 的主要特点

  • 轻量化,推理速度快,适合实时应用
  • 支持多种输入尺寸和不同的模型深度(通过 α 参数调整)
  • 可作为特征提取器用于迁移学习
  • 预训练权重来自 ImageNet 数据集

2、快速使用

安装 TensorFlow.js

bash 复制代码
npm install @tensorflow/tfjs

在 HTML 中引入

html 复制代码
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs"></script>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/mobilenet"></script>

加载 MobileNet 模型

js 复制代码
import * as mobilenet from '@tensorflow-models/mobilenet';

// 或在浏览器中
// const mobilenet = window.mobilenet;

const model = await mobilenet.load({
  version: 2,          // MobileNet V2
  alpha: 1.0           // 模型宽度系数(0.25 ~ 1.0)
});

图像分类示例

js 复制代码
const img = document.getElementById('img');  // HTMLImageElement 或 Canvas
const predictions = await model.classify(img);

console.log(predictions);
// [
//   {className: 'Egyptian cat', probability: 0.75},
//   {className: 'tabby, tabby cat', probability: 0.15},
//   ...
// ]

提取特征用于迁移学习

js 复制代码
const activation = model.infer(img, true);
activation.print();  // 返回一个张量,可用于后续自定义分类器训练

3、参数说明

参数 含义 默认值
version 模型版本(1 或 2) 1
alpha 宽度缩放系数,减小模型大小 1.0(完整模型)

推荐:

  • 快速推理:version: 2, alpha: 0.5
  • 更高准确率:version: 2, alpha: 1.0

4、模型输入要求

  • 输入图像可以是:
    • HTMLImageElement
    • HTMLCanvasElement
    • HTMLVideoElement
    • Tensor (tf.browser.fromPixels(img))
  • 输入尺寸会自动调整为模型要求的大小。

5、性能优化

  • 使用 WebGL 或 WebGPU 提升推理性能(TensorFlow.js 自动选择)
  • 建议在调用前将图像转换为固定尺寸并缓存模型
  • 使用 tf.tidy 自动释放中间张量内存

6、示例代码片段

html 复制代码
<img id="img" src="cat.jpg" />
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs"></script>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/mobilenet"></script>
<script>
  async function run() {
    const img = document.getElementById('img');
    const model = await mobilenet.load({version: 2, alpha: 0.75});
    const predictions = await model.classify(img);
    console.log(predictions);
  }
  run();
</script>

学习资料来源

TensorFolw.js
张量和运算
MobileNet 模型

相关推荐
舒一笑24 分钟前
基础RAG实现,最佳入门选择(六)
人工智能
SLAM必须dunk1 小时前
DL___线性神经网络
人工智能·深度学习·神经网络
甜辣uu1 小时前
第七届人工智能技术与应用国际学术会议
人工智能·ei会议·中文核心·国际学术会议
艾立泰智能包装1 小时前
艾立泰智能物流载具管理方案
大数据·人工智能
舒一笑1 小时前
基础RAG实现,最佳入门选择(五)
人工智能
爱看科技1 小时前
谷歌Gemini 2.5全系领跑AI赛道,微美全息加码构建AI+多模态交互生态新范式
人工智能
love530love1 小时前
Python 开发环境全栈隔离架构:从 Anaconda 到 PyCharm 的四级防护体系
运维·ide·人工智能·windows·python·架构·pycharm
m0_751336392 小时前
机器学习赋能多尺度材料模拟:前沿技术会议邀您共探
人工智能·深度学习·机器学习·第一性原理·分子动力学·vasp·复合材料
一休哥助手2 小时前
稳定币:从支付工具到金融基础设施的技术演进与全球竞争新格局
人工智能·金融
泡芙萝莉酱2 小时前
2011-2023年 省级-数字普惠金融指数-社科经管实证数据
大数据·人工智能·深度学习·数据挖掘·数据分析·数据统计·实证数据