JAX: 快如 PyTorch,简单如 NumPy - 深度学习与数据科学

JAX 是 TensorFlow 和 PyTorch 的新竞争对手。 JAX 强调简单性而不牺牲速度和可扩展性。由于 JAX 需要更少的样板代码,因此程序更短、更接近数学,因此更容易理解。

长话短说:

  • 使用 import jax.numpy 访问 NumPy 函数,使用 import jax.scipy 访问 SciPy 函数。
  • 通过使用 @jax.jit 进行装饰,可以加快即时编译速度。
  • 使用 jax.grad 求导。
  • 使用 jax.vmap 进行矢量化,并使用 jax.pmap 进行跨设备并行化。

函数式编程

JAX 遵循函数式编程哲学。这意味着您的函数必须是独立的或纯粹的:不允许有副作用。本质上,纯函数看起来像数学函数(图 1)。有输入进来,有东西出来,但与外界没有沟通。

例子#1

以下代码片段是一个非功能纯的示例。

复制代码
import jax.numpy as jnp

bias = jnp.array(0)
def impure_example(x):
   total = x + bias
   return total

注意 impure_example 之外的偏差。在编译期间(见下文),偏差可能会被缓存,因此不再反映偏差的变化。

例子#2

这是一个pure的例子。

复制代码
def pure_example(x, weights, bias):
   activation = weights @ x + bias
   return activation

在这里,pure_example 是独立的:所有参数都作为参数传递。

确定性采样器

在计算机中,不存在真正的随机性。相反,NumPy 和 TensorFlow 等库会跟踪伪随机数状态来生成"随机"样本。

函数式编程的直接后果是随机函数的工作方式不同。由于不再允许全局状态,因此每次采样随机数时都需要显式传入伪随机数生成器 (PRNG) 密钥

复制代码
import jax

key = jax.random.PRNGKey(42)
u = jax.random.uniform(key)

此外,您有责任为任何后续调用推进"随机状态"。

复制代码
key = jax.random.PRNGKey(43)

# Split off and consume subkey.
key, subkey = jax.random.split(key)
u = jax.random.uniform(subkey)

# Split off and consume second subkey.
key, subkey = jax.random.split(key)
u = jax.random.uniform(subkey)

..

jit

您可以通过即时编译 JAX 指令来加快代码速度。例如,要编译缩放指数线性单位 (SELU) 函数,请使用 jax.numpy 中的 NumPy 函数并将 jax.jit 装饰器添加到该函数,如下所示:

复制代码
from jax import jit

@jit
def selu(x, α=1.67, λ=1.05):
 return λ * jnp.where(x > 0, x, α * jnp.exp(x) - α)

JAX 会跟踪您的指令并将其转换为 jaxpr。这使得加速线性代数 (XLA) 编译器能够为您的加速器生成非常高效的优化代码。

gard

JAX 最强大的功能之一是您可以轻松获取 gard。使用 jax.grad,您可以定义一个新函数,即符号导数。

复制代码
from jax import grad

def f(x):
   return x + 0.5 * x**2

df_dx = grad(f)
d2f_dx2 = grad(grad(f))

正如您在示例中看到的,您不仅限于一阶导数。您可以通过简单地按顺序链接 grad 函数 n 次来获取 n 阶导数。

vmap 和 pmap

矩阵乘法使所有批次尺寸正确需要非常细心。 JAX 的矢量化映射函数 vmap 通过对函数进行矢量化来减轻这种负担。基本上,每个按元素应用函数 f 的代码块都是由 vmap 替换的候选者。让我们看一个例子。

计算线性函数:

复制代码
def linear(x):
 return weights @ x

在一批示例 [x₁, x2,..] 中,我们可以天真地(没有 vmap)实现它,如下所示:

复制代码
def naively_batched_linear(X_batched):
 return jnp.stack([linear(x) for x in X_batched])

相反,通过使用 vmap 对线性进行向量化,我们可以一次性计算整个批次:

复制代码
def vmap_batched_linear(X_batched):
 return vmap(linear)(X_batched)

本文由mdnice多平台发布

相关推荐
九章云极AladdinEdu17 小时前
超参数自动化调优指南:Optuna vs. Ray Tune 对比评测
运维·人工智能·深度学习·ai·自动化·gpu算力
研梦非凡20 小时前
ICCV 2025|从粗到细:用于高效3D高斯溅射的可学习离散小波变换
人工智能·深度学习·学习·3d
通街市密人有1 天前
IDF: Iterative Dynamic Filtering Networks for Generalizable Image Denoising
人工智能·深度学习·计算机视觉
智数研析社1 天前
9120 部 TMDb 高分电影数据集 | 7 列全维度指标 (评分 / 热度 / 剧情)+API 权威源 | 电影趋势分析 / 推荐系统 / NLP 建模用
大数据·人工智能·python·深度学习·数据分析·数据集·数据清洗
七元权1 天前
论文阅读-Correlate and Excite
论文阅读·深度学习·注意力机制·双目深度估计
ViperL11 天前
[智能算法]可微的神经网络搜索算法-FBNet
人工智能·深度学习·神经网络
2202_756749691 天前
LLM大模型-大模型微调(常见微调方法、LoRA原理与实战、LLaMA-Factory工具部署与训练、模型量化QLoRA)
人工智能·深度学习·llama
人有一心1 天前
深度学习中显性特征组合的网络结构crossNet
人工智能·深度学习
猫天意1 天前
【目标检测】metrice_curve和loss_curve对比图可视化
人工智能·深度学习·目标检测·计算机视觉·cv
蒋星熠1 天前
如何在Anaconda中配置你的CUDA & Pytorch & cuNN环境(2025最新教程)
开发语言·人工智能·pytorch·python·深度学习·机器学习·ai