Google JAX

Google JAX
Логотип программы Google JAX
Скриншот программы Google JAX
Тип Machine learning
Разработчик Google (компания)
Написана на Python, C++
Операционные системы Linux, macOS, Windows
Первый выпуск 12 декабря 2018
Аппаратные платформы Python, NumPy
Последняя версия
Тестовая версия v0.3.13 (16 мая 2022; 2 года назад (2022-05-16))
Репозиторий github.com/google/jax
Лицензия Apache 2.0
Сайт jax.readthedocs.io/en/la…

Google JAX — фреймворк машинного обучения для преобразования числовых функций.[2][3][4] Представляет объединение измененной версии autograd (автоматическое получение градиентной функции через дифференцирование функции) и TensorFlow's XLA (Ускоренная линейная алгебра (Accelerated Linear Algebra)). Спроектирован таким образом, чтобы максимально соответствовать структуре и рабочему процессу NumPy для работы с различными существующими фреймворками, такими как TensorFlow и PyTorch.[5][6] Основными функциями JAX являются:[2]

  1. grad: автоматическое дифференцирование
  2. jit: компиляция
  3. vmap: автоматическая векторизация
  4. pmap: SPMD программирование

grad

Код представленный ниже демонстрирует функцию автоматического дифференцирования пакета grad.

# imports
from jax import grad
import jax.numpy as jnp

# define the logistic function
def logistic(x):  
    return jnp.exp(x) / (jnp.exp(x) + 1)

# obtain the gradient function of the logistic function
grad_logistic = grad(logistic)

# evaluate the gradient of the logistic function at x = 1 
grad_log_out = grad_logistic(1.0)   
print(grad_log_out)

Код должен напечатать:

0.19661194

jit

Код представленный ниже демонстрирует функцию оптимизации через слияние пакета jit.

# imports
from jax import jit
import jax.numpy as jnp

# define the cube function
def cube(x):
    return x * x * x

# generate data
x = jnp.ones((10000, 10000))

# create the jit version of the cube function
jit_cube = jit(cube)

# apply the cube and jit_cube functions to the same data for speed comparison
cube(x)
jit_cube(x)

Вычислительное время для jit_cube (строка 17) должно быть заметно короче, чем для cube (строка 16). Увеличение значения в строке 7, будет увеличивать разницу.

vmap

Код представленный ниже демонстрирует функцию векторизации пакета vmap.

# imports
from functools import partial
from jax import vmap
import jax.numpy as jnp

# define function
def grads(self, inputs):
    in_grad_partial = partial(self._net_grads, self._net_params)
    grad_vmap = jax.vmap(in_grad_partial)
    rich_grads = grad_vmap(inputs)
    flat_grads = np.asarray(self._flatten_batch(rich_grads))
    assert flat_grads.ndim == 2 and flat_grads.shape[0] == inputs.shape[0]
    return flat_grads

Изображение в правой части раздела иллюстрирует идея векторизованного сложения.

Иллюстрационное видео векторизованного сложения

pmap

Код представленный ниже демонстрирует распараллеливание для умножения матриц пакета pmap.

# import pmap and random from JAX; import JAX NumPy
from jax import pmap, random
import jax.numpy as jnp

# generate 2 random matrices of dimensions 5000 x 6000, one per device
random_keys = random.split(random.PRNGKey(0), 2)
matrices = pmap(lambda key: random.normal(key, (5000, 6000)))(random_keys)

# without data transfer, in parallel, perform a local matrix multiplication on each CPU/GPU
outputs = pmap(lambda x: jnp.dot(x, x.T))(matrices)

# without data transfer, in parallel, obtain the mean for both matrices on each CPU/GPU separately
means = pmap(jnp.mean)(outputs)
print(means)

Последняя строка должна напечатать значенияː

[1.1566595 1.1805978]

Библиотеки, использующие Jax

Несколько библиотек Python используют Jax в качестве бэкенда, включая:

  • Flax — высокоуровневая библиотека для нейронных сетей изначально разработанная Google Brain.[7]
  • Haiku — объектно-ориентированная библиотека для нейронных сетей разработанная DeepMind.[8]
  • Equinox — библиотека, основанная на идеи представления параметризованных функций (включая нейронные сети) как PyTrees. Она была создана Патриком Кидгером.[9]
  • Optax — библиотека для градиентной обработки и оптимизации разработанная DeepMind.[10]
  • RLax — библиотека для разработки агентов для обучения с подкреплением, разработанная DeepMind.[11]

См. также

Примечания

  1. https://github.com/google/jax/releases/tag/jax-v0.4.24
  2. 1 2 Bradbury, James; Frostig, Roy; Hawkins, Peter; Johnson, Matthew James; Leary, Chris; MacLaurin, Dougal; Necula, George; Paszke, Adam; Vanderplas, Jake; Wanderman-Milne, Skye; Zhang, Qiao (2022-06-18), "JAX: Autograd and XLA", Astrophysics Source Code Library, Google, Bibcode:2021ascl.soft11002B, Архивировано 18 июня 2022, Дата обращения: 18 июня 2022
  3. Frostig, Roy; Johnson, Matthew James; Leary, Chris (2018-02-02). "Compiling machine learning programs via high-level tracing" (PDF). MLsys: 1—3. Архивировано (PDF) 21 июня 2022.{{cite journal}}: Википедия:Обслуживание CS1 (дата и год) (ссылка)
  4. Using JAX to accelerate our research (англ.). www.deepmind.com. Дата обращения: 18 июня 2022. Архивировано 18 июня 2022 года.
  5. Lynley, Matthew Google is quietly replacing the backbone of its AI product strategy after its last big push for dominance got overshadowed by Meta (амер. англ.). Business Insider. Дата обращения: 21 июня 2022. Архивировано 21 июня 2022 года.
  6. Why is Google's JAX so popular? (амер. англ.). Analytics India Magazine (25 апреля 2022). Дата обращения: 18 июня 2022. Архивировано 18 июня 2022 года.
  7. Flax: A neural network library and ecosystem for JAX designed for flexibility, Google, 2022-07-29, Архивировано 3 сентября 2022, Дата обращения: 29 июля 2022
  8. Haiku: Sonnet for JAX, DeepMind, 2022-07-29, Архивировано 29 июля 2022, Дата обращения: 29 июля 2022
  9. Kidger, Patrick (2022-07-29), Equinox, Архивировано 19 сентября 2023, Дата обращения: 29 июля 2022
  10. Optax, DeepMind, 2022-07-28, Архивировано 7 июня 2023, Дата обращения: 29 июля 2022
  11. RLax, DeepMind, 2022-07-29, Архивировано 26 апреля 2023, Дата обращения: 29 июля 2022

Ссылки