面向PyTorch用户的JAX简易教程[3]: 如何通过pmap轻松实现数据并行

背景 在上一篇文章中,我们学习了如何使用JAX+Flax+Optax训练神经网络。但是考虑到每块Cloud TPU上有8个core/device,而我们只用了一个device, 好在我们的模型规模没有夸张到一张卡放不下,很自然的想到使用数据并行 (data parallelism, DP) 的方式来训练模型。 数据并行 :假设有\( N \)张卡,每张卡都保存一个模型,每一次迭代(iteration/step)都将batch数据分割成\( N \)个等大小的micro-batch,每张卡根据拿到的micro-batch数据独立计算梯度,然后调用AllReduce计算梯度均值,每张卡再独立进行参数更新。 数据并行流程 注意,本文的数据并行仅限在单机多卡环境,以后如果有多机资源会进行update。 pmap+jax.lax.p* 在单机多卡上轻松实现数据并行 pmap JAX中的pmap (parallel map) 让数据并行的实现方式异常简单,先来看一个简单的pmap示例, import jax from jax import pmap, numpy as jnp key = jax.random.PRNGKey(0) # 定义一个函数,做向量点积 def f(x, y): return jnp.dot(x, y) # 创建两个向量x, y key, init_key1, init_key2 = jax.random.split(key, 3) x = jax.random.normal(init_key1, (10, )) y = jax.random.normal(init_key2, (10, )) x....

July 26, 2022 · 4 min

面向PyTorch用户的JAX简易教程[2]: 如何训练一个神经网络

背景 上一篇文章我们学习了JAX的基本知识,主要是几个关键词: NumPy API、transformations、XLA。这一篇来点实际的,看下如何训练一个神经网络,我们先回忆下用PyTorch训练神经网络,都需要哪几步: 实现网络模型 实现数据读取流程 使用优化器/调度器更新模型参数/学习率 实现模型训练和验证流程 下面我们就以在MNIST数据集上训练一个MLP为例,看下在JAX中如何实现上面的流程。 NumPy API实现网络模型 MNIST是一个10分类问题,每张图片大小是 28 * 28=784 ,我们设计一个简单的MLP网络, 一个四层MLP (包含输入层) import jax from jax import numpy as jnp from jax import grad, jit, vmap # 创建 PRNGKey (PRNG State) key = jax.random.PRNGKey(0) ## 创建模型参数, 去除输入层,实际上三层Linear,每层都包含一组(w, b),共三组参数 def random_layer_params(m, n, key, scale=1e-2): """ A helper function to randomly initialize weights and biases for a dense neural network layer """ w_key, b_key = jax....

July 23, 2022 · 9 min

面向PyTorch用户的JAX简易教程[1]: JAX介绍

背景 前几天申请参加了Google TRC项目,TPU VM的配置相当可以,但是PyTorch/XLA做数据并行时的体验却并不那么丝滑,考虑到Google一直力推TPU+JAX的组合,所以决定学习下JAX。 JAX简介 什么是JAX? 官方在GitHub README中是这么介绍的: JAX is Autograd and XLA, brought together for high-performance machine learning research. 在Description中写的是: Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more. 在JAX官方文档又是这么介绍的: JAX is NumPy on the CPU, GPU, and TPU, with great automatic differentiation for high-performance machine learning research. 总结一下,有几个关键词:Autograd、XLA、NumPy和composable transformations。 XLA 先来说XLA,这个我了解的最少,所以介绍起来最简单,XLA (Accelerated Linear Algebra)是Google为TensorFlow设计的一款编译器,主打JIT (Just-in-Time)编译和跨设备(CPU/GPU/TPU)执行,所以JAX介绍中凡是涉及到JIT、high-performance、CPU/GPU/TPU,都指的是XLA。 NumPy NumPy就不用提了,Python生态下只要涉及到数据分析/机器学习/数值计算中对数组/tensor进行操作,都离不开它,不夸张的说,NumPy API已经成为了数组/tensor操作的半个工业标准,包括各家深度学习框架中对tensor操作的函数接口也都是尽量靠近NumPy,JAX则更夸张,jax.numpy重新实现一套了NumPy API ,让用户从NumPy无缝切入JAX: from jax import numpy as jnp Autograd 这里的Autograd是哈佛大学HIPS实验室在14年开始开发的一款自动微分框架,特点是可以对Python/NumPy函数进行高阶求导,直接看个例子,一个简单的函数 f(x) ,顺便求一下一阶、二阶、三阶导函数:...

July 21, 2022 · 4 min

Sentence-BERT: 如何通过对比学习得到更好的句子向量表示

背景 Sentence-BERT是对句子进行向量表示的一项经典工作,论文延伸出来的sentence-transformers 项目,在GitHub上已经收获了8.1k个star,今天重读下论文。 Introduction 句子的向量表示,也就是sentence embedding,是利用神经网络对句子进行编码,得到的固定长度向量,我们希望这个向量包含了句子的”语义信息“: 句子向量表示 句子向量可以应用于NLP领域的方方面面,我们暂时将目光聚焦到文本语义相似度检测 (semantic textual similarity, STS )任务上:给定两个句子,判断两个句子在语义层面的相似程度,相似程度可以是连续值([0, 1])也可以是离散值 (0-5)。 前BERT时代有不少出彩的工作,咱们就先略去不表了,直接看BERT是怎么做的,本身BERT模型的输入就包含两个序列,所以天然适合处理STS任务,将两个句子拼接: [CLS] sentence 1 [SEP] sentence 2 [SEP] 直接作为BERT的输入,然后取最后一层的[CLS]向量或者所有token向量的mean/max啥的,再接一个简单的MLP即可。剩下的就是找个数据集进行fine-tune吧。 我们将这种方式称为"Cross-Encoder",因为两个句子的token可以交互,有利于学习到句子对之间的相似性。 如果你的的任务也像STS这样,句子对(sentence pair)的关系已经固定了,只需要判断句子对的关系(比如相似程度),那么Cross-Encoder非常适合你,但是,如果你的任务是从\( N \)个句子中找出最相似的两个句子,或者找出和句子\( q \)最相似的句子,那么Cross-Encoder就面临一个计算量的问题。 \( N \) 个句子两两组合,有多少种情况? \( \frac{N\cdot (N-1)}{2} \) 如果\( N = 10\),结果是45 如果\( N = 100\),结果是4950 如果\( N = 1000\),结果是49995000 …… 实际的业务场景中,几十万上百万的句子都算少的,好家伙,这计算量着实有点难顶啊。 can you 顶得住? 还是让我们回到sentence embedding这个更泛化的问题上来,如果现在有一个NBERT模型,能够得到高质量的句子向量表示,那么面对STS任务,我们就可以先将"sentence 1"作为BERT输入,得到向量"vector 1",再将"sentence 2"作为BERT输入,得到向量"vector 2",然后计算两个向量之间的相似度。...

July 16, 2022 · 2 min

如何申请Google TRC项目,领取免费的Cloud TPU计算资源

背景 当下,深度学习要想搞的好,算力、数据和模型哪一样都少不了。后两者相对来说比较容易解决,各种模型开源实现和公开数据集还是挺充足的,但是算力≈金钱,有时候脑海中蹦出来一些新鲜idea,可是看着手里的老旧显卡,只能在夜晚暗自神伤 额,画风好像有点不对,但是大概就是这么个意思 一种方法是租显卡,比如某A、某B或者某C,如果是个人使用,不必像公司做项目那样顾虑太多,可以挑一些小公司的产品,有的还是挺便宜的。 另一种就是找免费的计算资源了 比如Google Colab或者Kaggle,以及今天要介绍的Google Research的TPU Research Cloud (TRC)项目。 TRC项目 TRC项目中的T,指的是Google自家的加速卡TPU,和GPU不同,Google并不公开出售TPU设备,而是集成在Google Cloud中,提供挂载了TPU的云计算服务。 TRC项目就是Google免费赠送给我们一段时间的TPU服务器,比如这是我申请成功后的结果,5台Cloud TPU v2-8和5台Cloud TPU v3-8,以及100台抢占式的Cloud TPU v2-8。免费使用时间是60天。 随便看一下其中一台的配置, CPU 96核,内存335GB,而且还挂载了TPU v2-8或者v3-8。我只能说, Cloud TPU TPU是Google推出的专用于机器学习的加速设备,可以类比NVIDIA的GPU,目前TPU已经更新到了第四代,也就是TPU v4,TRC项目提供的是前两代,TPU v2和TPU v3,对于大部分场景已经绰绰有余了。 简单说一下TPU,每块TPU上面有4块芯片(chip),每块芯片有两个核(core),所以这就是为什么叫做v2-8/v3-8,以v3为例,每个核有2个独立的矩阵计算单元(Matrix Multiply Unit, MXU)、1个向量处理单元(Vector Processing Unit, VPU)以及1个标量单元,每块芯片有32GB的高速存储(HBM)。 所以,可以把v3-8简单理解为就是8张GPU卡。 除此之外,还有算力恐怖的TPU Pod,也就是几百上千块TPU组成的算力单元,当然,这个并没有免费开放给所有人。 TRC申请 由于Cloud TPU是以Google Cloud中的一种服务提供出来,所以先要有一个Google Cloud账号,然后再申请TRC项目。 Google Cloud开通 Google Cloud申请流程,网上有很多资料,这里就不细讲了,必要条件是有一个gmail邮箱、一个手机号以及一张支持VISA或MasterCard的信用卡。 // 貌似这一步是最难的:( 注册成功后,Google Cloud会赠送300$,有效期90天。 TRC申请 如果你开通了Google Cloud,接下来就可以申请TRC项目了,申请流程非常简单,就是填写一个小问卷: 问题很简单,如实填写就行了,貌似就没有被拒的:)...

July 7, 2022 · 1 min