面向PyTorch用户的JAX简易教程[3]: 如何通过pmap轻松实现数据并行
背景 在上一篇文章中,我们学习了如何使用JAX+Flax+Optax训练神经网络。但是考虑到每块Cloud TPU上有8个core/device,而我们只用了一个device, 好在我们的模型规模没有夸张到一张卡放不下,很自然的想到使用数据并行 (data parallelism, DP) 的方式来训练模型。 数据并行 :假设有\( N \)张卡,每张卡都保存一个模型,每一次迭代(iteration/step)都将batch数据分割成\( N \)个等大小的micro-batch,每张卡根据拿到的micro-batch数据独立计算梯度,然后调用AllReduce计算梯度均值,每张卡再独立进行参数更新。 数据并行流程 注意,本文的数据并行仅限在单机多卡环境,以后如果有多机资源会进行update。 pmap+jax.lax.p* 在单机多卡上轻松实现数据并行 pmap JAX中的pmap (parallel map) 让数据并行的实现方式异常简单,先来看一个简单的pmap示例, import jax from jax import pmap, numpy as jnp key = jax.random.PRNGKey(0) # 定义一个函数,做向量点积 def f(x, y): return jnp.dot(x, y) # 创建两个向量x, y key, init_key1, init_key2 = jax.random.split(key, 3) x = jax.random.normal(init_key1, (10, )) y = jax.random.normal(init_key2, (10, )) x....