Jax(Random、Numpy)常用函数

目录

Jax

vmap

Array

reshape

Random

PRNGKey

uniform

normal

split

choice

Numpy

expand_dims

linspace

jax.numpy.linalg[pkg]

dot

matmul

arange

interp

tile

reshape


Jax

jit

jax.jit(fun , in_shardings=UnspecifiedValue , out_shardings=UnspecifiedValue , static_argnums=None , static_argnames=None , donate_argnums=None , donate_argnames=None , keep_unused=False , device=None , backend=None , inline=False , abstracted_axes=None )[source]

注:jax.jit 是 JAX 中的一个装饰器,用于将 Python 函数编译为高效的机器代码,以提高运行速度。JIT(Just-In-Time)编译可以加速函数的执行,尤其是在循环或需要多次调用。

python 复制代码
>>>jax.jit(lambda x,y : x + y)
<PjitFunction of <function <lambda> at 0x7ea7b402f130>>
>>>jax.jit(lambda x,y : x + y)(1,2) #process jitfunc -> lambda fun
Array(3, dtype=int32, weak_type=True)
>>>@jax.jit
   def fun(x,y):
        return x + y
>>>fun
<PjitFunction of <function fun at 0x7ea7b402f5b0>>
>>>fun(1,2)
Array(3, dtype=int32, weak_type=True)

vmap

jax.vmap (fun , in_axes=0 , out_axes=0 , axis_name=None , axis_size=None , spmd_axis_name=None )[source]

注:对函数进行向量化处理,通常用于批量处理数据,而不需要显式地编写循环,函数映射调用,区别于pmap,vmap单个设备(CPU或GPU)上处理批量数据,pmap在多个设备(GPU或TPU)上并行处理数据(分布式)

python 复制代码
>>>f_xy = lambda x,y : x + y
>>>x = jax.numpy.array([[1, 2], 
                        [3, 4]])  # shape (2, 2)
>>>y = jax.numpy.array([[5, 6], 
                        [7, 8]])  # shape (2, 2)

# in this x and y array, axis 0 is row , axis 1 is col, ref shape index
# in x and y, axis -1 is shape[-1] , axis -2 is shape[-2]

>>>jax.vmap(f_xy,in_axes=(0,0))(x,y)      # default out_axes = 0,row ouput
# x row + y row , need x row dim equal y row dim
Array([[ 6,  8],
       [10, 12]], dtype=int32)
>>>jax.vmap(f_xy,in_axes=(0,0),out_axes=1)(x,y) #show output by col
Array([[ 6,  8],
       [10, 12]], dtype=int32)
>>>jax.vmap(f_xy,in_axes=(0,1))(x,y) 
# x row + y col , need x row's dim equal y col's dim
Array([[ 6,  9],
       [ 9, 12]], dtype=int32)
>>>jax.vmap(f_xy,in_axes=(0,1),out_axes=1)(x,y) #show output by col 
Array([[ 6,  9],
       [ 9, 12]], dtype=int32)
>>>jax.vmap(f_xy,in_axes=(None,0))(x,y) #no vector x by row or col, x is block
# x block + y row vector, x shape (2,2) , y shape(2,2), need x row equal y row
# return shape(y_dim_2,x_dim_1,x_dim2)
Array([[[ 6,  8],
        [ 8, 10]],
       [[ 8, 10],
        [10, 12]]], dtype=int32)

ref:Learning about JAX :axes in vmap()

Array

reshape

abstract Array.reshape(*args , order='C' )[source]

注:Array对象的实例方法,引用jax.numpy.reshape函数

Random

PRNGKey

jax.random.PRNGKey (seed , * , impl=None )[source]#

注:创建一个 PRNG key,作为生成随机数的种子Seed

eg:

python 复制代码
>>>jax.random.PRNGKey(0)
Array([0, 0], dtype=uint32)

uniform

jax.random.uniform (key , shape=() , dtype=<class 'float'> , minval=0.0 , maxval=1.0 )[source]

注:在给定的形状(shape)和数据类型(dtype)下,从 [minval, maxval) 区间内采样均匀分布的随机值

python 复制代码
>>>k = jax.random.PRNGKey(0)
>>>jax.random.uniform(k,shape=(1,))
Array([0.41845703], dtype=float32)

normal

normal (key , shape=() , dtype=<class 'float'> )[source]

注:在给定的形状shape和浮点数据类型dtype下,采样标准正态分布的随机值

python 复制代码
>>>k = jax.random.PRNGKey(0)
>>>jax.random.normal(k,shape=(1,))
Array([-0.20584226], dtype=float32)

split

jax.random.split(key , num=2 )[source]

注:用于生成伪随机数生成器(PRNG)状态的函数。它允许你从一个现有的 PRNG 状态中生成多个新的状态,从而实现随机数的可重复性和并行性。

python 复制代码
>>>k = jax.random.PRNGKey(1)
>>>k1,k2 = jax.random.split(k)
>>>k1
Array([2441914641, 1384938218], dtype=uint32)
>>>k2
Array([3819641963, 2025898573], dtype=uint32)

choice

jax.random.choice(key , a , shape=() , replace=True , p=None , axis=0 )[source]

注:从给定数组a中按shape生成随机样本,区别于numpy.random.choice函数。default choice one elem。

python 复制代码
>>>k = jax.random.PRNGKey(0)
>>>a = jax.numpy.array([1,2,3,4,5,6,7,8,9,0])
>>>jax.random.choice(k,a,(10,)) # random no seq
Array([9, 6, 8, 7, 8, 4, 1, 2, 3, 3], dtype=int32)
>>>jax.random.choice(k,a,(2,5))
Array([[9, 6, 8, 7, 8],
       [4, 1, 2, 3, 3]], dtype=int32)

Numpy

expand_dims

expand_dims (a , axis )[source]

注:为数组a的维度axis增加1维度

python 复制代码
>>>arr = jax.numpy.array([1,2,3])
>>>arr.shape
(3,)
>>>jax.numpy.expand_dims(arr,axis=0)
Array([[1, 2, 3]], dtype=int32)
>>>jax.numpy.expand_dims(arr,axis=0).shape
(1, 3)
>>>jax.numpy.expand_dims(arr,axis=1)
Array([[1],
       [2],
       [3]], dtype=int32)
>>>jax.numpy.expand_dims(arr,axis=1).shape
(3, 1)

linspace

linspace (start: ArrayLike , stop: ArrayLike , num: int = 50 , endpoint: bool = True , retstep: Literal[False] = False , dtype: DTypeLike | None = None , axis: int = 0 , * , device: xc.Device | Sharding | None = None ) → Array[source]

注:在给定区间[start,stop]内返回均匀间隔的数字

python 复制代码
>>>jax.numpy.linspace(0,1,5)
Array([0.  , 0.25, 0.5 , 0.75, 1.  ], dtype=float32)

jax.numpy.linalg[pkg]

jax.numpy.linalg 是 JAX 库中用于线性代数操作的模块,对应numpy.linalg库实现

cholesky

jax.numpy.linalg.cholesky(a , * , upper=False )[source]

注:计算一个正定矩阵A的 Cholesky 分解,得到满足A=L@L.T等式的下三角或上三角矩阵L,@为Python1.5定义的矩阵乘运算(jax.numpy.matmul),L.T为L转置矩阵

python 复制代码
>>> d = jax.numpy.array([[2. , 1.],
                         [1. , 2.]])
>>>jax.numpy.linalg.cholesky(d)
Array([[1.4142135 , 0.        ],
       [0.70710677, 1.2247449 ]], dtype=float32)

>>>L = jax.numpy.linalg.cholesky(d)
>>>L@L.T
Array([[1.9999999 , 0.99999994],
       [0.99999994, 2.        ]], dtype=float32)
eigvalsh

jax.numpy.linalg.eigvalsh(a , UPLO='L' )[source]

注:计算 Hermitian 对称矩阵的特征值。对于一个给定的方阵 A,其特征值 λ 和特征向量 v满足以下关系Av=λv。cholesky分解矩阵需满足特征值>0。

python 复制代码
>>>jax.numpy.linalg.eigvalsh(jax.numpy.array([[1,-1],
                                              [-1,1]]))
Array([0., 2.], dtype=float32)
cond

jax.numpy.linalg.cond(x , p=None )[source]

注:用于计算矩阵的条件数(condition number),这是衡量矩阵在数值计算中稳定性的重要指标。高条件数警示需要谨慎对待矩阵的计算,尤其是在求解线性方程或进行其他数值计算时,如cholesky分解。

python 复制代码
>>>jax.numpy.linalg.cond(jax.numpy.array([[1,2],
                                          [2,1]]))
Array(3., dtype=float32)

allclose

jax.numpy.allclose(a , b , rtol=1e-05 , atol=1e-08 , equal_nan=False )[source]

注:检查两个数组的元素是否在容差范围内近似相等,cholesky分解矩阵需满足对称性。

python 复制代码
>>>A=jax.numpy.array([[4, 2],
                      [2, 3]])
>>>jax.numpy.allclose(A,A.T)
Array(True, dtype=bool)
# A 为对称矩阵

dot

dot(a , b , * , precision=None , preferred_element_type=None )[source]

注:用于计算两个数组的点积(dot product),对于一维数组,它计算的是向量的内积;对于二维数组(矩阵),它计算的是矩阵乘积;对于更高维度的数组,它执行的是逐元素的点积,并在最后一个轴上进行求和

  • 对于一维数组(向量)numpy.dot(a, b) 计算的是向量 ab 的点积,结果是一个标量。
  • 对于二维数组(矩阵)numpy.dot(A, B) 计算的是矩阵 AB 的乘积,其中 A 的列数必须与 B 的行数相等。结果是一个新的矩阵。
  • 对于更高维度的数组numpy.dot() 可以进行更复杂的广播和求和运算,但通常用于计算张量积(tensor product)的某个维度上的和。
python 复制代码
>>>jax.numpy.dot(jax.numpy.array([1,2,3]),2)
Array([2, 4, 6], dtype=int32)
>>>jax.numpy.dot(jax.numpy.array([1,2,3]),jax.numpy.array([1,2,3]))
Array(14, dtype=int32)
>>>jax.numpy.dot(jax.numpy.array([[1,2,3],
                                  [4,5,6]]),
                  jax.numpy.array([1,2,3]))
Array([14, 32], dtype=int32)
>>>jax.numpy.dot(jax.numpy.array([[1,2],
                                  [4,5]]),
                 jax.numpy.array([[1,2],
                                  [4,5]]))
Array([[ 9, 12],
       [24, 33]], dtype=int32)
>>>a = jax.numpy.zeros((1,3,2))
>>>b = jax.numpy.zeros((1,2,4))
>>>jax.numpy.dot(a,b).shape
(1, 3, 1, 4) #matmul ret (1,3,4)

matmul

matmul (a , b , * , precision=None , preferred_element_type=None )[source]#

注:于执行矩阵乘法,也称为 @ 运算符(在 Python 3.5+ 中引入),对于一维数组(向量),它计算的是内积(与 dot 相同);对于二维数组(矩阵),它计算的是矩阵乘积(与 dot 相同);对于更高维度的数组,它执行的是逐元素的矩阵乘法,并保留其他轴

  • 对于一维数组(向量)numpy.matmul(a, b) 通常不被定义为向量之间的运算,除非 a 是一个二维数组(表示多个向量)的单个行或列,并且 b 的形状与之兼容。
  • 对于二维数组(矩阵)numpy.matmul(A, B) 计算的是矩阵 AB 的乘积,其中 A 的列数必须与 B 的行数相等。这与 numpy.dot() 对于二维数组的行为相同。
  • 对于更高维度的数组numpy.matmul() 遵循爱因斯坦求和约定(Einstein summation convention)的特定规则,允许在不同维度的数组之间执行矩阵乘法。这包括批处理矩阵乘法,其中每个批次独立地进行乘法运算。
python 复制代码
>>>jax.numpy.matmul(jax.numpy.array([1,2,3]),jax.numpy.array([1,2,3]))
Array(14, dtype=int32)
>>>jax.numpy.matmul(jax.numpy.array([[1,2,3],
                                     [4,5,6]]),
                     jax.numpy.array([1,2,3]))
Array([14, 32], dtype=int32)
>>>jax.numpy.matmul(jax.numpy.array([[1,2],
                                     [4,5]]),
                    jax.numpy.array([[1,2],
                                     [4,5]]))
Array([[ 9, 12],
       [24, 33]], dtype=int32)
>>>a = jax.numpy.zeros((1,3,2))
>>>b = jax.numpy.zeros((1,2,4))
>>>jax.numpy.matmul(a,b).shape
(1, 3, 4) #dot ret (1,3,1,4)

arange

jax.numpy.arange(start , stop=None , step=None , dtype=None , * , device=None )[source]

注:default step 为1,在区间[start,stop)生成步长为1的数组,类似range函数

python 复制代码
>>>jax.numpy.arange(0,10,1)
Array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int32)

interp

interp (x , xp , fp , left=None , right=None , period=None )[source]

注:在xp点列表中线性插值x,线性插值满足,xi和xi+1表示xp数组相邻两点,插值x位于两点区间之间,xp点对于y值为fp,线性插值为保持符合fp = fun(xp)两点区间斜率的增量

python 复制代码
>>>xp = jax.numpy.arange(0,10,1)
>>>fp = jax.numpy.array(range(0,10,1)) * 2
>>>x = jax.numpy.array([1,2,3])
>>>jax.numpy.interp(x,xp,fp)
Array([2., 4., 6.], dtype=float32)

tile

jax.numpy.tile(A , reps )[source]

注:将A数组按reps重复化生成新Array

python 复制代码
a = jax.numpy.array([1,2,3])
>>>jax.numpy.tile(a,2)
Array([1, 2, 3, 1, 2, 3], dtype=int32)
>>>jax.numpy.tile(a,(2,))
Array([1, 2, 3, 1, 2, 3], dtype=int32)
>>>jax.numpy.tile(a,(1,1))
Array([[1, 2, 3]], dtype=int32)
>>>jax.numpy.tile(a,(2,1)) # repeat axis 0 (row) by 2, repeat axis 1 (col) by 1
Array([[1, 2, 3],
       [1, 2, 3]], dtype=int32)

reshape

jax.numpy.reshape(a , shape=None , order='C' , * , newshape=Deprecated , copy=None )[source]

注:从定义Array a的shape形状为shape元组(),支持-1,推断dim数值

python 复制代码
>>>a = jax.numpy.array([[1, 2, 3],
                        [4, 5, 6]])
>>>jax.numpy.reshape(a,6) # equal reshape(a,(6,))
Array([1, 2, 3, 4, 5, 6], dtype=int32)
>>>jax.numpy.reshape(a,-1) # equal reshape(a,6)  -1 is inferred to be 3
Array([1, 2, 3, 4, 5, 6], dtype=int32)
>>>jax.numpy.reshape(a,(-1,2)) # equal reshape(a,(3,2)) , -1 is inferred to be 3
Array([[1, 2],
       [3, 4],
       [5, 6]], dtype=int32)
>>>jax.numpy.reshape(a,(1,-1)) # not (n,) inferred to 2 d
Array([[1, 2, 3, 4, 5, 6]], dtype=int32)

meshgrid

jax.numpy.meshgrid(*xi , copy=True , sparse=False , indexing='xy' )[source]

注:创建坐标矩阵,将一维坐标向量xi(自变量x、y)转换为对应的二维坐标向量或矩阵,适用于计算网格点上的函数值(因变量z),默认indexing='xy'输出笛卡尔坐标(row为vector),indexing='ij'输出矩阵坐标(col为vector)

python 复制代码
>>>x = jax.numpy.array([1,2,3])
>>>y = jax.numpy.array([4,5])
>>>jax.numpy.meshgrid(x,y) #default indexing='xy'
[Array([[1, 2, 3],
        [1, 2, 3]], dtype=int32),
 Array([[4, 4, 4],
        [5, 5, 5]], dtype=int32)]
>>>jax.numpy.meshgrid(x,y,indexing='ij')
[Array([[1, 1],
        [2, 2],
        [3, 3]], dtype=int32),
 Array([[4, 5],
        [4, 5],
        [4, 5]], dtype=int32)]
>>>xv,yv = jax.numpy.meshgrid(x,y,indexing='xy')
>>>xv
Array([[1, 2, 3],
       [1, 2, 3]], dtype=int32)
>>>yv
Array([[4, 4, 4],
       [5, 5, 5]], dtype=int32)
>>>xv.ravel()
Array([1, 2, 3, 1, 2, 3], dtype=int32) 
>>>yv.ravel()
Array([4, 4, 4, 5, 5, 5], dtype=int32)

#Array.ravel return a view of array (no memory),  flatten return a copy of array

自变量x shape(3,) 自变量y shape(2,),对应平面6个点, 对应值因变量z shape为(6,) 6个数值

,二维坐标可视化代码:

python 复制代码
import jax
import matplotlib.pyplot as plt

x = jax.numpy.array([1,2,3])
y = jax.numpy.array([4,5])

xv,yv = jax.numpy.meshgrid(x,y,indexing='xy')

z = xv + yv

plt.scatter(xv.flatten(), yv.flatten(), c=z, cmap='viridis') #use xv , yv also show similar graph
plt.colorbar(label='u')
plt.xlim(0, 4)
plt.ylim(3, 6)
plt.xlabel('X axis')
plt.ylabel('Y axis')
plt.title('Grid Units Visualization')
plt.show()

尝试将点变多:

python 复制代码
import jax
import matplotlib.pyplot as plt

x = jax.numpy.linspace(0,10,100)
y = jax.numpy.linspace(0,10,100)

xv,yv = jax.numpy.meshgrid(x,y,indexing='xy')

z = xv + yv

plt.scatter(xv.flatten(), yv.flatten(), c=z, cmap='viridis')
plt.colorbar(label='z')
plt.xlabel('X axis')
plt.ylabel('Y axis')
plt.title('Grid Units Visualization')
plt.show()

eye

jax.numpy.eye(N , M=None , k=0 , dtype=None , * , device=None )[source]

注:用于创建单位矩阵的函数。单位矩阵是一种方阵,其主对角线上的元素为 1,其余元素为 0。

python 复制代码
>>>jax.numpy.eye(3)
Array([[1., 0., 0.],
       [0., 1., 0.],
       [0., 0., 1.]], dtype=float32)
相关推荐
算家云43 分钟前
PhotoMaker部署文档
人工智能·aigc·conda·图像生成·comfyui·工作流·文本转图像
小猪包3332 小时前
ai论文写作软件哪个好?分享5款ai论文题目生成器
人工智能·深度学习·计算机视觉·ai写作
luthane2 小时前
python 实现algorithm topo卡恩拓扑算法
数据结构·python·算法
云翼时代科技3 小时前
【探索艺术新纪元:Midjourney中文版,让创意无界!】
人工智能
KGback3 小时前
【项目记录】大模型基于llama.cpp在Qemu-riscv64向量扩展指令下的部署
人工智能·llama·riscv
ZPC82103 小时前
Pytorch详解-Pytorch核心模块
人工智能·pytorch·python·深度学习·机器学习
马甲是掉不了一点的<.<3 小时前
论文精读:基于渐进式转移的无监督域自适应舰船检测
人工智能·目标检测·计算机视觉·领域迁移
灵雀云3 小时前
CNAI趋势下,打造一体化AI赋能平台
人工智能
985小水博一枚呀3 小时前
【深度学习基础模型】极限学习机(Extreme Learning Machines, ELM)详细理解并附实现代码。
人工智能·python·深度学习·极限学习机