机器学习入门:TensorFlow.js实战

机器学习入门:TensorFlow.js实战

大家好,我是欧阳瑞(Rich Own)。今天想和大家聊聊机器学习这个热门话题。作为一个全栈开发者,我最近一直在研究TensorFlow.js,它可以让我们在浏览器中直接运行机器学习模型。今天就来分享一下TensorFlow.js的基础知识和实战经验。

什么是TensorFlow.js?

TensorFlow.js是Google开发的一个开源机器学习库,可以在浏览器和Node.js环境中运行。它允许开发者:

  • 在浏览器中训练和运行模型
  • 使用预训练模型进行推理
  • 将模型部署到Web应用中

环境准备

html 复制代码
<!-- 在HTML中引入TensorFlow.js -->
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@4.10.0/dist/tf.min.js"></script>

或者使用npm:

bash 复制代码
npm install @tensorflow/tfjs

核心概念

Tensor(张量)

javascript 复制代码
// 创建张量
const tensor1d = tf.tensor1d([1, 2, 3, 4]);
const tensor2d = tf.tensor2d([[1, 2], [3, 4]]);
const tensor3d = tf.tensor3d([[[1], [2]], [[3], [4]]]);

// 张量操作
const a = tf.tensor2d([[1, 2], [3, 4]]);
const b = tf.tensor2d([[5, 6], [7, 8]]);

const sum = a.add(b);
const product = a.matMul(b);

// 打印张量
sum.print();

变量

javascript 复制代码
// 创建变量
const initialValue = tf.tensor([[1, 2], [3, 4]]);
const weights = tf.variable(initialValue);

// 更新变量
const newValue = tf.tensor([[5, 6], [7, 8]]);
weights.assign(newValue);

模型

javascript 复制代码
// 创建一个简单的模型
const model = tf.sequential({
  layers: [
    tf.layers.dense({inputShape: [784], units: 32, activation: 'relu'}),
    tf.layers.dense({units: 10, activation: 'softmax'})
  ]
});

// 编译模型
model.compile({
  optimizer: 'sgd',
  loss: 'categoricalCrossentropy',
  metrics: ['accuracy']
});

实战:手写数字识别

加载数据集

javascript 复制代码
// 加载MNIST数据集
async function loadData() {
  const mnistData = await tf.loadLayersModel('https://storage.googleapis.com/tfjs-models/tfjs/mnist_model/model.json');
  return mnistData;
}

创建模型

javascript 复制代码
function createModel() {
  const model = tf.sequential();
  
  model.add(tf.layers.conv2d({
    inputShape: [28, 28, 1],
    kernelSize: 3,
    filters: 16,
    activation: 'relu'
  }));
  
  model.add(tf.layers.maxPooling2d({poolSize: 2}));
  
  model.add(tf.layers.conv2d({
    kernelSize: 3,
    filters: 32,
    activation: 'relu'
  }));
  
  model.add(tf.layers.maxPooling2d({poolSize: 2}));
  
  model.add(tf.layers.flatten());
  
  model.add(tf.layers.dense({units: 64, activation: 'relu'}));
  
  model.add(tf.layers.dense({units: 10, activation: 'softmax'}));
  
  model.compile({
    optimizer: tf.train.adam(),
    loss: 'categoricalCrossentropy',
    metrics: ['accuracy']
  });
  
  return model;
}

训练模型

javascript 复制代码
async function trainModel(model, data) {
  const { images, labels } = data;
  
  const batchSize = 32;
  const epochs = 5;
  
  await model.fit(images, labels, {
    batchSize,
    epochs,
    shuffle: true,
    validationSplit: 0.1
  });
}

进行预测

javascript 复制代码
async function predict(model, image) {
  const tensor = tf.browser.fromPixels(image)
    .resizeNearestNeighbor([28, 28])
    .mean(2)
    .expandDims(0)
    .expandDims(-1)
    .toFloat()
    .div(tf.scalar(255));
  
  const prediction = model.predict(tensor);
  const result = await prediction.data();
  
  return result.indexOf(Math.max(...result));
}

预训练模型

使用MobileNet

javascript 复制代码
async function loadMobileNet() {
  const model = await tf.loadLayersModel('https://storage.googleapis.com/tfjs-models/tfjs/mobilenet_v1_0.25_224/model.json');
  return model;
}

async function classifyImage(model, image) {
  const tensor = tf.browser.fromPixels(image)
    .resizeNearestNeighbor([224, 224])
    .toFloat()
    .div(tf.scalar(127.5))
    .sub(tf.scalar(1));
  
  const predictions = await model.predict(tensor.expandDims(0)).data();
  return predictions;
}

使用Face API

javascript 复制代码
// 加载人脸识别模型
async function loadFaceModel() {
  await faceapi.loadSsdMobilenetv1Model('/models');
  await faceapi.loadFaceLandmarkModel('/models');
  await faceapi.loadFaceRecognitionModel('/models');
}

async function detectFaces(image) {
  const detections = await faceapi.detectAllFaces(image)
    .withFaceLandmarks()
    .withFaceDescriptors();
  
  return detections;
}

自定义模型训练

准备数据

javascript 复制代码
// 创建模拟数据
function generateData(numSamples) {
  const features = [];
  const labels = [];
  
  for (let i = 0; i < numSamples; i++) {
    const x = Math.random() * 10 - 5;
    const y = Math.sin(x) + (Math.random() - 0.5) * 0.1;
    
    features.push([x]);
    labels.push([y]);
  }
  
  return {
    features: tf.tensor2d(features),
    labels: tf.tensor2d(labels)
  };
}

训练回归模型

javascript 复制代码
async function trainRegressionModel() {
  const data = generateData(1000);
  
  const model = tf.sequential();
  model.add(tf.layers.dense({inputShape: [1], units: 10, activation: 'relu'}));
  model.add(tf.layers.dense({units: 10, activation: 'relu'}));
  model.add(tf.layers.dense({units: 1}));
  
  model.compile({
    optimizer: tf.train.adam(0.01),
    loss: 'meanSquaredError'
  });
  
  await model.fit(data.features, data.labels, {
    epochs: 100,
    batchSize: 32,
    verbose: 1
  });
  
  return model;
}

性能优化

使用WebGL加速

javascript 复制代码
// 检查是否支持WebGL
console.log(tf.getBackend());

// 切换到WebGL后端
tf.setBackend('webgl');

内存管理

javascript 复制代码
// 使用tf.tidy清理中间张量
const result = tf.tidy(() => {
  const a = tf.tensor([1, 2, 3]);
  const b = tf.tensor([4, 5, 6]);
  return a.add(b);
});

// 手动清理张量
const tensor = tf.tensor([1, 2, 3]);
tensor.dispose();

// 检查内存使用
console.log(tf.memory());

模型优化

javascript 复制代码
// 量化模型
const quantizedModel = await tf.loadLayersModel('model_quantized.json');

// 使用tf.data API处理大数据
const dataset = tf.data.generator(function*() {
  for (let i = 0; i < 1000; i++) {
    yield {xs: tf.tensor([i]), ys: tf.tensor([i * 2])};
  }
}).batch(32);

实战案例:情绪识别

javascript 复制代码
// 创建情绪识别模型
async function createEmotionModel() {
  const model = tf.sequential();
  
  model.add(tf.layers.conv2d({
    inputShape: [48, 48, 1],
    kernelSize: 3,
    filters: 32,
    activation: 'relu'
  }));
  
  model.add(tf.layers.maxPooling2d({poolSize: 2}));
  
  model.add(tf.layers.conv2d({
    kernelSize: 3,
    filters: 64,
    activation: 'relu'
  }));
  
  model.add(tf.layers.maxPooling2d({poolSize: 2}));
  
  model.add(tf.layers.flatten());
  
  model.add(tf.layers.dense({units: 128, activation: 'relu'}));
  model.add(tf.layers.dropout({rate: 0.5}));
  
  model.add(tf.layers.dense({units: 7, activation: 'softmax'}));
  
  model.compile({
    optimizer: tf.train.adam(0.001),
    loss: 'categoricalCrossentropy',
    metrics: ['accuracy']
  });
  
  return model;
}

总结

TensorFlow.js是一个强大的工具,可以让我们在浏览器中运行机器学习模型。从简单的张量操作到复杂的神经网络,TensorFlow.js都能胜任。

我的鬃狮蜥Hash对机器学习也有自己的理解------它总是通过观察来学习我的行为模式,这也许就是自然界的"监督学习"吧!

如果你对机器学习感兴趣,欢迎留言交流!我是欧阳瑞,极客之路,永无止境!


技术栈:TensorFlow.js · 机器学习 · WebGL · 神经网络

相关推荐
mutourend17 小时前
Zcash 与量子计算机
区块链·量子计算·后量子密码学
TechubNews18 小时前
稳定币下一战:不是谁发币,而是谁掌握结算通道
人工智能·web3·区块链
mutourend20 小时前
量子计算与区块链:让紧迫性与真实威胁相匹配
区块链·量子计算·后量子密码学
多年小白1 天前
A股算力租赁板块 深度分析
大数据·人工智能·ai·金融·区块链
架构源启1 天前
Spring AI 进阶系列- Agent 智能体开发:ReAct模式、多步推理与自主Agent实战
人工智能·spring·react·ai agent·智能体·springai
小牛itbull1 天前
ReactPress 3.0 :一分钟拥有自己的 CMS & 博客
开源·cms·react·博客系统·reactpress
黄焖鸡能干四碗1 天前
固定资产管理系统建设方案和源码(Java源码)
大数据·数据库·人工智能·物联网·区块链
master-dragon1 天前
DeFi 基础: 流动性、池子、AMM、滑点
区块链
多年小白2 天前
【本周复盘】2026年5月11日-5月15日
人工智能·ai·金融·区块链