面向PyTorch用户的JAX简易教程[2]: 如何训练一个神经网络

背景 上一篇文章我们学习了JAX的基本知识,主要是几个关键词: NumPy API、transformations、XLA。这一篇来点实际的,看下如何训练一个神经网络,我们先回忆下用PyTorch训练神经网络,都需要哪几步: 实现网络模型 实现数据读取流程 使用优化器/调度器更新模型参数/学习率 实现模型训练和验证流程 下面我们就以在MNIST数据集上训练一个MLP为例,看下在JAX中如何实现上面的流程。 NumPy API实现网络模型 MNIST是一个10分类问题,每张图片大小是 28 * 28=784 ,我们设计一个简单的MLP网络, 一个四层MLP (包含输入层) import jax from jax import numpy as jnp from jax import grad, jit, vmap # 创建 PRNGKey (PRNG State) key = jax.random.PRNGKey(0) ## 创建模型参数, 去除输入层,实际上三层Linear,每层都包含一组(w, b),共三组参数 def random_layer_params(m, n, key, scale=1e-2): """ A helper function to randomly initialize weights and biases for a dense neural network layer """ w_key, b_key = jax....

July 23, 2022 · 9 min

面向PyTorch用户的JAX简易教程[1]: JAX介绍

背景 前几天申请参加了Google TRC项目,TPU VM的配置相当可以,但是PyTorch/XLA做数据并行时的体验却并不那么丝滑,考虑到Google一直力推TPU+JAX的组合,所以决定学习下JAX。 JAX简介 什么是JAX? 官方在GitHub README中是这么介绍的: JAX is Autograd and XLA, brought together for high-performance machine learning research. 在Description中写的是: Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more. 在JAX官方文档又是这么介绍的: JAX is NumPy on the CPU, GPU, and TPU, with great automatic differentiation for high-performance machine learning research. 总结一下,有几个关键词:Autograd、XLA、NumPy和composable transformations。 XLA 先来说XLA,这个我了解的最少,所以介绍起来最简单,XLA (Accelerated Linear Algebra)是Google为TensorFlow设计的一款编译器,主打JIT (Just-in-Time)编译和跨设备(CPU/GPU/TPU)执行,所以JAX介绍中凡是涉及到JIT、high-performance、CPU/GPU/TPU,都指的是XLA。 NumPy NumPy就不用提了,Python生态下只要涉及到数据分析/机器学习/数值计算中对数组/tensor进行操作,都离不开它,不夸张的说,NumPy API已经成为了数组/tensor操作的半个工业标准,包括各家深度学习框架中对tensor操作的函数接口也都是尽量靠近NumPy,JAX则更夸张,jax.numpy重新实现一套了NumPy API ,让用户从NumPy无缝切入JAX: from jax import numpy as jnp Autograd 这里的Autograd是哈佛大学HIPS实验室在14年开始开发的一款自动微分框架,特点是可以对Python/NumPy函数进行高阶求导,直接看个例子,一个简单的函数 f(x) ,顺便求一下一阶、二阶、三阶导函数:...

July 21, 2022 · 4 min

如何申请Google TRC项目,领取免费的Cloud TPU计算资源

背景 当下,深度学习要想搞的好,算力、数据和模型哪一样都少不了。后两者相对来说比较容易解决,各种模型开源实现和公开数据集还是挺充足的,但是算力≈金钱,有时候脑海中蹦出来一些新鲜idea,可是看着手里的老旧显卡,只能在夜晚暗自神伤 额,画风好像有点不对,但是大概就是这么个意思 一种方法是租显卡,比如某A、某B或者某C,如果是个人使用,不必像公司做项目那样顾虑太多,可以挑一些小公司的产品,有的还是挺便宜的。 另一种就是找免费的计算资源了 比如Google Colab或者Kaggle,以及今天要介绍的Google Research的TPU Research Cloud (TRC)项目。 TRC项目 TRC项目中的T,指的是Google自家的加速卡TPU,和GPU不同,Google并不公开出售TPU设备,而是集成在Google Cloud中,提供挂载了TPU的云计算服务。 TRC项目就是Google免费赠送给我们一段时间的TPU服务器,比如这是我申请成功后的结果,5台Cloud TPU v2-8和5台Cloud TPU v3-8,以及100台抢占式的Cloud TPU v2-8。免费使用时间是60天。 随便看一下其中一台的配置, CPU 96核,内存335GB,而且还挂载了TPU v2-8或者v3-8。我只能说, Cloud TPU TPU是Google推出的专用于机器学习的加速设备,可以类比NVIDIA的GPU,目前TPU已经更新到了第四代,也就是TPU v4,TRC项目提供的是前两代,TPU v2和TPU v3,对于大部分场景已经绰绰有余了。 简单说一下TPU,每块TPU上面有4块芯片(chip),每块芯片有两个核(core),所以这就是为什么叫做v2-8/v3-8,以v3为例,每个核有2个独立的矩阵计算单元(Matrix Multiply Unit, MXU)、1个向量处理单元(Vector Processing Unit, VPU)以及1个标量单元,每块芯片有32GB的高速存储(HBM)。 所以,可以把v3-8简单理解为就是8张GPU卡。 除此之外,还有算力恐怖的TPU Pod,也就是几百上千块TPU组成的算力单元,当然,这个并没有免费开放给所有人。 TRC申请 由于Cloud TPU是以Google Cloud中的一种服务提供出来,所以先要有一个Google Cloud账号,然后再申请TRC项目。 Google Cloud开通 Google Cloud申请流程,网上有很多资料,这里就不细讲了,必要条件是有一个gmail邮箱、一个手机号以及一张支持VISA或MasterCard的信用卡。 // 貌似这一步是最难的:( 注册成功后,Google Cloud会赠送300$,有效期90天。 TRC申请 如果你开通了Google Cloud,接下来就可以申请TRC项目了,申请流程非常简单,就是填写一个小问卷: 问题很简单,如实填写就行了,貌似就没有被拒的:)...

July 7, 2022 · 1 min