以下是关于AI集成(TensorFlow.js模型推理)的基本知识点总结:
一、TensorFlow.js 核心概念
-
Tensor(张量)
-
定义:多维数组,是数据的基本单位(如标量、向量、矩阵)。
-
操作:数学运算(加减乘除)、形状变换(reshape)、广播(broadcasting)。
-
示例 :
javascriptconst t = tf.tensor([1, 2, 3]); // 创建张量 t.add(1).print(); // 张量加法
-
-
模型(Model)
- 预训练模型:直接加载已训练好的模型(如 MobileNet、PoseNet)。
- 自定义模型 :通过
tf.sequential()
或函数式 API 构建模型。
-
模型推理流程
- 输入处理 → 模型预测 → 输出解析。
二、TensorFlow.js 环境搭建
1. 安装方式
-
浏览器环境 (直接引入):
html<script src="https://cdn.jsdelivr.net/npm/@tensorflow/[email protected]"></script>
-
Node.js 环境 :
bashnpm install @tensorflow/tfjs-node # CPU 版本 npm install @tensorflow/tfjs-node-gpu # GPU 版本(需 CUDA)
2. 模型转换工具
-
转换 Keras 模型 :
bashtensorflowjs_converter --input_format=keras_saved_model model.h5 ./js_model/
-
转换 TensorFlow SavedModel :
bashtensorflowjs_converter --input_format=tf_saved_model ./saved_model/ ./js_model/
三、模型加载与推理
1. 加载预训练模型
javascript
// 浏览器中加载模型
const model = await tf.loadLayersModel('https://path/to/model.json');
// Node.js 中加载模型
const model = await tf.node.loadSavedModel('./path/to/saved_model/');
2. 输入预处理
javascript
// 图像预处理示例(归一化 + 调整尺寸)
const img = document.getElementById('my-image');
const tensor = tf.browser.fromPixels(img)
.resizeNearestNeighbor([224, 224]) // 调整尺寸
.toFloat()
.div(255) // 归一化到 [0,1]
.expandDims(); // 添加批次维度 [1, 224, 224, 3]
3. 执行推理
javascript
const predictions = await model.predict(tensor).data();
console.log(predictions); // 输出预测结果数组
4. 输出后处理
javascript
// 获取分类标签(假设输出为概率数组)
const labels = ['cat', 'dog'];
const maxIndex = predictions.indexOf(Math.max(...predictions));
console.log('预测结果:', labels[maxIndex]);
四、高级特性与优化
1. **WebGL 加速
-
自动启用:浏览器中默认使用 WebGL 后端加速计算。
-
手动切换后端 :
javascripttf.setBackend('webgl'); // 强制使用 WebGL
2. **模型量化
- 目的:减小模型体积,提升推理速度。
- 方法 :使用
tensorflowjs_converter
时添加--quantize_float16
或--quantize_uint8
参数。
3. **模型分片加载
-
分片模型:将大模型拆分为多个文件,按需加载。
-
实现 :
javascriptconst model = await tf.loadLayersModel({ modelUrl: 'model.json', weightsUrlPrefix: './shards/' });
4. **内存管理
-
手动释放内存 :
javascripttf.dispose(tensor); // 释放单个张量 tf.disposeVariables(); // 释放所有变量
-
自动内存回收 :使用
tf.tidy()
包裹作用域:javascriptconst result = tf.tidy(() => { const a = tf.tensor([1, 2]); return a.square(); });
五、实战案例:图像分类
1. 加载 MobileNet 模型
javascript
const model = await tf.loadLayersModel('https://storage.googleapis.com/tfjs-models/tfjs/mobilenet_v1_0.25_224/model.json');
2. 图像预处理与推理
javascript
async function classifyImage(imgElement) {
const tensor = tf.browser.fromPixels(imgElement)
.resizeBilinear([224, 224])
.toFloat()
.div(255)
.expandDims();
const predictions = await model.predict(tensor).data();
return predictions;
}
3. 结合 UI 交互
html
<input type="file" id="upload" accept="image/*">
<img id="preview" width="224" height="224">
<div id="result"></div>
<script>
document.getElementById('upload').addEventListener('change', async (e) => {
const img = document.getElementById('preview');
img.src = URL.createObjectURL(e.target.files[0]);
const predictions = await classifyImage(img);
document.getElementById('result').textContent = `分类结果: ${getTopClass(predictions)}`;
});
</script>
六、调试与性能优化
1. 性能监控
-
查看张量内存 :
javascriptconsole.log(tf.memory()); // 输出内存使用情况
-
帧率监控 :使用
stats.js
库监控渲染性能。
2. 错误排查
- 张量形状不匹配 :检查输入是否符合模型预期(如
model.input.shape
)。 - 模型加载失败:检查网络请求或文件路径是否正确。
3. 最佳实践
-
使用异步加载:避免阻塞主线程。
-
预热模型 :提前运行一次推理以减少首次延迟。
javascript// 预热模型 const warmupTensor = tf.zeros(model.input.shape); model.predict(warmupTensor); tf.dispose(warmupTensor);