面向PyTorch用户的JAX简易教程[2]: 如何训练一个神经网络

背景 上一篇文章我们学习了JAX的基本知识,主要是几个关键词: NumPy API、transformations、XLA。这一篇来点实际的,看下如何训练一个神经网络,我们先回忆下用PyTorch训练神经网络,都需要哪几步: 实现网络模型 实现数据读取流程 使用优化器/调度器更新模型参数/学习率 实现模型训练和验证流程 下面我们就以在MNIST数据集上训练一个MLP为例,看下在JAX中如何实现上面的流程。 NumPy API实现网络模型 MNIST是一个10分类问题,每张图片大小是 28 * 28=784 ,我们设计一个简单的MLP网络, 一个四层MLP (包含输入层) import jax from jax import numpy as jnp from jax import grad, jit, vmap # 创建 PRNGKey (PRNG State) key = jax.random.PRNGKey(0) ## 创建模型参数, 去除输入层,实际上三层Linear,每层都包含一组(w, b),共三组参数 def random_layer_params(m, n, key, scale=1e-2): """ A helper function to randomly initialize weights and biases for a dense neural network layer """ w_key, b_key = jax....

July 23, 2022 · 9 min