面向PyTorch用户的JAX简易教程[3]: 如何通过pmap轻松实现数据并行

背景 在上一篇文章中,我们学习了如何使用JAX+Flax+Optax训练神经网络。但是考虑到每块Cloud TPU上有8个core/device,而我们只用了一个device, 好在我们的模型规模没有夸张到一张卡放不下,很自然的想到使用数据并行 (data parallelism, DP) 的方式来训练模型。 数据并行 :假设有\( N \)张卡,每张卡都保存一个模型,每一次迭代(iteration/step)都将batch数据分割成\( N \)个等大小的micro-batch,每张卡根据拿到的micro-batch数据独立计算梯度,然后调用AllReduce计算梯度均值,每张卡再独立进行参数更新。 数据并行流程 注意,本文的数据并行仅限在单机多卡环境,以后如果有多机资源会进行update。 pmap+jax.lax.p* 在单机多卡上轻松实现数据并行 pmap JAX中的pmap (parallel map) 让数据并行的实现方式异常简单,先来看一个简单的pmap示例, import jax from jax import pmap, numpy as jnp key = jax.random.PRNGKey(0) # 定义一个函数,做向量点积 def f(x, y): return jnp.dot(x, y) # 创建两个向量x, y key, init_key1, init_key2 = jax.random.split(key, 3) x = jax.random.normal(init_key1, (10, )) y = jax.random.normal(init_key2, (10, )) x....

July 26, 2022 · 4 min

面向PyTorch用户的JAX简易教程[2]: 如何训练一个神经网络

背景 上一篇文章我们学习了JAX的基本知识,主要是几个关键词: NumPy API、transformations、XLA。这一篇来点实际的,看下如何训练一个神经网络,我们先回忆下用PyTorch训练神经网络,都需要哪几步: 实现网络模型 实现数据读取流程 使用优化器/调度器更新模型参数/学习率 实现模型训练和验证流程 下面我们就以在MNIST数据集上训练一个MLP为例,看下在JAX中如何实现上面的流程。 NumPy API实现网络模型 MNIST是一个10分类问题,每张图片大小是 28 * 28=784 ,我们设计一个简单的MLP网络, 一个四层MLP (包含输入层) import jax from jax import numpy as jnp from jax import grad, jit, vmap # 创建 PRNGKey (PRNG State) key = jax.random.PRNGKey(0) ## 创建模型参数, 去除输入层,实际上三层Linear,每层都包含一组(w, b),共三组参数 def random_layer_params(m, n, key, scale=1e-2): """ A helper function to randomly initialize weights and biases for a dense neural network layer """ w_key, b_key = jax....

July 23, 2022 · 9 min

面向PyTorch用户的JAX简易教程[1]: JAX介绍

背景 前几天申请参加了Google TRC项目,TPU VM的配置相当可以,但是PyTorch/XLA做数据并行时的体验却并不那么丝滑,考虑到Google一直力推TPU+JAX的组合,所以决定学习下JAX。 JAX简介 什么是JAX? 官方在GitHub README中是这么介绍的: JAX is Autograd and XLA, brought together for high-performance machine learning research. 在Description中写的是: Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more. 在JAX官方文档又是这么介绍的: JAX is NumPy on the CPU, GPU, and TPU, with great automatic differentiation for high-performance machine learning research. 总结一下,有几个关键词:Autograd、XLA、NumPy和composable transformations。 XLA 先来说XLA,这个我了解的最少,所以介绍起来最简单,XLA (Accelerated Linear Algebra)是Google为TensorFlow设计的一款编译器,主打JIT (Just-in-Time)编译和跨设备(CPU/GPU/TPU)执行,所以JAX介绍中凡是涉及到JIT、high-performance、CPU/GPU/TPU,都指的是XLA。 NumPy NumPy就不用提了,Python生态下只要涉及到数据分析/机器学习/数值计算中对数组/tensor进行操作,都离不开它,不夸张的说,NumPy API已经成为了数组/tensor操作的半个工业标准,包括各家深度学习框架中对tensor操作的函数接口也都是尽量靠近NumPy,JAX则更夸张,jax.numpy重新实现一套了NumPy API ,让用户从NumPy无缝切入JAX: from jax import numpy as jnp Autograd 这里的Autograd是哈佛大学HIPS实验室在14年开始开发的一款自动微分框架,特点是可以对Python/NumPy函数进行高阶求导,直接看个例子,一个简单的函数 f(x) ,顺便求一下一阶、二阶、三阶导函数:...

July 21, 2022 · 4 min