机器学习入门:TensorFlow.js实战
大家好,我是欧阳瑞(Rich Own)。今天想和大家聊聊机器学习这个热门话题。作为一个全栈开发者,我最近一直在研究TensorFlow.js,它可以让我们在浏览器中直接运行机器学习模型。今天就来分享一下TensorFlow.js的基础知识和实战经验。
什么是TensorFlow.js?
TensorFlow.js是Google开发的一个开源机器学习库,可以在浏览器和Node.js环境中运行。它允许开发者:
- 在浏览器中训练和运行模型
- 使用预训练模型进行推理
- 将模型部署到Web应用中
环境准备
html
<!-- 在HTML中引入TensorFlow.js -->
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@4.10.0/dist/tf.min.js"></script>
或者使用npm:
bash
npm install @tensorflow/tfjs
核心概念
Tensor(张量)
javascript
// 创建张量
const tensor1d = tf.tensor1d([1, 2, 3, 4]);
const tensor2d = tf.tensor2d([[1, 2], [3, 4]]);
const tensor3d = tf.tensor3d([[[1], [2]], [[3], [4]]]);
// 张量操作
const a = tf.tensor2d([[1, 2], [3, 4]]);
const b = tf.tensor2d([[5, 6], [7, 8]]);
const sum = a.add(b);
const product = a.matMul(b);
// 打印张量
sum.print();
变量
javascript
// 创建变量
const initialValue = tf.tensor([[1, 2], [3, 4]]);
const weights = tf.variable(initialValue);
// 更新变量
const newValue = tf.tensor([[5, 6], [7, 8]]);
weights.assign(newValue);
模型
javascript
// 创建一个简单的模型
const model = tf.sequential({
layers: [
tf.layers.dense({inputShape: [784], units: 32, activation: 'relu'}),
tf.layers.dense({units: 10, activation: 'softmax'})
]
});
// 编译模型
model.compile({
optimizer: 'sgd',
loss: 'categoricalCrossentropy',
metrics: ['accuracy']
});
实战:手写数字识别
加载数据集
javascript
// 加载MNIST数据集
async function loadData() {
const mnistData = await tf.loadLayersModel('https://storage.googleapis.com/tfjs-models/tfjs/mnist_model/model.json');
return mnistData;
}
创建模型
javascript
function createModel() {
const model = tf.sequential();
model.add(tf.layers.conv2d({
inputShape: [28, 28, 1],
kernelSize: 3,
filters: 16,
activation: 'relu'
}));
model.add(tf.layers.maxPooling2d({poolSize: 2}));
model.add(tf.layers.conv2d({
kernelSize: 3,
filters: 32,
activation: 'relu'
}));
model.add(tf.layers.maxPooling2d({poolSize: 2}));
model.add(tf.layers.flatten());
model.add(tf.layers.dense({units: 64, activation: 'relu'}));
model.add(tf.layers.dense({units: 10, activation: 'softmax'}));
model.compile({
optimizer: tf.train.adam(),
loss: 'categoricalCrossentropy',
metrics: ['accuracy']
});
return model;
}
训练模型
javascript
async function trainModel(model, data) {
const { images, labels } = data;
const batchSize = 32;
const epochs = 5;
await model.fit(images, labels, {
batchSize,
epochs,
shuffle: true,
validationSplit: 0.1
});
}
进行预测
javascript
async function predict(model, image) {
const tensor = tf.browser.fromPixels(image)
.resizeNearestNeighbor([28, 28])
.mean(2)
.expandDims(0)
.expandDims(-1)
.toFloat()
.div(tf.scalar(255));
const prediction = model.predict(tensor);
const result = await prediction.data();
return result.indexOf(Math.max(...result));
}
预训练模型
使用MobileNet
javascript
async function loadMobileNet() {
const model = await tf.loadLayersModel('https://storage.googleapis.com/tfjs-models/tfjs/mobilenet_v1_0.25_224/model.json');
return model;
}
async function classifyImage(model, image) {
const tensor = tf.browser.fromPixels(image)
.resizeNearestNeighbor([224, 224])
.toFloat()
.div(tf.scalar(127.5))
.sub(tf.scalar(1));
const predictions = await model.predict(tensor.expandDims(0)).data();
return predictions;
}
使用Face API
javascript
// 加载人脸识别模型
async function loadFaceModel() {
await faceapi.loadSsdMobilenetv1Model('/models');
await faceapi.loadFaceLandmarkModel('/models');
await faceapi.loadFaceRecognitionModel('/models');
}
async function detectFaces(image) {
const detections = await faceapi.detectAllFaces(image)
.withFaceLandmarks()
.withFaceDescriptors();
return detections;
}
自定义模型训练
准备数据
javascript
// 创建模拟数据
function generateData(numSamples) {
const features = [];
const labels = [];
for (let i = 0; i < numSamples; i++) {
const x = Math.random() * 10 - 5;
const y = Math.sin(x) + (Math.random() - 0.5) * 0.1;
features.push([x]);
labels.push([y]);
}
return {
features: tf.tensor2d(features),
labels: tf.tensor2d(labels)
};
}
训练回归模型
javascript
async function trainRegressionModel() {
const data = generateData(1000);
const model = tf.sequential();
model.add(tf.layers.dense({inputShape: [1], units: 10, activation: 'relu'}));
model.add(tf.layers.dense({units: 10, activation: 'relu'}));
model.add(tf.layers.dense({units: 1}));
model.compile({
optimizer: tf.train.adam(0.01),
loss: 'meanSquaredError'
});
await model.fit(data.features, data.labels, {
epochs: 100,
batchSize: 32,
verbose: 1
});
return model;
}
性能优化
使用WebGL加速
javascript
// 检查是否支持WebGL
console.log(tf.getBackend());
// 切换到WebGL后端
tf.setBackend('webgl');
内存管理
javascript
// 使用tf.tidy清理中间张量
const result = tf.tidy(() => {
const a = tf.tensor([1, 2, 3]);
const b = tf.tensor([4, 5, 6]);
return a.add(b);
});
// 手动清理张量
const tensor = tf.tensor([1, 2, 3]);
tensor.dispose();
// 检查内存使用
console.log(tf.memory());
模型优化
javascript
// 量化模型
const quantizedModel = await tf.loadLayersModel('model_quantized.json');
// 使用tf.data API处理大数据
const dataset = tf.data.generator(function*() {
for (let i = 0; i < 1000; i++) {
yield {xs: tf.tensor([i]), ys: tf.tensor([i * 2])};
}
}).batch(32);
实战案例:情绪识别
javascript
// 创建情绪识别模型
async function createEmotionModel() {
const model = tf.sequential();
model.add(tf.layers.conv2d({
inputShape: [48, 48, 1],
kernelSize: 3,
filters: 32,
activation: 'relu'
}));
model.add(tf.layers.maxPooling2d({poolSize: 2}));
model.add(tf.layers.conv2d({
kernelSize: 3,
filters: 64,
activation: 'relu'
}));
model.add(tf.layers.maxPooling2d({poolSize: 2}));
model.add(tf.layers.flatten());
model.add(tf.layers.dense({units: 128, activation: 'relu'}));
model.add(tf.layers.dropout({rate: 0.5}));
model.add(tf.layers.dense({units: 7, activation: 'softmax'}));
model.compile({
optimizer: tf.train.adam(0.001),
loss: 'categoricalCrossentropy',
metrics: ['accuracy']
});
return model;
}
总结
TensorFlow.js是一个强大的工具,可以让我们在浏览器中运行机器学习模型。从简单的张量操作到复杂的神经网络,TensorFlow.js都能胜任。
我的鬃狮蜥Hash对机器学习也有自己的理解------它总是通过观察来学习我的行为模式,这也许就是自然界的"监督学习"吧!
如果你对机器学习感兴趣,欢迎留言交流!我是欧阳瑞,极客之路,永无止境!
技术栈:TensorFlow.js · 机器学习 · WebGL · 神经网络