Google JAX — фреймворкмашинного обучения для преобразования числовых функций.[2][3][4] Представляет объединение измененной версии autograd (автоматическое получение градиентной функции через дифференцирование функции) и TensorFlow's XLA (Ускоренная линейная алгебра (Accelerated Linear Algebra)). Спроектирован таким образом, чтобы максимально соответствовать структуре и рабочему процессу NumPy для работы с различными существующими фреймворками, такими как TensorFlow и PyTorch.[5][6] Основными функциями JAX являются:[2]
Код представленный ниже демонстрирует функцию автоматического дифференцирования пакета grad.
# importsfromjaximportgradimportjax.numpyasjnp# define the logistic functiondeflogistic(x):returnjnp.exp(x)/(jnp.exp(x)+1)# obtain the gradient function of the logistic functiongrad_logistic=grad(logistic)# evaluate the gradient of the logistic function at x = 1 grad_log_out=grad_logistic(1.0)print(grad_log_out)
Код представленный ниже демонстрирует функцию оптимизации через слияние пакета jit.
# importsfromjaximportjitimportjax.numpyasjnp# define the cube functiondefcube(x):returnx*x*x# generate datax=jnp.ones((10000,10000))# create the jit version of the cube functionjit_cube=jit(cube)# apply the cube and jit_cube functions to the same data for speed comparisoncube(x)jit_cube(x)
Вычислительное время для jit_cube (строка 17) должно быть заметно короче, чем для cube (строка 16). Увеличение значения в строке 7, будет увеличивать разницу.
Код представленный ниже демонстрирует распараллеливание для умножения матриц пакета pmap.
# import pmap and random from JAX; import JAX NumPyfromjaximportpmap,randomimportjax.numpyasjnp# generate 2 random matrices of dimensions 5000 x 6000, one per devicerandom_keys=random.split(random.PRNGKey(0),2)matrices=pmap(lambdakey:random.normal(key,(5000,6000)))(random_keys)# without data transfer, in parallel, perform a local matrix multiplication on each CPU/GPUoutputs=pmap(lambdax:jnp.dot(x,x.T))(matrices)# without data transfer, in parallel, obtain the mean for both matrices on each CPU/GPU separatelymeans=pmap(jnp.mean)(outputs)print(means)
Последняя строка должна напечатать значенияː
[1.15665951.1805978]
Библиотеки, использующие Jax
Несколько библиотек Python используют Jax в качестве бэкенда, включая:
Flax — высокоуровневая библиотека для нейронных сетей изначально разработанная Google Brain.[7]
Haiku — объектно-ориентированная библиотека для нейронных сетей разработанная DeepMind.[8]
Equinox — библиотека, основанная на идеи представления параметризованных функций (включая нейронные сети) как PyTrees. Она была создана Патриком Кидгером.[9]
Optax — библиотека для градиентной обработки и оптимизации разработанная DeepMind.[10]