【face-api.js】2️⃣ NetInput - 神经网络输入封装类

前言:今天来研究一下NetInput,它是 face-api.js 中将输入数据统一化,封装成神经网络输入的类。它将不同类型的输入(Canvas、Tensor等)统一管理,并提供批量处理能力。另外我在B站大学开始学习周志华老师的机器学习课程,非常推荐🙌🏻🙌🏻🙌🏻🙌🏻。后续希望可以写文章分享一下机器学习的学习心得。

定义的位置:src/dom/NetInput.ts

一、调用流程

我们还是以项目中的第一个功能为例,看一下关于 NetInput 的调用流程。我发现研究一个功能有很多角度,上一篇文章,我是总体的看了一下整个流程的执行,大致都干了什么事情,其中有非常多看不懂的地方。这一篇文章来针对 NetInput 类,看一下这个类在什么地方用,怎么用。话说,加木的歌也太太太好听了吧🥰🥰🥰🥰🥰

(一)开始处理图像

examples/examples-browser/views/faceDetection.html 这个文件是上图中的模块文件,页面初始化会获取图片开始处理。

javascript 复制代码
async function updateResults() {
   const inputImgEl = $('#inputImg').get(0)
   // 获取面部检测选项
   const options = getFaceDetectorOptions()
   // 进行面部检测
   const results = await faceapi.detectAllFaces(inputImgEl, options)
}

此时会把图片元素,HTMLElement类型,传入 faceapi.detectAllFacesinputImgEl 会作为 NetInput 类的输入参数。

(二)detectAllFaces

src/globalApi/detectFaces.ts

javascript 复制代码
export function detectAllFaces(
  input: TNetInput,
  options: FaceDetectionOptions = new SsdMobilenetv1Options()
): DetectAllFacesTask {
  return new DetectAllFacesTask(input, options)
}

TNetInput类:
src/dom/types.ts

HTML 中的图片、视频、Canvas元素;tf.Tensor3D 、tf.Tensor4D;NetInput元素

javascript 复制代码
import * as tf from '@tensorflow/tfjs-core';

import { NetInput } from './NetInput';

export type TMediaElement = HTMLImageElement | HTMLVideoElement | HTMLCanvasElement

export type TResolvedNetInput = TMediaElement | tf.Tensor3D | tf.Tensor4D

export type TNetInputArg = string | TResolvedNetInput

export type TNetInput = TNetInputArg | Array<TNetInputArg> | NetInput | tf.Tensor4D

(三)DetectAllFacesTask

src/globalApi/DetectFacesTasks.ts

这个文件中会调用这个方法:
nets.ssdMobilenetv1.locateFaces(input, options)

(四)locateFaces

src/ssdMobilenetv1/SsdMobilenetv1.ts

主角出场!toNetInput 方法会将各种输入类型(HTMLImageElement、Canvas、Tensor等)统一转换为 NetInput 格式

javascript 复制代码
const netInput = await toNetInput(input)

二、toNetInput 方法

功能概述:

这是一个输入标准化函数,用于将不同类型的输入(图像元素、视频元素、Canvas、张量等)

转换为统一的 NetInput 格式,以便神经网络可以处理。

主要作用:

  1. 输入类型转换:支持多种输入类型,统一转换为 NetInput
  2. 输入验证:检查输入是否有效,不符合要求时抛出错误
  3. 媒体加载等待:确保所有媒体元素(图片、视频)都已加载完成
  4. 批量处理支持:支持单个输入或批量输入(数组)

支持的输入类型(TNetInput):

  • HTMLImageElement: HTML 图片元素
  • HTMLVideoElement: HTML 视频元素
  • HTMLCanvasElement: HTML Canvas 元素
  • tf.Tensor3D: 3维张量 [height, width, channels]
  • tf.Tensor4D: 4维张量 [batch, height, width, channels]
  • string: 元素ID字符串(浏览器环境),会通过 document.getElementById() 解析
  • NetInput: 如果已经是 NetInput 对象,直接返回
  • Array: 上述类型的数组,用于批量处理

返回值:

@returns {Promise} NetInput 对象,包含:

  • batchSize: 批次大小(单个输入为1,数组输入为数组长度)
  • canvases: HTMLCanvasElement 数组(用于媒体元素)
  • imageTensors: 张量数组(用于张量输入)
  • 提供 toBatchTensor() 方法:将输入转换为批量张量

使用场景:

  • 在神经网络的前向传播前,统一输入格式
  • 处理来自不同来源的图像数据(文件上传、摄像头、Canvas绘制等)
  • 批量处理多张图像

(一)解析输入

javascript 复制代码
// 如果输入已经是 NetInput 对象,直接返回(避免重复转换)
if (inputs instanceof NetInput) {
  return inputs
}

// 将输入统一转换为数组格式,方便后续处理
// 单个输入:[input],数组输入:保持原样
let inputArgArray = Array.isArray(inputs)
  ? inputs
  : [inputs]

// 验证:确保输入数组不为空
if (!inputArgArray.length) {
  throw new Error('toNetInput - empty array passed as input')
}

// 辅助函数:生成错误提示中的索引信息
// 如果是数组输入,显示 "at input index X",否则为空字符串
const getIdxHint = (idx: number) => Array.isArray(inputs) ? ` at input index ${idx}:` : ''

// 步骤1:解析输入
// resolveInput 会将字符串ID转换为 DOM 元素(浏览器环境)
// 其他类型保持不变
const inputArray = inputArgArray.map(resolveInput)

(二)验证每个输入的类型

必须是以下类型之一:

  • HTMLImageElement | HTMLVideoElement | HTMLCanvasElement(媒体元素)
  • tf.Tensor3D(3维张量)
  • tf.Tensor4D(4维张量,但批次大小必须为1)
javascript 复制代码
// 步骤2:验证每个输入的类型
inputArray.forEach((input, i) => {
  // 检查是否为有效类型
  if (!isMediaElement(input) && !isTensor3D(input) && !isTensor4D(input)) {
    // 如果是字符串但无法解析为 DOM 元素
    if (typeof inputArgArray[i] === 'string') {
      throw new Error(`toNetInput -${getIdxHint(i)} string passed, but could not resolve HTMLElement for element id ${inputArgArray[i]}`)
    }

    // 其他无效类型
    throw new Error(`toNetInput -${getIdxHint(i)} expected media to be of type HTMLImageElement | HTMLVideoElement | HTMLCanvasElement | tf.Tensor3D, or to be an element id`)
  }

  // 特殊验证:如果输入是 Tensor4D,批次大小必须为1
  // 原因:在输入数组中,每个元素代表一个独立的输入,不应该包含批次维度
  // 如果需要批量处理,应该传入多个 Tensor3D,而不是一个 Tensor4D
  if (isTensor4D(input)) {
    const batchSize = input.shape[0]
    if (batchSize !== 1) {
      throw new Error(`toNetInput -${getIdxHint(i)} tf.Tensor4D with batchSize ${batchSize} passed, but not supported in input array`)
    }
  }
})

(三)等待所有媒体元素加载完成

这是异步操作,确保图片和视频都已完全加载。

对于张量输入,awaitMediaLoaded 会返回 undefined,Promise.all 会忽略

javascript 复制代码
// 步骤3:等待所有媒体元素加载完成
await Promise.all(
  inputArray.map(input => isMediaElement(input) && awaitMediaLoaded(input))
)

看一下 awaitMediaLoaded 的处理,是监听的原生的 loaderror 事件,这个手法在实际项目中也可以借鉴。
src/dom/awaitMediaLoaded.ts

javascript 复制代码
export function awaitMediaLoaded(media: HTMLImageElement | HTMLVideoElement | HTMLCanvasElement) {

  return new Promise((resolve, reject) => {
    if (media instanceof env.getEnv().Canvas || isMediaLoaded(media)) {
      return resolve()
    }

    function onLoad(e: Event) {
      if (!e.currentTarget) return
      e.currentTarget.removeEventListener('load', onLoad)
      e.currentTarget.removeEventListener('error', onError)
      resolve(e)
    }

    function onError(e: Event) {
      if (!e.currentTarget) return
      e.currentTarget.removeEventListener('load', onLoad)
      e.currentTarget.removeEventListener('error', onError)
      reject(e)
    }

    media.addEventListener('load', onLoad)
    media.addEventListener('error', onError)
  })
}

(四)创建并返回 NetInput 对象

第二个参数表示是否作为批量输入处理

  • true: 原始输入是数组,明确表示批量处理
  • false: 原始输入是单个元素,但被转换为数组格式
    当前输入的是一个图片元素,所以是第二种情况
javascript 复制代码
// 步骤4:创建并返回 NetInput 对象
return new NetInput(inputArray, Array.isArray(inputs))

三、NetInput类

(一)私有属性

支持的输入类型(TResolvedNetInput):

  • HTMLImageElement: HTML 图片元素
  • HTMLVideoElement: HTML 视频元素
  • HTMLCanvasElement: HTML Canvas 元素
  • tf.Tensor3D: 3维张量 [height, width, channels]
  • tf.Tensor4D: 4维张量 [batch=1, height, width, channels](批次大小必须为1)
js 复制代码
/** 张量输入数组:存储 Tensor3D 或 Tensor4D 类型的输入 */
private _imageTensors: Array<tf.Tensor3D | tf.Tensor4D> = []

/** Canvas 数组:存储 HTMLCanvasElement 类型的输入(从媒体元素转换而来) */
private _canvases: HTMLCanvasElement[] = []

/** 批次大小:输入的数量(单个输入为1,批量输入为数组长度) */
private _batchSize: number

/** 是否作为批量输入处理:true 表示明确作为批量输入,false 表示单个输入 */
private _treatAsBatchInput: boolean = false

/** 输入尺寸数组:每个输入的原始尺寸 [height, width, channels] */
private _inputDimensions: number[][] = []

/** 网络输入尺寸:调用 toBatchTensor() 后设置,表示调整后的输入尺寸(如 512) */
private _inputSize: number

(二)构造函数

1、基础

功能:

初始化 NetInput 对象,根据输入类型进行分类存储:

  • 张量输入 → 存储在 _imageTensors
  • 媒体元素 → 转换为 Canvas 存储在 _canvases
    参数:
    @param {Array} inputs - 输入数组,每个元素可以是:
  • HTMLImageElement | HTMLVideoElement | HTMLCanvasElement(媒体元素)
  • tf.Tensor3D(3维张量)
  • tf.Tensor4D(4维张量,批次大小必须为1)
    @param {boolean} treatAsBatchInput - 是否作为批量输入处理(默认:false)
  • true: 明确表示这是批量输入(即使只有一个元素)
  • false: 单个输入(即使被转换为数组格式)

处理逻辑:

  1. Tensor3D: 直接存储,记录形状 [height, width, channels]
  2. Tensor4D: 验证批次大小为1,存储,记录形状 [height, width, channels](去掉批次维度)
  3. 媒体元素: 转换为 Canvas,记录尺寸 [height, width, 3]

Channels(通道) 是图像中每个像素的颜色信息维度,当前使用的是 rgb 颜色表示法,Channels(通道) 的值为3。

javascript 复制代码
constructor(
  inputs: Array<TResolvedNetInput>,
  treatAsBatchInput: boolean = false
) {
  // 验证输入必须是数组
  if (!Array.isArray(inputs)) {
    throw new Error(`NetInput.constructor - expected inputs to be an Array of TResolvedNetInput or to be instanceof tf.Tensor4D, instead have ${inputs}`)
  }

  // 设置批量输入标志和批次大小
  this._treatAsBatchInput = treatAsBatchInput
  this._batchSize = inputs.length

  // 遍历每个输入,根据类型进行分类存储
  inputs.forEach((input, idx) => {
    // 情况1:Tensor3D 输入
    // 形状:[height, width, channels]
    if (isTensor3D(input)) {
      this._imageTensors[idx] = input
      this._inputDimensions[idx] = input.shape  // [height, width, channels]
      return
    }

    // 情况2:Tensor4D 输入
    // 形状:[batch, height, width, channels]
    // 注意:批次大小必须为1(在输入数组中,每个元素代表一个独立输入)
    if (isTensor4D(input)) {
      const batchSize = input.shape[0]
      if (batchSize !== 1) {
        throw new Error(`NetInput - tf.Tensor4D with batchSize ${batchSize} passed, but not supported in input array`)
      }

      this._imageTensors[idx] = input
      // 去掉批次维度,只记录 [height, width, channels]
      this._inputDimensions[idx] = input.shape.slice(1)
      return
    }

    // 情况3:媒体元素输入(HTMLImageElement、HTMLVideoElement、HTMLCanvasElement)
    // 如果是 Canvas,直接使用;否则转换为 Canvas
    const canvas = input instanceof env.getEnv().Canvas
      ? input
      : createCanvasFromMedia(input)

    this._canvases[idx] = canvas
    // 记录 Canvas 的尺寸:[height, width, 3](RGB 3通道)
    this._inputDimensions[idx] = [canvas.height, canvas.width, 3]
  })
}

2、元素类型转换

如果是非Canvas的HTML元素,需要转换成Canvas
src/dom/createCanvas.ts

javascript 复制代码
import { env } from '../env';
export function createCanvasFromMedia(media: HTMLImageElement | HTMLVideoElement | ImageData, dims?: IDimensions): HTMLCanvasElement {

  // 获取环境对象
  const { ImageData } = env.getEnv()

  if (!(media instanceof ImageData) && !isMediaLoaded(media)) {
    throw new Error('createCanvasFromMedia - media has not finished loading yet')
  }

  const { width, height } = dims || getMediaDimensions(media)
  const canvas = createCanvas({ width, height })

  if (media instanceof ImageData) {
    getContext2dOrThrow(canvas).putImageData(media, 0, 0)
  } else {
    getContext2dOrThrow(canvas).drawImage(media, 0, 0, width, height)
  }
  return canvas
}

3、env.getEnv()

我们来看一下 env.getEnv() 方法
src/env/index.ts
getEnv() 获取当前运行环境的 Environment 对象。

返回一个 Environment 对象,该对象包含了当前运行环境(浏览器或 Node.js)的 API 抽象。

这使得 face-api.js 可以在浏览器和 Node.js 环境中使用相同的代码。
返回的 Environment 对象包含:

  • Canvas: Canvas 元素的构造函数
  • Image: Image 元素的构造函数
  • Video: Video 元素的构造函数
  • ImageData: ImageData 构造函数
  • createCanvasElement: 创建 Canvas 元素的函数
  • createImageElement: 创建 Image 元素的函数
  • fetch: 网络请求函数
  • readFile: 文件读取函数(Node.js 环境)
    ** 使用场景:**
  • 需要创建 Canvas 元素时
  • 需要创建 Image 元素时
  • 需要判断输入类型时(instanceof 检查)
  • 需要访问环境特定的 API 时
javascript 复制代码
function getEnv(): Environment {
  if (!environment) {
    throw new Error('getEnv - environment is not defined, check isNodejs() and isBrowser()')
  }
  return environment
}

Environment类:

javascript 复制代码
export type FileSystem = {
  readFile: (filePath: string) => Promise<Buffer>
}

export type Environment = FileSystem & {
  Canvas: typeof HTMLCanvasElement
  CanvasRenderingContext2D: typeof CanvasRenderingContext2D
  Image: typeof HTMLImageElement
  ImageData: typeof ImageData
  Video: typeof HTMLVideoElement
  createCanvasElement: () => HTMLCanvasElement
  createImageElement: () => HTMLImageElement
  fetch: (url: string, init?: RequestInit) => Promise<Response>
}

4、创建canvas元素

javascript 复制代码
const { width, height } = dims || getMediaDimensions(media)
const canvas = createCanvas({ width, height })

if (media instanceof ImageData) {
  getContext2dOrThrow(canvas).putImageData(media, 0, 0)
} else {
  getContext2dOrThrow(canvas).drawImage(media, 0, 0, width, height)
}

这里用到了两个canvas原生的API

  • putImageData
    将数据从已有的 ImageData 对象绘制到画布上。
    参数为:
    • ImageData
      • data,描述了一个一维数组,包含以 RGBA 顺序的数据,数据使用 0 至 255(包含)的整数表示。
      • height
      • width
    • dx:目标画布中放置图像数据的水平位置(x 坐标)。
    • dy:目标画布中放置图像数据的垂直位置(y 坐标)。
  • drawImage
    在画布(Canvas)上绘制图像,参数如下
    • image,绘制到上下文的元素。允许任何的画布图像源,例如:HTMLImageElement、SVGImageElement、HTMLVideoElement、HTMLCanvasElement、ImageBitmap、OffscreenCanvas 或 VideoFrame
    • sx,需要绘制到目标上下文中的,源 image 的子矩形(裁剪)的左上角 X 轴坐标。
    • sy,需要绘制到目标上下文中的,源 image 的子矩形(裁剪)的左上角 Y 轴坐标。
    • sWidth,需要绘制到目标上下文中的,源 image 的子矩形(裁剪)的宽度。
    • sHeight,需要绘制到目标上下文中的,image的矩形(裁剪)选择框的高度。

最终会把canvas元素返回。

(三)公共属性访问器

提供可以访问 NetInout对象公共属性的接口

javascript 复制代码
/**
  * 获取张量输入数组
  * @returns {Array<tf.Tensor3D | tf.Tensor4D>} 张量输入数组
  */
 public get imageTensors(): Array<tf.Tensor3D | tf.Tensor4D> {
   return this._imageTensors
 }

 /**
  * 获取 Canvas 输入数组
  * @returns {HTMLCanvasElement[]} Canvas 元素数组
  */
 public get canvases(): HTMLCanvasElement[] {
   return this._canvases
 }

 /**
  * 判断是否为批量输入
  * @returns {boolean} true 表示批量输入,false 表示单个输入
  */
 public get isBatchInput(): boolean {
   return this.batchSize > 1 || this._treatAsBatchInput
 }

 /**
  * 获取批次大小
  * @returns {number} 输入的数量
  */
 public get batchSize(): number {
   return this._batchSize
 }

 /**
  * 获取所有输入的原始尺寸
  * @returns {number[][]} 每个输入的尺寸数组 [height, width, channels]
  */
 public get inputDimensions(): number[][] {
   return this._inputDimensions
 }

 /**
  * 获取网络输入尺寸
  * @returns {number | undefined} 调用 toBatchTensor() 后设置的输入尺寸(如 512),未调用则为 undefined
  */
 public get inputSize(): number | undefined {
   return this._inputSize
 }

 /**
  * 获取所有输入调整后的尺寸(调用 toBatchTensor 后的尺寸)
  * @returns {Dimensions[]} 每个输入调整后的尺寸数组
  */
 public get reshapedInputDimensions(): Dimensions[] {
   return range(this.batchSize, 0, 1).map(
     (_, batchIdx) => this.getReshapedInputDimensions(batchIdx)
   )
 }

(四)公共方法

javascript 复制代码
/**
 * 获取指定批次的输入
 * 优先返回 Canvas,如果没有则返回张量
 * 
 * @param {number} batchIdx - 批次索引(从 0 开始)
 * @returns {tf.Tensor3D | tf.Tensor4D | HTMLCanvasElement} 输入元素
 */
public getInput(batchIdx: number): tf.Tensor3D | tf.Tensor4D | HTMLCanvasElement {
  return this.canvases[batchIdx] || this.imageTensors[batchIdx]
}

/**
 * 获取指定批次的原始输入尺寸
 * 
 * @param {number} batchIdx - 批次索引
 * @returns {number[]} 尺寸数组 [height, width, channels]
 */
public getInputDimensions(batchIdx: number): number[] {
  return this._inputDimensions[batchIdx]
}

/**
 * 获取指定批次的输入高度
 * 
 * @param {number} batchIdx - 批次索引
 * @returns {number} 图像高度(像素)
 */
public getInputHeight(batchIdx: number): number {
  return this._inputDimensions[batchIdx][0]
}

/**
 * 获取指定批次的输入宽度
 * 
 * @param {number} batchIdx - 批次索引
 * @returns {number} 图像宽度(像素)
 */
public getInputWidth(batchIdx: number): number {
  return this._inputDimensions[batchIdx][1]
}

/**
 * 获取指定批次调整后的输入尺寸
 * 
 * 功能:
 * 计算调用 toBatchTensor() 后,输入图像调整后的尺寸。
 * 这考虑了图像可能被填充(padding)或调整大小(resize)的情况。
 * 
 * 注意:
 * - 必须在调用 toBatchTensor() 之后才能使用此方法
 * - 如果未调用 toBatchTensor(),会抛出错误
 * 
 * @param {number} batchIdx - 批次索引
 * @returns {Dimensions} 调整后的尺寸对象 {width, height}
 * 
 * 示例:
 * ```typescript
 * const netInput = await toNetInput(imageElement)
 * const batchTensor = netInput.toBatchTensor(512, false)
 * const reshapedDims = netInput.getReshapedInputDimensions(0)
 * // reshapedDims = {width: 512, height: 512}(如果图像被调整为正方形)
 * ```
 */
public getReshapedInputDimensions(batchIdx: number): Dimensions {
  if (typeof this.inputSize !== 'number') {
    throw new Error('getReshapedInputDimensions - inputSize not set, toBatchTensor has not been called yet')
  }

  const width = this.getInputWidth(batchIdx)
  const height = this.getInputHeight(batchIdx)
  // 计算调整后的尺寸(考虑填充和缩放)
  return computeReshapedDimensions({ width, height }, this.inputSize)
}

(五)核心方法 toBatchTensor

这是 NetInput 的核心方法,将所有输入(Canvas 或张量)转换为统一的批量张量格式, 供神经网络使用。它会处理图像尺寸调整、填充、类型转换等操作。
主要作用:

  1. 统一输入格式:将所有输入转换为相同尺寸的张量
  2. 尺寸标准化:将图像调整为指定的 inputSize x inputSize
  3. 填充处理:将非正方形图像填充为正方形(可选居中填充)
  4. 批量堆叠:将所有输入堆叠成一个批量张量
  5. 类型转换:转换为 float 类型

参数:

  • @param {number} inputSize - 目标输入尺寸(高度和宽度,如 512)

    所有输入都会被调整为这个尺寸

  • @param {boolean} isCenterInputs - 是否居中填充(默认:true)

    • true: 在较短边两侧均匀填充,图像居中
    • false: 在右侧和底部填充

返回值:

@returns {tf.Tensor4D} 批量张量

  • 形状:[batchSize, inputSize, inputSize, 3]
  • 数据类型:float32
  • 值范围:[0, 255](像素值)

1、张量的格式化

javascript 复制代码
// 步骤1:处理每个输入,转换为统一格式的张量
const inputTensors = range(this.batchSize, 0, 1).map(batchIdx => {
  const input = this.getInput(batchIdx)

  // 情况1:输入是张量(Tensor3D 或 Tensor4D)
  if (input instanceof tf.Tensor) {
    // 统一为 Tensor4D 格式 [1, h, w, 3]
    let imgTensor = isTensor4D(input)
      ? input
      : input.expandDims<tf.Rank.R4>()

    // 步骤2:填充为正方形(如果非正方形)
    // isCenterInputs: true 表示居中填充,false 表示右侧/底部填充
    imgTensor = padToSquare(imgTensor, isCenterInputs)

    // 步骤3:调整尺寸到 inputSize x inputSize(如果需要)
    // 使用双线性插值进行缩放
    if (imgTensor.shape[1] !== inputSize || imgTensor.shape[2] !== inputSize) {
      imgTensor = tf.image.resizeBilinear(imgTensor, [inputSize, inputSize])
    }

    // 确保形状为 [inputSize, inputSize, 3]
    return imgTensor.as3D(inputSize, inputSize, 3)
  }
})

2、cancas的格式化

js 复制代码
// 步骤1:处理每个输入,转换为统一格式的张量
const inputTensors = range(this.batchSize, 0, 1).map(batchIdx => {
  const input = this.getInput(batchIdx)
  // 情况2:输入是 Canvas
  if (input instanceof env.getEnv().Canvas) {
    // imageToSquare() 将 Canvas 转换为正方形(填充)
    // tf.browser.fromPixels() 将 Canvas 转换为张量
    // 自动处理尺寸调整
    return tf.browser.fromPixels(imageToSquare(input, inputSize, isCenterInputs))
  }
})

tf.browser.fromPixels

用于创建指定图像的像素值的张量。

  • 参数:此函数接受两个参数,如下所示:
    • pixels:它是要从中构造张量的输入图像的像素。支持的图像类型均为4通道。不过这里我们是输入了一个canvas元素。
    • numchannels:它是输出张量的通道数。默认值为3,上限为4。
  • 返回值:此函数返回指定图像的已创建像素张量值。

3、将所有张量堆叠成批量张量

js 复制代码
// 1. 转换为 float 类型(toFloat())
// 2. 堆叠所有张量(stack())
// 3. 确保形状为 [batchSize, inputSize, inputSize, 3]
const batchTensor = tf.stack(inputTensors.map(t => t.toFloat()))
  .as4D(this.batchSize, inputSize, inputSize, 3)

tf.stack

将同一形状的张量堆叠起来,会提升张量的维度

  • 参数
    • tensors:同一形状和同一数据类型的张量对象列表
    • axis:指定在哪个位置创建新维度。默认是 0
  • 返回值

tf.stack

将张量的维度提升为 4 维

总结:NetInput - 神经网络输入封装类会将输入的张量或图像处理成方方正正的形状,并且处理为批量的四维张量,方便后续统一处理。

相关推荐
yongche_shi2 小时前
第九十九篇:Python在其他领域的应用:游戏开发、物联网、AIoT简介
开发语言·python·物联网·游戏开发·aiot
froginwe112 小时前
Node.js 回调函数
开发语言
期待のcode2 小时前
Java中的继承
java·开发语言
TAEHENGV2 小时前
关于应用模块 Cordova 与 OpenHarmony 混合开发实战
android·javascript·数据库
资深低代码开发平台专家2 小时前
MicroQuickJS:为极致资源而生的嵌入式JavaScript革命
开发语言·javascript·ecmascript
世转神风-2 小时前
qt-通信协议基础-固定长度-小端字节序补0x00指导
开发语言·qt
czlczl200209252 小时前
基于 Spring Boot 权限管理 RBAC 模型
前端·javascript·spring boot
期待のcode2 小时前
Java中的super关键字
java·开发语言
TM1Club2 小时前
Zoey的TM1聊天室|#3 合并报表提速:业财一体如何实现关联方对账自动化
大数据·开发语言·人工智能·经验分享·数据分析·自动化·数据库系统