JAX
JAX,是用于变换数值函数的Python机器学习框架,它由Google开发并具有来自Nvidia的一些贡献[4][5][6]。它结合了修改版本的Autograd(自动通过函数的微分获得其梯度函数)[7],和OpenXLA的XLA(加速线性代数)[8]。它被设计为尽可能的遵从NumPy的结构和工作流程,并协同工作于各种现存的框架如TensorFlow和PyTorch[9][10]。 主要功能JAX的主要功能是[4]:
grad下面的代码演示 # 导入库
from jax import grad
import jax.numpy as jnp
# 定义logistic函数
def logistic(x):
return jnp.exp(x) / (jnp.exp(x) + 1)
# 获得logistic函数的梯度函数
grad_logistic = grad(logistic)
# 求值logistic函数在x = 1处的梯度
grad_log_out = grad_logistic(1.0)
print(grad_log_out)
最终的输出为: 0.19661194
jit下面的代码演示 # 导入库
from jax import jit
import jax.numpy as jnp
# 定义cube函数
def cube(x):
return x * x * x
# 生成数据
x = jnp.ones((10000, 10000))
# 创建cube函数的jit版本
jit_cube = jit(cube)
# 应用cube函数和jit_cube函数于相同数据来比较其速度
cube(x)
jit_cube(x)
可见 vmap下面的代码展示 # 导入库
from functools import partial
from jax import vmap
import jax.numpy as jnp
# 定义函数
def grads(self, inputs):
in_grad_partial = partial(self._net_grads, self._net_params)
grad_vmap = vmap(in_grad_partial)
rich_grads = grad_vmap(inputs)
flat_grads = np.asarray(self._flatten_batch(rich_grads))
assert flat_grads.ndim == 2 and flat_grads.shape[0] == inputs.shape[0]
return flat_grads
pmap下面的代码展示 # 从JAX导入pmap和random;导入JAX NumPy
from jax import pmap, random
import jax.numpy as jnp
# 生成2个维度为5000 x 6000的随机数矩阵,每设备一个
random_keys = random.split(random.PRNGKey(0), 2)
matrices = pmap(lambda key: random.normal(key, (5000, 6000)))(random_keys)
# 没有数据传输,并行的在每个CPU/GPU上进行局部矩阵乘法
outputs = pmap(lambda x: jnp.dot(x, x.T))(matrices)
# 没有数据传输,并行的在每个CPU/GPU上分别求取这两个矩阵的均值
means = pmap(jnp.mean)(outputs)
print(means)
最终的输出为: [1.1566595 1.1805978]
使用JAX的库一些Python库使用JAX作为后端,这包括:
参见引用
外部链接
|
Index:
pl ar de en es fr it arz nl ja pt ceb sv uk vi war zh ru af ast az bg zh-min-nan bn be ca cs cy da et el eo eu fa gl ko hi hr id he ka la lv lt hu mk ms min no nn ce uz kk ro simple sk sl sr sh fi ta tt th tg azb tr ur zh-yue hy my ace als am an hyw ban bjn map-bms ba be-tarask bcl bpy bar bs br cv nv eml hif fo fy ga gd gu hak ha hsb io ig ilo ia ie os is jv kn ht ku ckb ky mrj lb lij li lmo mai mg ml zh-classical mr xmf mzn cdo mn nap new ne frr oc mhr or as pa pnb ps pms nds crh qu sa sah sco sq scn si sd szl su sw tl shn te bug vec vo wa wuu yi yo diq bat-smg zu lad kbd ang smn ab roa-rup frp arc gn av ay bh bi bo bxr cbk-zam co za dag ary se pdc dv dsb myv ext fur gv gag inh ki glk gan guw xal haw rw kbp pam csb kw km kv koi kg gom ks gcr lo lbe ltg lez nia ln jbo lg mt mi tw mwl mdf mnw nqo fj nah na nds-nl nrm nov om pi pag pap pfl pcd krc kaa ksh rm rue sm sat sc trv stq nso sn cu so srn kab roa-tara tet tpi to chr tum tk tyv udm ug vep fiu-vro vls wo xh zea ty ak bm ch ny ee ff got iu ik kl mad cr pih ami pwn pnt dz rmy rn sg st tn ss ti din chy ts kcg ve