背景
在上一篇文章中,我们学习了如何使用JAX+Flax+Optax训练神经网络。但是考虑到每块Cloud TPU上有8个core/device,而我们只用了一个device,
好在我们的模型规模没有夸张到一张卡放不下,很自然的想到使用数据并行 (data parallelism, DP) 的方式来训练模型。
数据并行 :假设有\( N \)张卡,每张卡都保存一个模型,每一次迭代(iteration/step)都将batch数据分割成\( N \)个等大小的micro-batch,每张卡根据拿到的micro-batch数据独立计算梯度,然后调用AllReduce计算梯度均值,每张卡再独立进行参数更新。
注意,本文的数据并行仅限在单机多卡环境,以后如果有多机资源会进行update。
pmap+jax.lax.p* 在单机多卡上轻松实现数据并行
pmap
JAX中的pmap
(parallel map) 让数据并行的实现方式异常简单,先来看一个简单的pmap
示例,
import jax
from jax import pmap, numpy as jnp
key = jax.random.PRNGKey(0)
# 定义一个函数,做向量点积
def f(x, y):
return jnp.dot(x, y)
# 创建两个向量x, y
key, init_key1, init_key2 = jax.random.split(key, 3)
x = jax.random.normal(init_key1, (10, ))
y = jax.random.normal(init_key2, (10, ))
x.shape, y.shape
# ((10,), (10,))
# 使用pmap得到并行版本的f,并且是跨device执行哦
p_f = pmap(f)
# 注意:此时p_f的输入x和y的shape不再是向量了,而是增加了一个维度,(N, 10)
# N的值由硬件环境决定,N <= device数量
# 比如我在TPU v3-8上执行这段代码,则N的取值范围是N <=8
key, init_key1, init_key2 = jax.random.split(key, 3)
xs = jax.random.normal(init_key1, (jax.local_device_count(), 10))
ys = jax.random.normal(init_key2, (jax.local_device_count(), 10))
xs.shape, ys.shape
# ((8, 10), (8, 10))
p_f(xs, ys)
# ShardedDeviceArray([-0.2600838 , 4.726631 , 3.7643652 , 1.5107703 ,
-0.64313316, -1.0984898 , -1.3667903 , 6.053646 ], dtype=float32)
# 由于xs和ys的shape都是(8, 10),在执行时,每个device中的local xs和local ys的shape都是
# (1, 10)
# 我们来验证下:
jnp.dot(xs[0], ys[0]) # 看下计算结果,是p_f(xs, ys)的第一个值
# DeviceArray(-0.2600838, dtype=float32)
jnp.dot(xs[1], ys[1]) # 再看下计算结果,是p_f(xs, ys)的第二个值
# DeviceArray(4.726631, dtype=float32)
上面的例子虽然简单,但是已经足够说明了pmap
的作用,自动对函数输入数据进行分片 (partition/shard),每个device拿到一个独立的分片数据进行计算。不过这样还不能实现数据并行,因为缺少对各个device上的梯度进行AllReduce的操作,jax.lax
中提供了必备的集合通信 (collective communication) 函数来进行跨设备数据计算: pmean
、psum
、all_gather
、all_to_all
等等。
jax.lax.p* 集合通信函数
看个简单的例子,
def mean_f(x, y):
z = jnp.dot(x, y)
return jax.lax.pmean(z, axis_name="batch") # pmean,计算所有device中数据的均值(all-reduce mean)
p_mean_f = pmap(mean_f, axis_name="batch") # 注意axis_name
key, init_key1, init_key2 = jax.random.split(key, 3)
xs = jax.random.normal(init_key1, (jax.local_device_count(), 10))
ys = jax.random.normal(init_key2, (jax.local_device_count(), 10))
p_mean_f(xs, ys)
# ShardedDeviceArray([-0.568686, -0.568686, -0.568686, -0.568686, -0.568686,
-0.568686, -0.568686, -0.568686], dtype=float32)
注意,上面的pmap
转换后的函数,数据操作都是沿着第0轴进行的,或者理解为数据都在第0维进行了扩充,比如原来x的shape是(10, ),现在是(8, 10),实际上,可以指定沿着哪个轴(axis)进行数据维度扩充,比如在一些更复杂的情景,需要vmap
、pmap
进行组合,而每个transformation要扩充的axis都不同,这个时候函数内部需要知道沿着哪个axis进行pmean/psum啥的,就需要用到axis_name
,就是给axis起的别名,方便记忆。
axis_name参数
看个例子,
xxx = jax.random.normal(init_key1, (4, 2, 3))
yyy = jax.random.normal(init_key2, (4, 2, 3))
def g(x, y):
z = jnp.dot(x, y)
return jax.lax.pmean(z, "i") # 每个device的vmap结果求均值
pmap(vmap(g, axis_name="i"), axis_name="j")(xxx, yyy)
"""
ShardedDeviceArray([[-1.8138683 , -1.8138683 ],
[ 4.119129 , 4.119129 ],
[-0.32894212, -0.32894212],
[-1.3789278 , -1.3789278 ]], dtype=float32)
"""
def h(x, y):
z = jnp.dot(x, y)
return jax.lax.pmean(z, "j")
pmap(vmap(h, axis_name="i"), axis_name="j")(xxx, yyy) # 所有device的vmap结果,每一列求均值
"""
ShardedDeviceArray([[ 0.9534667 , -0.65477127],
[ 0.9534667 , -0.65477127],
[ 0.9534667 , -0.65477127],
[ 0.9534667 , -0.65477127]], dtype=float32)
"""
所以,即使是单独使用pmap
也需要给指定axis_name
。
数据并行
已经了解了pmap + jax.lax.p* + axis_name,下面就可以改造train_step
函数了,
def train_step(state, x, y, dropout_key):
"""Computes gradients and loss for a single batch."""
def loss_fn(params):
logits, new_state = state.apply_fn({"params": params, "batch_stats": state.batch_stats},
x, train=True, rngs={"dropout": dropout_key}, mutable=["batch_stats"])
one_hot = jax.nn.one_hot(y, 10)
loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
return loss, new_state
grad_fn = value_and_grad(loss_fn, has_aux=True) # `value_and_grad`在进行grad同时返回loss
(loss, new_state), grads = grad_fn(state.params)
grads = jax.lax.pmean(grads, "batch") # pmean计算所有device上的梯度均值
loss = jax.lax.pmean(loss, "batch")
batch_stats = jax.lax.pmean(new_state["batch_stats"], "batch")
new_state = state.apply_gradients(grads=grads, batch_stats=batch_stats)
return new_state, loss
p_train_step = pmap(train_step, "batch", donate_argnums=(0,)) # 设置axi_name
def apply_model(state, x):
"""Computes gradients and loss for a single batch."""
logits = state.apply_fn({"params":state.params, "batch_stats": state.batch_stats},
x, train=False)
return jnp.argmax(logits, -1)
def eval_model(state, loader):
total_acc = 0.
total_num = 0.
for xs, ys in loader:
xs = jax.tree_map(
lambda x: x.reshape((jax.local_device_count(), -1) + x.shape[1:]), xs)
ys = jax.tree_map(
lambda x: x.reshape((jax.local_device_count(), -1) + x.shape[1:]), ys)
y_pred = pmap(apply_model)(state, xs) # 验证时没有跨设备通信操作,不需要设置axis_name
total_num += ys.size
total_acc += jnp.sum(y_pred == ys)
return total_acc / total_num
数据维度扩充
接下来就是把pmap
转换后的函数输入,都进行维度扩充,比如原来x的shape是(128, 784),现在使用了数据并行,x的shape可以提高到(128 * 8, 784),然后reshape成(8, 128, 784),TrainState、rng全部进行复制,
# replicate states
devices = jax.local_devices()
state = jax.device_put_replicated(state, devices) # 或者 state = flax.jax_utils.replicate(state)
for epoch in range(5):
for idx, (xs, ys) in enumerate(train_loader):
xs = jax.tree_map(
lambda x: x.reshape((jax.local_device_count(), -1) + x.shape[1:]), xs)
ys = jax.tree_map(
lambda x: x.reshape((jax.local_device_count(), -1) + x.shape[1:]), ys)
key, dropout_key = jax.random.split(key)
dropout_key = jax.random.split(drop_key, jax.local_device_count()) # replicate rngs
state, loss = p_train_step(state, xs, ys, dropout_key)
if idx % 100 == 0: # evaluation
train_acc = eval_model(state, train_loader)
eval_acc = eval_model(state, eval_loader)
print("Epoch {} - batch_idx {}, loss {}, Training set acc {}, eval set accuracy {}".format(
epoch, idx, jax.tree_map(lambda x: x[0], loss), train_acc, eval_acc))
# some logs
Epoch 0 - batch_idx 0, loss 2.666304588317871, Training set acc 0.4732118844985962, eval set accuracy 0.4770999848842621
Epoch 1 - batch_idx 0, loss 0.20269665122032166, Training set acc 0.9528387784957886, eval set accuracy 0.9491999745368958
Epoch 2 - batch_idx 0, loss 0.1484561264514923, Training set acc 0.9685648083686829, eval set accuracy 0.964199960231781
Epoch 3 - batch_idx 0, loss 0.12777751684188843, Training set acc 0.9755185842514038, eval set accuracy 0.9678999781608582
Epoch 4 - batch_idx 0, loss 0.10267762094736099, Training set acc 0.9769666194915771, eval set accuracy 0.9679999947547913
通过jax.tree_util.tree_map(lambda x: x[0], loss)
来取第一个device中的loss值。
总结 + 源码
总结下,使用pmap进行数据并行的要点:
- batch_size扩充为原始batch_size * device数量
- 使用
jax.lax.pmean
计算所有device的梯度均值 pmap
别忘了设置axis_name
,随便一个有意义的字符串就行pmap
转换后的函数传参,x和y进行reshape,(deivce_num, batch_size/device_num, *),state进行复制,rngs也split多份(device_num)
看着步骤挺多,实际上非常直观,而且都是模板,套用就行了。
老规矩,源码都在GitHub
参考资料
[1] pmap API文档,jax.pmap - JAX documentation
[2] Parallel Evaluation in JAX
[3] jax.lax中的集合通信函数,jax.lax package - JAX documentation
[4] NVIDIA 集合通信函数介绍 Collective Operations
[5] 参考示例,Distributed Inference with JAX