JAX is a machine learning framework for transforming numerical functions.[2][3][4] It is described as bringing together a modified version of autograd (automatic obtaining of the gradient function through differentiation of a function) and OpenXLA's XLA (Accelerated Linear Algebra). It is designed to follow the structure and workflow of NumPy as closely as possible and works with various existing frameworks such as TensorFlow and PyTorch.[5][6] The primary functions of JAX are:[2]
The below code demonstrates the grad function's automatic differentiation.
# importsfromjaximportgradimportjax.numpyasjnp# define the logistic functiondeflogistic(x):returnjnp.exp(x)/(jnp.exp(x)+1)# obtain the gradient function of the logistic functiongrad_logistic=grad(logistic)# evaluate the gradient of the logistic function at x = 1 grad_log_out=grad_logistic(1.0)print(grad_log_out)
The final line should outputː
0.19661194
jit
The below code demonstrates the jit function's optimization through fusion.
# importsfromjaximportjitimportjax.numpyasjnp# define the cube functiondefcube(x):returnx*x*x# generate datax=jnp.ones((10000,10000))# create the jit version of the cube functionjit_cube=jit(cube)# apply the cube and jit_cube functions to the same data for speed comparisoncube(x)jit_cube(x)
The computation time for jit_cube (line #17) should be noticeably shorter than that for cube (line #16). Increasing the values on line #7, will further exacerbate the difference.
The GIF on the right of this section illustrates the notion of vectorized addition.
pmap
The below code demonstrates the pmap function's parallelization for matrix multiplication.
# import pmap and random from JAX; import JAX NumPyfromjaximportpmap,randomimportjax.numpyasjnp# generate 2 random matrices of dimensions 5000 x 6000, one per devicerandom_keys=random.split(random.PRNGKey(0),2)matrices=pmap(lambdakey:random.normal(key,(5000,6000)))(random_keys)# without data transfer, in parallel, perform a local matrix multiplication on each CPU/GPUoutputs=pmap(lambdax:jnp.dot(x,x.T))(matrices)# without data transfer, in parallel, obtain the mean for both matrices on each CPU/GPU separatelymeans=pmap(jnp.mean)(outputs)print(means)