背景

上一篇文章我们学习了JAX的基本知识,主要是几个关键词: NumPy API、transformations、XLA。这一篇来点实际的,看下如何训练一个神经网络,我们先回忆下用PyTorch训练神经网络,都需要哪几步:

  1. 实现网络模型
  2. 实现数据读取流程
  3. 使用优化器/调度器更新模型参数/学习率
  4. 实现模型训练和验证流程

下面我们就以在MNIST数据集上训练一个MLP为例,看下在JAX中如何实现上面的流程。

NumPy API实现网络模型

MNIST是一个10分类问题,每张图片大小是 28 * 28=784 ,我们设计一个简单的MLP网络,

一个四层MLP (包含输入层)
import jax
from jax import numpy as jnp
from jax import grad, jit, vmap

# 创建 PRNGKey (PRNG State)
key = jax.random.PRNGKey(0)


## 创建模型参数, 去除输入层,实际上三层Linear,每层都包含一组(w, b),共三组参数

def random_layer_params(m, n, key, scale=1e-2):
    """
    A helper function to randomly initialize weights and biases
    for a dense neural network layer
    """
    w_key, b_key = jax.random.split(key)  # 显式更新PRNG state
    return scale * jax.random.normal(w_key, (n, m)), scale * jax.random.normal(b_key, (n,))


def init_network_params(sizes, key):
    """Initialize all layers for a fully-connected neural network with sizes "sizes"
    """
    keys = jax.random.split(key, len(sizes))  # split可以同时创建多个key
    return [random_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]


layer_sizes = [784, 512, 512, 10]

key, init_key = jax.random.split(key)  # init_key used for initialization
params = init_network_params(layer_sizes, init_key)

print(len(params), len(params[0]), len(params[1]), len(params[2]))
# 3, 2, 2, 2

print(params[0][0].shape, params[0][1].shape)
# (512, 784), (512,)


# 创建网络,实际上就是写出forward
def relu(x):
    return jnp.maximum(0, x)

# 注意下面的x只是一张图片,我们并不需要自己动手去实现batched_x
def model_forward(params, x):
    # per-example predictions
    for w, b in params[:-1]:
        x = jnp.dot(w, x) + b
        x = relu(x)
    final_w, final_b = params[-1]
    logits = jnp.dot(final_w, x) + final_b
    return logits


# 模型forward已经完成了,下面测试下
key, test_key = jax.random.split(key)
random_flattened_image = jax.random.normal(test_key, (784, ))
preds = model_forward(params, random_flattened_image)
print(preds.shape)
# (10,)

我们知道,网络的输入都是batch数据,下面就用vmap来得到一个支持batch的model_forward:

# 创建一个随机batch数据, shape=(32, 784)
random_batched_flattened_images = jax.random.normal(jax.random.PRNGKey(1), (32, 784))
model_forward(params, random_batched_flattened_images)  # error
# TypeError: Incompatible shapes for dot: got (512, 784) and (32, 784).


# 创建支持batch数据的model_forward, 使用vmap so easy
batched_forward = vmap(model_forward, in_axes=(None, 0), out_axes=0)

batched_preds = batched_forward(params, random_batched_flattened_images)
print(batched_preds.shape)
# (32, 10)

借助PyTorch实现数据读取流程

准确来说,JAX并不是为深度学习而设计的框架,它并不包含任何数据集处理相关的函数和类,但是借助NumPy NDArray作为桥梁,我们可以将PyTorch中的Dataset/DataLoader和JAX DeviceArray连接起来:

PyTorch预处理 --> numpy.ndarray --> jax.numpy.array

方法很简单,在创建DataLoader 时使用自定义的collate_fn,返回numpy array而不是torch Tensor。

还有一点要注意,上一篇文章介绍JAX时,我们讲过,JAX中建议使用显式的随机数生成器状态(PRNG State),所以,我们最好不使用DataLoader自带的shuffle,而是自定义Sampler

import numpy as np
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torch.utils.data import Sampler, SequentialSampler


class FlattenAndCast(object):
    def __call__(self, pic):
        return np.ravel(np.array(pic, dtype=jnp.float32))

# DataLoader返回numpy array,而不是torch Tensor
def numpy_collate(batch):
    if isinstance(batch[0], np.ndarray):
        return np.stack(batch)
    elif isinstance(batch[0], (tuple,list)):
        transposed = zip(*batch)
        return [numpy_collate(samples) for samples in transposed]
    else:
        return np.array(batch)

class JAXRandomSampler(Sampler):
    def __init__(self, data_source, rng_key):
        self.data_source = data_source
        self.rng_key = rng_key
        
    def __len__(self):
        return len(self.data_source)
    
    def __iter__(self):
        self.rng_key, current_rng = jax.random.split(self.rng_key)
        return iter(jax.random.permutation(current_rng, jnp.arange(len(self))).tolist())


class NumpyLoader(DataLoader):
    def __init__(self, dataset, rng_key=None, batch_size=1,
                 shuffle=False, **kwargs):
        if shuffle:
            sampler = JAXRandomSampler(dataset, rng_key)
        else:
            sampler = SequentialSampler(dataset)
        
        super().__init__(dataset, batch_size, sampler=sampler, **kwargs)


# 借助于torchvision和NumpyLoader
mnist_dataset_train = MNIST('/tmp/mnist/', download=True, transform=FlattenAndCast())
key, loader_key = jax.random.split(key)
train_loader = NumpyLoader(mnist_dataset_train, loader_key, batch_size=32, shuffle=True,
                           num_workers=0, collate_fn=numpy_collate, drop_last=True)

mnist_dataset_test = MNIST('/tmp/mnist/', download=True, train=False, transform=FlattenAndCast())
eval_loader = NumpyLoader(mnist_dataset_test, batch_size=128, shuffle=False, num_workers=0,
                          collate_fn=numpy_collate, drop_last=False)
### Here we set num_workers=0

使用优化器更新模型参数

我们实现一个简单的SGD,

from jax.scipy.special import logsumexp


def loss(params, images, targets):
    logits = batched_forward(params, images)
    preds = logits - logsumexp(logits)
    return -jnp.mean(preds * targets)

@jit
def sgd_update(params, x, y, lr):
    grads = grad(loss)(params, x, y)
    return [(w - lr * dw, b - lr * db)
            for (w, b), (dw, db) in zip(params, grads)]

上面sgd_update这种写法没啥问题,但是考虑下,如果我们为模型再添加一层layer,含有三个参数:\( W_{2} \cdot (W_{1}\cdot x) + b \),此时

\[ params = [(w_{1}, b_{1}), (w_{2}, b_{2}), (w_{3}, b_{3}), (w_{4}, w_{5}, b_{4})] \]

sgd_update最后一行的列表解析式就不太好写了,涉及到if else了,比如这样


@jit
def sgd_update(params, x, y, lr):
    grads = grad(loss)(params, x, y)
    return [(param[0] - lr * grad[0], param[1] - lr * grad[1]) if len(param) == 2 else 
            (param[0] - lr * grad[0], param[1] - lr * grad[1], param[2] - lr * grad[2])
                for param, grad in zip(params, grads)]

如果网络再复杂一点,比如Transformer,里面大大小的的layer十几个,这得多少个if else

PyTree

JAX有一个叫做"PyTree"的数据结构,并且内置了jax.tree_util.tree_*模块,里面有大量针对pytree结构的函数,可以优雅的进行参数管理。

特别要说明的是,PyTree不特指某一种数据类型,而是一种概念,是一类数据类型的统称,比如Python中的list、tuple、dict都属于pytree,这也比较容易理解,pytree表示“树”结构,线性的序列当然也属于树,而数值(int、float啥的)、ndarray、字符串以及我们自定义的类则不属于pytree,它们被称为leaf (叶子)。

PyTree类型可以嵌套PyTree以及leaf,比如list可以包含几个float,甚至可以包含其他的list,而leaf就是孤立的一个数字或者一个数组,不可以嵌套list、tuple这些。

看几个例子,


[1, "a", object()]  # 这个list属于pytree,它含有3个leaf: 1, "a", object()

(1, (2, 3), ())  # 这个tuple属于pytree,含有三个leaf: 1, 2, 3

[1, {"k1": 2, "k2": (3, 4)}, 5]  # 这个list含有5个leaf: 1, 2, 3, 4, 5

JAX也支持用户把自定义的类注册为PyTree,这部分后面会单独写篇文章讲一下,因为PyTree是JAX最核心的数据结构,上一篇文章讲过的jaxpr只接受PyTree作为输入,返回的也是PyTree。

我们再来看下jax.tree_util.tree_*模块中的函数,比如tree_map,这是一个为PyTree数据类型设计的map函数,


ptree = (1, (2, 3), (), [(2,3,4,), 5], {"key": 2})
jax.tree_util.tree_map(lambda x: x+ 2, ptree)
# (3, (4, 5), (), [(4, 5, 6), 7], {'key': 4})

哦豁,是不是有点好用?

下面我们就更新下sgd_update

@jit
def sgd_update(params, x, y, lr):
    grads = grad(loss)(params, x, y)
    return jax.tree_util.tree_map(lambda p, g: p - lr * g, params, grads)
``

`请把优雅打在公屏上栓Q`

# 训练流程和验证流程

ok接下来就是把代码串起来整一个训练流程和验证流程


```python
def one_hot(x, k=10, dtype=jnp.float32):
    """Create a one-hot encoding of x of size k."""
    return jnp.array(x[:, None] == jnp.arange(k), dtype)

def accuracy(params, loader):
    total_acc = 0
    total_num = 0
    for x, y in loader:
        predicted_class = jnp.argmax(batched_forward(params, x), axis=1)
        total_num += len(x)
        total_acc += jnp.sum(predicted_class == y)
    return total_acc / total_num


lr = 0.01
n_classes = 10
for epoch in range(5):
    for idx, (x, y) in enumerate(train_loader):
        y = one_hot(y, n_classes)
        params = sgd_update(params, x, y, lr)
        lr = lr * 0.999 if lr > 1e-3 else 1e-3  # very simple lr scheduler
        if idx % 20 == 0:  # evaluation
            train_acc = accuracy(params, train_loader)
            eval_acc = accuracy(params, eval_loader)
            print("Epoch {} - batch_idx {}, Training set acc {}, eval set accuracy {}".format(
                  epoch, idx, train_acc, eval_acc))
# Training logs
Epoch 0 - batch_idx 0, Training set acc 0.09814999997615814, eval set accuracy 0.09950000047683716
Epoch 0 - batch_idx 100, Training set acc 0.8302666544914246, eval set accuracy 0.8362999558448792
Epoch 0 - batch_idx 200, Training set acc 0.8892666697502136, eval set accuracy 0.8940999507904053
Epoch 0 - batch_idx 300, Training set acc 0.8997166752815247, eval set accuracy 0.9006999731063843
Epoch 0 - batch_idx 400, Training set acc 0.9085167050361633, eval set accuracy 0.9128999710083008
Epoch 0 - batch_idx 500, Training set acc 0.9076499938964844, eval set accuracy 0.911300003528595
Epoch 0 - batch_idx 600, Training set acc 0.9230999946594238, eval set accuracy 0.9253000020980835
Epoch 0 - batch_idx 700, Training set acc 0.9269000291824341, eval set accuracy 0.9298999905586243
Epoch 0 - batch_idx 800, Training set acc 0.9295666813850403, eval set accuracy 0.9334999918937683
Epoch 0 - batch_idx 900, Training set acc 0.9290666580200195, eval set accuracy 0.9296999573707581
Epoch 0 - batch_idx 1000, Training set acc 0.9342833161354065, eval set accuracy 0.9357999563217163

# 大概3个epoch后,acc能达到95%

以上就是使用JAX NumPy API来实现网络训练的流程,感觉一下子回到了上个世纪。

PyTorch一大特色就是API设计非常简洁优雅,用最少的类干最多的活,比如群众喜闻乐见的以nn.Module为核心进行模型创建,我们能不能模仿一个呢?

JAX && Flax && Optax

JAX NumPy API 在torch.nn.Module 面前显得太底层了,因此,衍生了不少基于JAX的深度学习框架 (FlaxHaikuEquinox …),有点像当年TensorFlow1时代各种高阶API混战,不过没那么夸张,现在大家基本上接受了"JAX + Flax + Optax"的三件套:

  • jax.numpy提供array操作函数,类似于torch.*
  • from flax import linen as nn,对齐torch.nn.*
  • optax.*对齐torch.optim.*,同时也包含各种损失函数
  • jitgradvmappmap等transformations做各种胶水操作

下面我们就重构下刚才的训练流程。

首先是利用nn.Module创建模型,

  1. Module类型是dataclass,我们将超参数(比如Dense layer的size)设置为field,别忘了添加类型标注 (type annotation)
  2. 由于Module类型是dataclass,所以__init__方法我们用不了了,那就在setup方法中创建模型需要的Layer,也即是sub Module,相当于torch.nn.Module.__init__
  3. __call__方法中 实现模型前向计算过程,相当于torch.nn.Module.forward

还要注意flax的Module中是不包含模型参数的,必须显式的通过init方法来创建参数, 而init方法本质上就是调用了一次__call__,只不过需要额外加上一个PRNGKey参数,比如__call__方法的参数列表是(self, *args, **kwargs),那么init方法的参数列表就是(rngkey, *args, **kwargs)

init方法需要创建模型参数, 也就是模型参数初始化,我们知道这一步需要用到随机数生成,比如some_params = jnp.random(key, (2, 2)),所以需要用到PRNGKey。

一旦模型和参数都创建好,在调用模型时也要注意,不是model(x),而是用apply方法,model.apply()apply方法默认调用__call__,只不过需要额外加上一个模型参数,比如__call__方法的参数列表是 (self, *args, **kwargs),那么apply方法的参数列表就是(params, *args, rngs=<RNGS>, mutable=<MUTABLEKINDS>, **kwargs),其中rngs用于那些需要随机性的layer,比如dropout,而mutable用于BatchNorm等包含状态变量的layer。

import jax
from jax import numpy as jnp
from jax import grad, jit, vmap
from flax import linen as nn
from typing import Sequence

# 创建 PRNGKey (PRNG State)
key = jax.random.PRNGKey(0)


class MLP(nn.Module):
    layer_sizes: Sequence[int] = None  # 类型标注信息 Sequence[int]
    
    def setup(self):
        # 创建Dense()时只设置了输出维度大小,输入维度大小需要Flax进行推测
        self.layers = [nn.Dense(features=size) for size in self.layer_sizes[1:]]
    
    def __call__(self, x):
        for layer in self.layers[:-1]:
            x = layer(x)
            x = nn.relu(x)
        return self.layers[-1](x)

layer_sizes = [784, 512, 512, 10]


# 创建模型
model = MLP(layer_sizes)

# 使用`init`和dummy_x来创建模型参数
# 注意,在创建Dense()时并没有指定输入维度大小,
# `init`方法本质上也是调用`__call__`,利用dummpy_x来进行推测参数的shape
key, init_key = jax.random.split(key)  # init_key used for initialization
dummy_x = jax.random.uniform(init_key, (784, ))
key, init_key = jax.random.split(key)

# init_key
params = model.init(init_key, dummy_x)

# params

# 创建一个随机batch数据, shape=(32, 784)
# 模型同Pytorch Module一样,自动支持batch数据,所以也不需要手动vmap了
random_batched_flattened_images = jax.random.normal(jax.random.PRNGKey(1), (32, 784))
model.apply(params, random_batched_flattened_images).shape
# (32, 10)

接下来我们使用Optax来创建优化器和学习率调度器,


import optax

lr = 1e-3

# 学习率调度算法
lr_decay_fn = optax.linear_schedule(
        init_value=lr,
        end_value=1e-5,
        transition_steps=200,
)

# 直接上Adam
optimizer = optax.adam(
            learning_rate=lr_decay_fn,
)

先不管Adam好不好用,经历了手撸SGD的苦日子,用上了Optax,就是这么任性!

数据读取的代码原封不动即可。

TrainState

如果我们需要对模型进行checkpoint怎么办,在PyTorch中,需要保存Model、optimizer、lr_scheduer的state_dict,Flax将这些统称为训练阶段的状态,可以用TrainState类进行封装,方便checkpoint。

state = train_state.TrainState.create(apply_fn=model.apply, params=params, tx=optimizer)

此时模型的前向计算由model.apply换成state.apply_fn即可,梯度更新也很简单state.apply_gradients(grads=grads),接下来,就是训练流程,可以简化如下:

def train_step(state, x, y):
    """Computes gradients and loss for a single batch."""
    def loss_fn(params):
        logits = state.apply_fn(params, x)
        one_hot = jax.nn.one_hot(y, 10)
        loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
        return loss

    grad_fn = value_and_grad(loss_fn)
    loss, grads = grad_fn(state.params)
    new_state = state.apply_gradients(grads=grads)
    return new_state, loss

# donate_argnums用于buffer复用,这里指的是输入和输出的state buffer复用
jit_train_step = jit(train_step, donate_argnums=(0,))  


@jax.jit
def apply_model(state, x):
    """Computes gradients and loss for a single batch."""
    
    logits = state.apply_fn(state.params, x)
    return jnp.argmax(logits, -1)

def eval_model(state, loader):
    total_acc = 0.
    total_num = 0.
    for x, y in loader:
        y_pred = apply_model(state, x)
        total_num += len(x)
        total_acc += jnp.sum(y_pred == y)
    return total_acc / total_num

for epoch in range(5):
    for idx, (x, y) in enumerate(train_loader):
        state, loss = jit_train_step(state, x, y)
        if idx % 20 == 0:  # evaluation
            train_acc = eval_model(state, train_loader)
            eval_acc = eval_model(state, eval_loader)
            print("Epoch {} - batch_idx {}, loss {}, Training set acc {}, eval set accuracy {}".format(
              epoch, idx, loss, train_acc, eval_acc))

Dropout和BatchNorm

前面我们讲过,JAX中每次涉及到随机数生成,都需要显式传入PRNGKey,如果网络中有Dropout,应该怎么处理呢?和PyTorch的Module不同,Flax的Module是不存储模型权重参数的,所以每次调用apply时都需要传入模型参数,如果网络中存在BatchNorm,又如何处理统计量数据呢?

首先,我们将上面的MLP模型改造下,添加上Dropout和BatchNorm,

class MLP(nn.Module):
    
    def setup(self):
        self.layer1 = nn.Dense(features=512)
        self.dropout1 = nn.Dropout(rate=0.3)
        self.norm1 = nn.BatchNorm()
        
        self.layer2 = nn.Dense(features=512)
        self.dropout2 = nn.Dropout(rate=0.4)
        self.norm2 = nn.BatchNorm()
        
        self.layer3 = nn.Dense(features=10)
        
    def __call__(self, x, train:bool = True):
        """train用于区分train_mode or eval_mode"""

        x = nn.relu(self.layer1(x))
        x = self.dropout1(x, deterministic=not train)
        x = self.norm1(x, use_running_average=not train)
        x = nn.relu(self.layer2(x))
        x = self.dropout2(x, deterministic=not train)
        x = self.norm2(x, use_running_average=not train)
        
        x = self.layer3(x)

        return x

init

此时,在模型参数初始化时,就要注意了:Dropout在模型验证和Inference阶段不需要随机的dropout,在训练阶段每次前向过程都涉及随机操作,所以调用init时需要单独为"dropout"指定一个PRNGKey。

# 创建模型
model = MLP()

# 使用`init`和dummy_x来创建模型参数
key, init_key = jax.random.split(key)
dummy_x = jax.random.uniform(init_key, (784, ))

key, init_key, drop_key = jax.random.split(key, 3)  # 通过split得到3个key
# "dropout"这个名字是固定的
variables = model.init({"params": init_key, "dropout": drop_key}, dummy_x, train=True)

我们再看下此时的variables,多了"batch_stats",这就是BatchNorm中的统计量moving_mean和moving_var,

variables.keys()
# frozen_dict_keys(['params', 'batch_stats'])

variables['batch_stats'].keys()
# frozen_dict_keys(['norm1', 'norm2'])

variables['batch_stats']['norm1'].keys()
# frozen_dict_keys(['mean', 'var'])

apply

在调用apply进行前向计算时,也要注意,

  1. 通过rngs为Dropout传入PRNGKey
  2. mutable指定"batch_stats"是可变的,需要在前向计算过程中进行更新

apply返回结果除了y还有更新后的"batch_stats":

key, drop_key = jax.random.split(key)

y, non_trainable_params = model.apply(variables, dummy_x, train=True, rngs={"dropout": drop_key},
                                      mutable=['batch_stats']) 

non_trainable_params.keys()
# frozen_dict_keys(['batch_stats'])

参数更新

此时variables中包含了"batch_stats",我们首先新建一个TrainState类来包含batch_stats,其次 batch_stats不属于模型权重,不应该参与到optimizer的参数更新,所以训练流程也要进行修改:

class CustomTrainState(train_state.TrainState):
    batch_stats: flax.core.FrozenDict[str, Any]


state = CustomTrainState.create(
    apply_fn=model.apply,
    params=variables['params'],
    tx=optimizer,
    batch_stats=variables['batch_stats'],
)
def train_step(state, x, y, dropout_key):
    """Computes gradients and loss for a single batch."""
    def loss_fn(params):
        logits, new_state = state.apply_fn({"params": params, "batch_stats": state.batch_stats},
                                           x, train=True, rngs={"dropout": dropout_key}, mutable=["batch_stats"])
        
        one_hot = jax.nn.one_hot(y, 10)
        loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
        return loss, new_state

    grad_fn = value_and_grad(loss_fn, has_aux=True)  # `value_and_grad`在进行grad同时返回loss
    (loss, new_state), grads = grad_fn(state.params)
    new_state = state.apply_gradients(grads=grads, batch_stats=new_state["batch_stats"])
    return new_state, loss

jit_train_step = jit(train_step, donate_argnums=(0,))  # donate_argnums用于buffer复用,这里指的是输入和输出的state buffer复用


@jax.jit
def apply_model(state, x):
    """Computes gradients and loss for a single batch."""
    
    logits = state.apply_fn({"params":state.params, "batch_stats": state.batch_stats},
                            x, train=False)  # train设置为False,即为eval mode
    return jnp.argmax(logits, -1)

for epoch in range(5):
    for idx, (x, y) in enumerate(train_loader):
        key, dropout_key = jax.random.split(key)
        state, loss = jit_train_step(state, x, y, dropout_key)
        if idx % 100 == 0:  # evaluation
            train_acc = eval_model(state, train_loader)
            eval_acc = eval_model(state, eval_loader)
            print("Epoch {} - batch_idx {}, loss {}, Training set acc {}, eval set accuracy {}".format(
              epoch, idx, loss, train_acc, eval_acc))
# some logs
Epoch 0 - batch_idx 0, loss 2.559518337249756, Training set acc 0.3179420530796051, eval set accuracy 0.31070002913475037
Epoch 0 - batch_idx 100, loss 0.3981797695159912, Training set acc 0.9382011890411377, eval set accuracy 0.9367000460624695
Epoch 0 - batch_idx 200, loss 0.29799991846084595, Training set acc 0.9520065784454346, eval set accuracy 0.9492000341415405
Epoch 0 - batch_idx 300, loss 0.22030052542686462, Training set acc 0.9536759257316589, eval set accuracy 0.9513000249862671
Epoch 0 - batch_idx 400, loss 0.22531506419181824, Training set acc 0.9540432095527649, eval set accuracy 0.950700044631958
Epoch 1 - batch_idx 0, loss 0.2441655695438385, Training set acc 0.954594075679779, eval set accuracy 0.9508000612258911
Epoch 1 - batch_idx 100, loss 0.14692620933055878, Training set acc 0.9552618265151978, eval set accuracy 0.9508000612258911

源码

以上就是使用JAX + Flax + Optax训练神经网络的简单示例,上面的代码我已放到GitHub :)

jax-tutorials-for-pytorchers

参考资料

[1] JAX文档,JAX reference documentation

[2] Flax文档,Flax documentation

[3] Optax文档,https://optax.readthedocs.io/en/latest/

[4] Pytorch Dataloders for Jax, https://colab.research.google.com/github/kk1694/blog/blob/master/_notebooks/2021-05-03-Pytorch_Dataloaders_for_Jax.ipynb

[5] Flax的基本用法, https://colab.research.google.com/github/google/flax/blob/main/docs/notebooks/flax_basics.ipynb

[6] Flax文档中的Annotated MNIST,Flax documentation

[7] Dropout和BatchNorm的例子参考 Machine Learning with Flax - From Zero to Hero Machine Learning with Flax - From Zero to HeroHuggingFace BERT Flax实现