背景

在上一篇文章中,我们学习了如何使用JAX+Flax+Optax训练神经网络。但是考虑到每块Cloud TPU上有8个core/device,而我们只用了一个device,

好在我们的模型规模没有夸张到一张卡放不下,很自然的想到使用数据并行 (data parallelism, DP) 的方式来训练模型。

数据并行 :假设有\( N \)张卡,每张卡都保存一个模型,每一次迭代(iteration/step)都将batch数据分割成\( N \)个等大小的micro-batch,每张卡根据拿到的micro-batch数据独立计算梯度,然后调用AllReduce计算梯度均值,每张卡再独立进行参数更新。

数据并行流程

注意,本文的数据并行仅限在单机多卡环境,以后如果有多机资源会进行update。

pmap+jax.lax.p* 在单机多卡上轻松实现数据并行

pmap

JAX中的pmap (parallel map) 让数据并行的实现方式异常简单,先来看一个简单的pmap示例,

import jax
from jax import pmap, numpy as jnp

key = jax.random.PRNGKey(0)

# 定义一个函数,做向量点积
def f(x, y):
    return jnp.dot(x, y)

# 创建两个向量x, y
key, init_key1, init_key2 = jax.random.split(key, 3)
x = jax.random.normal(init_key1, (10, ))
y = jax.random.normal(init_key2, (10, ))

x.shape, y.shape
# ((10,), (10,))

# 使用pmap得到并行版本的f,并且是跨device执行哦
p_f = pmap(f)

# 注意:此时p_f的输入x和y的shape不再是向量了,而是增加了一个维度,(N, 10)
# N的值由硬件环境决定,N <= device数量
# 比如我在TPU v3-8上执行这段代码,则N的取值范围是N <=8
key, init_key1, init_key2 = jax.random.split(key, 3)
xs = jax.random.normal(init_key1, (jax.local_device_count(), 10))
ys = jax.random.normal(init_key2, (jax.local_device_count(), 10))
xs.shape, ys.shape
# ((8, 10), (8, 10))

p_f(xs, ys)
# ShardedDeviceArray([-0.2600838 ,  4.726631  ,  3.7643652 ,  1.5107703 ,
                    -0.64313316, -1.0984898 , -1.3667903 ,  6.053646  ],                   dtype=float32)

# 由于xs和ys的shape都是(8, 10),在执行时,每个device中的local xs和local ys的shape都是
# (1, 10)
# 我们来验证下:
jnp.dot(xs[0], ys[0])  # 看下计算结果,是p_f(xs, ys)的第一个值
# DeviceArray(-0.2600838, dtype=float32)

jnp.dot(xs[1], ys[1])  # 再看下计算结果,是p_f(xs, ys)的第二个值
# DeviceArray(4.726631, dtype=float32)

上面的例子虽然简单,但是已经足够说明了pmap的作用,自动对函数输入数据进行分片 (partition/shard),每个device拿到一个独立的分片数据进行计算。不过这样还不能实现数据并行,因为缺少对各个device上的梯度进行AllReduce的操作,jax.lax中提供了必备的集合通信 (collective communication) 函数来进行跨设备数据计算: pmeanpsumall_gatherall_to_all等等。

jax.lax.p* 集合通信函数

看个简单的例子,

def mean_f(x, y):
    z = jnp.dot(x, y)
    return jax.lax.pmean(z, axis_name="batch")   # pmean,计算所有device中数据的均值(all-reduce mean)

p_mean_f = pmap(mean_f, axis_name="batch")  # 注意axis_name

key, init_key1, init_key2 = jax.random.split(key, 3)
xs = jax.random.normal(init_key1, (jax.local_device_count(), 10))
ys = jax.random.normal(init_key2, (jax.local_device_count(), 10))
p_mean_f(xs, ys)
# ShardedDeviceArray([-0.568686, -0.568686, -0.568686, -0.568686, -0.568686,
                    -0.568686, -0.568686, -0.568686], dtype=float32)

注意,上面的pmap转换后的函数,数据操作都是沿着第0轴进行的,或者理解为数据都在第0维进行了扩充,比如原来x的shape是(10, ),现在是(8, 10),实际上,可以指定沿着哪个轴(axis)进行数据维度扩充,比如在一些更复杂的情景,需要vmappmap进行组合,而每个transformation要扩充的axis都不同,这个时候函数内部需要知道沿着哪个axis进行pmean/psum啥的,就需要用到axis_name,就是给axis起的别名,方便记忆。

axis_name参数

看个例子,

xxx = jax.random.normal(init_key1, (4, 2, 3))
yyy = jax.random.normal(init_key2, (4, 2, 3))

def g(x, y):
    z = jnp.dot(x, y)
    return jax.lax.pmean(z, "i")  # 每个device的vmap结果求均值

pmap(vmap(g, axis_name="i"), axis_name="j")(xxx, yyy)

"""
ShardedDeviceArray([[-1.8138683 , -1.8138683 ],
                    [ 4.119129  ,  4.119129  ],
                    [-0.32894212, -0.32894212],
                    [-1.3789278 , -1.3789278 ]], dtype=float32)

""" 

def h(x, y):
    z = jnp.dot(x, y)
    return jax.lax.pmean(z, "j")

pmap(vmap(h, axis_name="i"), axis_name="j")(xxx, yyy)  # 所有device的vmap结果,每一列求均值

"""
ShardedDeviceArray([[ 0.9534667 , -0.65477127],
                    [ 0.9534667 , -0.65477127],
                    [ 0.9534667 , -0.65477127],
                    [ 0.9534667 , -0.65477127]], dtype=float32)

"""

所以,即使是单独使用pmap也需要给指定axis_name

数据并行

已经了解了pmap + jax.lax.p* + axis_name,下面就可以改造train_step函数了,

def train_step(state, x, y, dropout_key):
    """Computes gradients and loss for a single batch."""
    def loss_fn(params):
        logits, new_state = state.apply_fn({"params": params, "batch_stats": state.batch_stats},
                                           x, train=True, rngs={"dropout": dropout_key}, mutable=["batch_stats"])
        
        one_hot = jax.nn.one_hot(y, 10)
        loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
        return loss, new_state

    grad_fn = value_and_grad(loss_fn, has_aux=True)  # `value_and_grad`在进行grad同时返回loss
    (loss, new_state), grads = grad_fn(state.params)
    grads = jax.lax.pmean(grads, "batch")  # pmean计算所有device上的梯度均值
    loss = jax.lax.pmean(loss, "batch")
    batch_stats = jax.lax.pmean(new_state["batch_stats"], "batch")
    new_state = state.apply_gradients(grads=grads, batch_stats=batch_stats)
    
    return new_state, loss

p_train_step = pmap(train_step, "batch", donate_argnums=(0,))  # 设置axi_name


def apply_model(state, x):
    """Computes gradients and loss for a single batch."""
    
    logits = state.apply_fn({"params":state.params, "batch_stats": state.batch_stats},
                            x, train=False)
    return jnp.argmax(logits, -1)


def eval_model(state, loader):
    total_acc = 0.
    total_num = 0.
    for xs, ys in loader:
        xs = jax.tree_map(
            lambda x: x.reshape((jax.local_device_count(), -1) + x.shape[1:]), xs)
        ys = jax.tree_map(
            lambda x: x.reshape((jax.local_device_count(), -1) + x.shape[1:]), ys)
        y_pred = pmap(apply_model)(state, xs)  # 验证时没有跨设备通信操作,不需要设置axis_name
        total_num += ys.size
        total_acc += jnp.sum(y_pred == ys)
    return total_acc / total_num

数据维度扩充

接下来就是把pmap转换后的函数输入,都进行维度扩充,比如原来x的shape是(128, 784),现在使用了数据并行,x的shape可以提高到(128 * 8, 784),然后reshape成(8, 128, 784),TrainState、rng全部进行复制,

# replicate states
devices = jax.local_devices()
state = jax.device_put_replicated(state, devices)  # 或者 state = flax.jax_utils.replicate(state)

for epoch in range(5):
    for idx, (xs, ys) in enumerate(train_loader):
        xs = jax.tree_map(
            lambda x: x.reshape((jax.local_device_count(), -1) + x.shape[1:]), xs)
        ys = jax.tree_map(
            lambda x: x.reshape((jax.local_device_count(), -1) + x.shape[1:]), ys)
        
        key, dropout_key = jax.random.split(key)
        dropout_key = jax.random.split(drop_key, jax.local_device_count())  # replicate rngs
        state, loss = p_train_step(state, xs, ys, dropout_key)
        
        if idx % 100 == 0:  # evaluation
            train_acc = eval_model(state, train_loader)
            eval_acc = eval_model(state, eval_loader)
            print("Epoch {} - batch_idx {}, loss {}, Training set acc {}, eval set accuracy {}".format(
              epoch, idx, jax.tree_map(lambda x: x[0], loss), train_acc, eval_acc))
# some logs
Epoch 0 - batch_idx 0, loss 2.666304588317871, Training set acc 0.4732118844985962, eval set accuracy 0.4770999848842621
Epoch 1 - batch_idx 0, loss 0.20269665122032166, Training set acc 0.9528387784957886, eval set accuracy 0.9491999745368958
Epoch 2 - batch_idx 0, loss 0.1484561264514923, Training set acc 0.9685648083686829, eval set accuracy 0.964199960231781
Epoch 3 - batch_idx 0, loss 0.12777751684188843, Training set acc 0.9755185842514038, eval set accuracy 0.9678999781608582
Epoch 4 - batch_idx 0, loss 0.10267762094736099, Training set acc 0.9769666194915771, eval set accuracy 0.9679999947547913

通过jax.tree_util.tree_map(lambda x: x[0], loss)来取第一个device中的loss值。

总结 + 源码

总结下,使用pmap进行数据并行的要点:

  1. batch_size扩充为原始batch_size * device数量
  2. 使用jax.lax.pmean计算所有device的梯度均值
  3. pmap别忘了设置axis_name,随便一个有意义的字符串就行
  4. pmap转换后的函数传参,x和y进行reshape,(deivce_num, batch_size/device_num, *),state进行复制,rngs也split多份(device_num)

看着步骤挺多,实际上非常直观,而且都是模板,套用就行了。

老规矩,源码都在GitHub

jax-tutorials-for-pytorchers

参考资料

[1] pmap API文档,jax.pmap - JAX documentation

[2] Parallel Evaluation in JAX

[3] jax.lax中的集合通信函数,jax.lax package - JAX documentation

[4] NVIDIA 集合通信函数介绍 Collective Operations

[5] 参考示例,Distributed Inference with JAX