背景

前几天申请参加了Google TRC项目,TPU VM的配置相当可以,但是PyTorch/XLA做数据并行时的体验却并不那么丝滑,考虑到Google一直力推TPU+JAX的组合,所以决定学习下JAX。

JAX简介

什么是JAX?

官方在GitHub README中是这么介绍的: JAX is Autograd and XLA, brought together for high-performance machine learning research.

在Description中写的是: Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more.

在JAX官方文档又是这么介绍的: JAX is NumPy on the CPU, GPU, and TPU, with great automatic differentiation for high-performance machine learning research.

总结一下,有几个关键词:Autograd、XLA、NumPy和composable transformations。

XLA

先来说XLA,这个我了解的最少,所以介绍起来最简单,XLA (Accelerated Linear Algebra)是Google为TensorFlow设计的一款编译器,主打JIT (Just-in-Time)编译和跨设备(CPU/GPU/TPU)执行,所以JAX介绍中凡是涉及到JIT、high-performance、CPU/GPU/TPU,都指的是XLA。

NumPy

NumPy就不用提了,Python生态下只要涉及到数据分析/机器学习/数值计算中对数组/tensor进行操作,都离不开它,不夸张的说,NumPy API已经成为了数组/tensor操作的半个工业标准,包括各家深度学习框架中对tensor操作的函数接口也都是尽量靠近NumPy,JAX则更夸张,jax.numpy重新实现一套了NumPy API ,让用户从NumPy无缝切入JAX:

from jax import numpy as jnp

Autograd

这里的Autograd是哈佛大学HIPS实验室在14年开始开发的一款自动微分框架,特点是可以对Python/NumPy函数进行高阶求导,直接看个例子,一个简单的函数 f(x) ,顺便求一下一阶、二阶、三阶导函数:

\[f(x) = x^3 + 2 \cdot x \]

\[f’(x) = 3 \cdot x^2 + 2 \]

\[f’’(x) = 6 \cdot x \]

\[f’’’(x) = 6 \]

如果\(x = 2 \) ,甚至可以口算出 \(f’(2) = 14, f’’(2)=12, f’’’(2)=6 \) 。我们可以用autograd来实现求导:

from autograd import grad

def f(x):
    return x**3 + 2*x

grad_f = grad(f)  # 一阶导函数
grad_grad_f = grad(grad_f)  # 两次grad组合,就是二阶导函数
grad_grad_grad_f = grad(grad_grad_f)  # 三次grad组合,就是三阶导函数
print(grad_f(2.), grad_grad_f(2.), grad_grad_grad_f(2.))
# 14.0 12.0 6.0

自动微分框架除了可以应用于数值计算,它还是深度学习框架的核心,可惜的是,由于性能(纯Python,只有CPU版本)以及其他原因,autograd库并没有推广起来,但是它却实实在在启发到了后续的torch-autograd、Chainer以及PyTorch中的autograd模块:

注:Adam毕业后加入了JAX团队,PyTorch在1.10版本也推出了functorch (JAX-like composable function transforms for PyTorch), 他们都有光明的未来:)

估计是Matthew一直对autograd性能耿耿于怀,当他在Google内部听到XLA的分享后,便和同事产生了JAX的最初想法:

Autograd + XLA ===> JAX

前者负责微分功能,后者实现高性能。

注:可以将JAX中的算子(operation,操作)看做是对XLA算子的Python封装:jax.numpy中的操作/算子是对更底层的jax.lax的封装,而jax.lax中的算子是XLA算子的Python封装。

Composable (function) transformations (可组合的函数转换)

composable transformations是JAX的核心,真正体现了JAX的特性/差异/优势。 // 标题都改成一级标题了。

什么是transformation (function transformations, transforms)?其实就是高阶函数 (Higher-order function),高阶函数是至少满足下列一个条件的函数:

  • 接受一个或多个函数作为输入
  • 输出一个函数

Python中常见的高阶函数比如map

transformation的输入是Python函数,输出也是函数。JAX中经常用到的transformation主要有四个:

  • grad: reverse mode自动微分,用在深度学习中足够了
  • jit: JIT编译,调用XLA进行JIT编译,用于优化代码
  • vmap: vectorization/batching,将函数扩展为支持批处理
  • pmap: parallelization,轻松实现数据并行 (data parallelism),类似PyTorch的DistributedDataParallel

不知道看到这里,你是不是会很疑惑,JAX的核心就是这么几个高阶函数?能干啥?

我们来看下这四个transformation到底能干啥?

grad

grad在Autograd那里已经介绍过了,

from jax import numpy as jnp
from jax import grad

def f(x):
    return jnp.sum(x * x)  # 函数输出只能是标量

grad_f = grad(f)
grad_f(jnp.array([1, 2, 3.]))
# DeviceArray([2., 4., 6.], dtype=float32)

grad只是JAX自动微分机制中最基本的一个transform,实际上JAX支持前向(forward-mode)和后向(reverse-mode)自动微分以及二者的任意组合, 感兴趣的同学可以去查看jvp和vjp 的文档。考虑到常见的深度学习任务,grad戳戳有余, 其他transform这里就不介绍了,实际上是我没用过,压根没那个能力介绍。

grad不但好用,而且数学上更直观,如果我们不局限在深度学习领域,从优化 (optimization)的角度看,大多数机器学习模型的学习都可以表示为\(\hat{y} =f(x)\) 、\(\max_{{y}}\ p(y|x)\) 、\(\max_{y} \ \frac{p(x, y)}{p(x)}\) 的一种。

LR可以表示为\(f(x)\),神经网络也可以表示为\(f(x)\),损失函数是\(loss = g(f(x), y)\),如果用SGD算法来解决,需要计算参数的梯度,想一下高数课上我们是怎么做的,直接对损失函数求导函数\(grad(g)\) 然后代入\(x\) ,现在grad用的就是这种方式。并且这种方式在数学上可以自然的泛化到高阶导数优化求解问题上。

jit

jit 是用户显式的调用XLA对代码进行优化(包括算子融合、内存优化等),执行时间可能缩短很多,


import numpy as np
from jax import numpy as jnp
from jax import jit

def norm(X):
    X = X - X.mean(0)
    return X / X.std(0)

norm_compiled = jit(norm)
X = jnp.array(np.random.rand(10000, 100))

%timeit norm(X).block_until_ready()
%timeit norm_compiled(X).block_until_ready()
# 585 µs ± 85.4 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
# 216 µs ± 12.3 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

# 好像提升不是很显著,再来看一个例子
from jax import random

key = random.PRNGKey(0)

def selu(x, alpha=1.67, lmbda=1.05):
    return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

selu_jit = jit(selu)
x = random.normal(key, (1000000,))
%timeit selu(x).block_until_ready()
%timeit selu_jit(x).block_until_ready()
# 1.06 ms ± 26.8 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
# 187 µs ± 19.6 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

# 哦豁,效果还不错

vmap

vmap可以自动让函数支持batching,看个例子,原始函数表示的是向量-向量乘法,使用vmap可以得到矩阵-向量乘法的函数:

from jax import numpy as jnp
from jax import vmap


def vec_vec_dot(x, y):
    """vector-vector dot, ([a], [a]) -> []
    """
    return jnp.dot(x, y)

x = jnp.array([1,1,2])
y = jnp.array([2,1,1,])
vec_vec_dot(x, y)
# DeviceArray(5, dtype=int32)

mat_vec = vmap(vec_vec_dot, in_axes=(0, None), out_axes=0)  # ([b,a], [a]) -> [b]      (b is the mapped axis)
xx = jnp.array([[1,1,2], [1,1,2]])
mat_vec(xx, y)
# DeviceArray([5, 5], dtype=int32)

解释下vmap中的in_axesout_axees两个参数,前者表示对输入参数中哪一个的哪一维度进行batch扩充,这里(0, None)表示对x的第0维扩充,由原来的[a] -> [b,a]。后者表示对返回结果的哪一维度进行扩充,这里表示由原来的[] - > [b]

pmap

pmap让并行编程变的非常丝滑,可以用于数据并行训练,注意pmap包含了jit操作,下面我就在TPU v3-8 VM演示下:

import jax
from jax import numpy as jnp
from jax import pmap

jax.device_count()  # 8个core
# 8

jax.devices()
"""
[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
"""

x = jnp.arange(8)
y = jnp.arange(8)

vmap(jnp.add)(x, y)
# DeviceArray([ 0,  2,  4,  6,  8, 10, 12, 14], dtype=int32)

pmap(jnp.add)(x, y)
# ShardedDeviceArray([ 0,  2,  4,  6,  8, 10, 12, 14], dtype=int32)

看到上面vmap和pmap执行后的区别没,一个返回数据类型是DeviceArray,一个则是SharedDeviceArray,后者表示数据分散在多个device中。

组合

上面介绍的transformation不仅仅可以单兵作战,最重要的是可以任意组合,比如

pmap(vamp(some_func))
jit(grad(grad(vmap(some_func))))

纯函数约束

transformation很好用,但是只能作用于纯函数 (pure function)。

或者反过来理解,正因为函数都是纯函数,才可以实现composable transformations这样灵活强大的功能。

什么是纯函数?

  1. 只要函数的传参不变,函数返回结果就要相同
  2. 函数不会改变函数外的状态

我们直接来看反例吧,第一个反例:


x = 3.

def not_pure_function_case1(a):
    return x + a

print(not_pure_function_case1(1.))
# 4.

x = 5.
print(not_pure_function_case1(1.))
# 6.

我们使用相同的传参(1.)调用了两次,可是函数结果不同,所以违背了第一条原则。这是因为函数内部使用了全局变量,虽然仅仅是read value,但是只要全局变量的值改变,函数返回结果就变了。

Tip: 纯函数内部不要读或写函数外的变量。

第二个反例:

import numpy as np

np.random.seed(123)  # 设置随机数种子

def not_pure_function_case2(n):
    return np.random.randn(n)

not_pure_function_case2(5)
# array([-1.0856306 ,  0.99734545,  0.2829785 , -1.50629471, -0.57860025])

not_pure_function_case2(5)
# array([ 1.65143654, -2.42667924, -0.42891263,  1.26593626, -0.8667404 ])

随机数在机器学习中太常见了,你看,为了结果可复现,我们还设置了随机数种子,但是,这却不是一个纯函数。

在NumPy中,随机数生成器状态(RNG State)是一个全局变量,只要我们调用了随机数生成算法(比如上面的np.random.rand()),都会导致RNG State发生变化,这样,连续两次的随机数生成结果就不相同,又违背了纯函数第一条原则。

为此,jax.numpynumpy的第一个不同之处出现了,JAX没有隐含的全局RNG State,凡是涉及到随机数生成的地方,都需要用户显式的使用RNG State。

import jax

key = jax.random.PRNGKey(0)  # 显式的创建PRNGKey,可以表示RNG State

x = jax.random.normal(key, (1000000,))  # 传入key,进行随机数生成

key, subkey = jax.random.split(key)  # 更新RNG State
xx = jax.random.normal(subkey, shape=(1,))

key, subkey = jax.random.split(key)  # 更新RNG State
xxx = jax.random.normal(subkey, shape=(1,))

第三个反例:

xs = [1,2,3]

def not_pure_function_case3(xs):
    xs.append(1.)
    return xs

not_pure_function_case3(xs)
# [1, 2, 3, 1.0]

not_pure_function_case3(xs)
# [1, 2, 3, 1.0, 1.0]

函数内部修改了xs,违背了第二条原则。

第四个反例:

def not_pure_function_case4(x):
    print("oops, not pure")
    return x

这个反例是因为print属于IO操作,违背了第二条。

Note: 如果我们不小心写出了non-pure function,然后进行transformation怎么办?你肯定指望JAX抛出一个异常,可惜的是,JAX内部并没有检查函数是否pure的机制,对于non-pure,transformation的行为属于undefined,有点像C语言中的野指针,此时函数的执行结果不可预测。

jaxpr

稍微聊一下transformation背后的故事,JAX中定义了一种中间表示语言(jaxpr),每个transformation的执行都分两步:

  1. 先将原Python函数翻译为jaxpr,这个过程被称为"tracing"
  2. 再对jaxpr进行transform (转换),可以将每个transformation看作一个独立的jaxpr interpreter,对于JAX中每个原子操作 (primitive)都有相应的转换规则

jaxpr的优势是语法简单,相比于直接对Python函数transform,对jaxpr进行transform容易得多。

如何实现NN model

有了jax.numpy、grad、pmap、jit,现在就可以编写网络,实现训练过程了,但是想象下用NumPy实现一个ResNet,实现一个Transformer,能做,但是也太复杂了,

下一篇会介绍Flax,一个基于JAX的NN library,如何基于Flax+JAX来轻松实现网络训练流程。

参考资料

[1] JAX文档,JAX reference documentation

[2] JAX的两位creator Roy和Matt 对JAX项目的介绍,强烈推荐,Stanford MLSys Seminar Episode 6: Roy Frostig on JAX

[3] JAX团队的Skye Wanderman-Milne 对JAX项目的介绍,0:42:49 Marc van Zee (Google Brain): Introduction to Flax

[4] JHU的Sabrina J. Mielke对JAX的介绍,0:26:16 Sabrina J. Mielke (Johns Hopkins University & Hugging Face): From stateful code to purified JAX: how to build your neural net framework

[5] JHU的Sabrina J. Mielke一篇博客,上一个链接的文字版,From PyTorch to JAX: towards neural net frameworks that purify stateful code

[6] 前JAX团队Mat Kelcey对JAX的介绍,“High performance machine learning with JAX” - Mat Kelcey (PyConline AU 2021)

[7] Matt对JAX的介绍,内容有点雷同,JAX: accelerated machine learning research via composable function transformations in Python

[8] DeepMind团队介绍JAX在内部使用情况,NeurIPS 2020: JAX Ecosystem Meetup

[9] Autograd项目,https://github.com/HIPS/autograd

[10] XLA项目,https://www.tensorflow.org/xla

[11] 如果想了解下随机数生成,强烈推荐该领域大牛 Melissa E. O’Neill写的 Random Number Generation Basics