背景
上一篇文章我们学习了JAX的基本知识,主要是几个关键词: NumPy API、transformations、XLA。这一篇来点实际的,看下如何训练一个神经网络,我们先回忆下用PyTorch训练神经网络,都需要哪几步:
- 实现网络模型
- 实现数据读取流程
- 使用优化器/调度器更新模型参数/学习率
- 实现模型训练和验证流程
下面我们就以在MNIST数据集上训练一个MLP为例,看下在JAX中如何实现上面的流程。
NumPy API实现网络模型
MNIST是一个10分类问题,每张图片大小是 28 * 28=784 ,我们设计一个简单的MLP网络,
import jax
from jax import numpy as jnp
from jax import grad, jit, vmap
# 创建 PRNGKey (PRNG State)
key = jax.random.PRNGKey(0)
## 创建模型参数, 去除输入层,实际上三层Linear,每层都包含一组(w, b),共三组参数
def random_layer_params(m, n, key, scale=1e-2):
"""
A helper function to randomly initialize weights and biases
for a dense neural network layer
"""
w_key, b_key = jax.random.split(key) # 显式更新PRNG state
return scale * jax.random.normal(w_key, (n, m)), scale * jax.random.normal(b_key, (n,))
def init_network_params(sizes, key):
"""Initialize all layers for a fully-connected neural network with sizes "sizes"
"""
keys = jax.random.split(key, len(sizes)) # split可以同时创建多个key
return [random_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]
layer_sizes = [784, 512, 512, 10]
key, init_key = jax.random.split(key) # init_key used for initialization
params = init_network_params(layer_sizes, init_key)
print(len(params), len(params[0]), len(params[1]), len(params[2]))
# 3, 2, 2, 2
print(params[0][0].shape, params[0][1].shape)
# (512, 784), (512,)
# 创建网络,实际上就是写出forward
def relu(x):
return jnp.maximum(0, x)
# 注意下面的x只是一张图片,我们并不需要自己动手去实现batched_x
def model_forward(params, x):
# per-example predictions
for w, b in params[:-1]:
x = jnp.dot(w, x) + b
x = relu(x)
final_w, final_b = params[-1]
logits = jnp.dot(final_w, x) + final_b
return logits
# 模型forward已经完成了,下面测试下
key, test_key = jax.random.split(key)
random_flattened_image = jax.random.normal(test_key, (784, ))
preds = model_forward(params, random_flattened_image)
print(preds.shape)
# (10,)
我们知道,网络的输入都是batch数据,下面就用vmap
来得到一个支持batch的model_forward
:
# 创建一个随机batch数据, shape=(32, 784)
random_batched_flattened_images = jax.random.normal(jax.random.PRNGKey(1), (32, 784))
model_forward(params, random_batched_flattened_images) # error
# TypeError: Incompatible shapes for dot: got (512, 784) and (32, 784).
# 创建支持batch数据的model_forward, 使用vmap so easy
batched_forward = vmap(model_forward, in_axes=(None, 0), out_axes=0)
batched_preds = batched_forward(params, random_batched_flattened_images)
print(batched_preds.shape)
# (32, 10)
借助PyTorch实现数据读取流程
准确来说,JAX并不是为深度学习而设计的框架,它并不包含任何数据集处理相关的函数和类,但是借助NumPy NDArray作为桥梁,我们可以将PyTorch中的Dataset/DataLoader和JAX DeviceArray连接起来:
方法很简单,在创建DataLoader
时使用自定义的collate_fn
,返回numpy array而不是torch Tensor。
还有一点要注意,上一篇文章介绍JAX时,我们讲过,JAX中建议使用显式的随机数生成器状态(PRNG State),所以,我们最好不使用DataLoader
自带的shuffle
,而是自定义Sampler
。
import numpy as np
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torch.utils.data import Sampler, SequentialSampler
class FlattenAndCast(object):
def __call__(self, pic):
return np.ravel(np.array(pic, dtype=jnp.float32))
# DataLoader返回numpy array,而不是torch Tensor
def numpy_collate(batch):
if isinstance(batch[0], np.ndarray):
return np.stack(batch)
elif isinstance(batch[0], (tuple,list)):
transposed = zip(*batch)
return [numpy_collate(samples) for samples in transposed]
else:
return np.array(batch)
class JAXRandomSampler(Sampler):
def __init__(self, data_source, rng_key):
self.data_source = data_source
self.rng_key = rng_key
def __len__(self):
return len(self.data_source)
def __iter__(self):
self.rng_key, current_rng = jax.random.split(self.rng_key)
return iter(jax.random.permutation(current_rng, jnp.arange(len(self))).tolist())
class NumpyLoader(DataLoader):
def __init__(self, dataset, rng_key=None, batch_size=1,
shuffle=False, **kwargs):
if shuffle:
sampler = JAXRandomSampler(dataset, rng_key)
else:
sampler = SequentialSampler(dataset)
super().__init__(dataset, batch_size, sampler=sampler, **kwargs)
# 借助于torchvision和NumpyLoader
mnist_dataset_train = MNIST('/tmp/mnist/', download=True, transform=FlattenAndCast())
key, loader_key = jax.random.split(key)
train_loader = NumpyLoader(mnist_dataset_train, loader_key, batch_size=32, shuffle=True,
num_workers=0, collate_fn=numpy_collate, drop_last=True)
mnist_dataset_test = MNIST('/tmp/mnist/', download=True, train=False, transform=FlattenAndCast())
eval_loader = NumpyLoader(mnist_dataset_test, batch_size=128, shuffle=False, num_workers=0,
collate_fn=numpy_collate, drop_last=False)
### Here we set num_workers=0
使用优化器更新模型参数
我们实现一个简单的SGD,
from jax.scipy.special import logsumexp
def loss(params, images, targets):
logits = batched_forward(params, images)
preds = logits - logsumexp(logits)
return -jnp.mean(preds * targets)
@jit
def sgd_update(params, x, y, lr):
grads = grad(loss)(params, x, y)
return [(w - lr * dw, b - lr * db)
for (w, b), (dw, db) in zip(params, grads)]
上面sgd_update这种写法没啥问题,但是考虑下,如果我们为模型再添加一层layer,含有三个参数:\( W_{2} \cdot (W_{1}\cdot x) + b \),此时
\[ params = [(w_{1}, b_{1}), (w_{2}, b_{2}), (w_{3}, b_{3}), (w_{4}, w_{5}, b_{4})] \]
sgd_update
最后一行的列表解析式就不太好写了,涉及到if else了,比如这样
@jit
def sgd_update(params, x, y, lr):
grads = grad(loss)(params, x, y)
return [(param[0] - lr * grad[0], param[1] - lr * grad[1]) if len(param) == 2 else
(param[0] - lr * grad[0], param[1] - lr * grad[1], param[2] - lr * grad[2])
for param, grad in zip(params, grads)]
如果网络再复杂一点,比如Transformer,里面大大小的的layer十几个,这得多少个if else
PyTree
JAX有一个叫做"PyTree"的数据结构,并且内置了jax.tree_util.tree_*
模块,里面有大量针对pytree结构的函数,可以优雅的进行参数管理。
特别要说明的是,PyTree不特指某一种数据类型,而是一种概念,是一类数据类型的统称,比如Python中的list、tuple、dict都属于pytree,这也比较容易理解,pytree表示“树”结构,线性的序列当然也属于树,而数值(int、float啥的)、ndarray、字符串以及我们自定义的类则不属于pytree,它们被称为leaf (叶子)。
PyTree类型可以嵌套PyTree以及leaf,比如list可以包含几个float,甚至可以包含其他的list,而leaf就是孤立的一个数字或者一个数组,不可以嵌套list、tuple这些。
看几个例子,
[1, "a", object()] # 这个list属于pytree,它含有3个leaf: 1, "a", object()
(1, (2, 3), ()) # 这个tuple属于pytree,含有三个leaf: 1, 2, 3
[1, {"k1": 2, "k2": (3, 4)}, 5] # 这个list含有5个leaf: 1, 2, 3, 4, 5
JAX也支持用户把自定义的类注册为PyTree,这部分后面会单独写篇文章讲一下,因为PyTree是JAX最核心的数据结构,上一篇文章讲过的jaxpr只接受PyTree作为输入,返回的也是PyTree。
我们再来看下jax.tree_util.tree_*
模块中的函数,比如tree_map
,这是一个为PyTree数据类型设计的map函数,
ptree = (1, (2, 3), (), [(2,3,4,), 5], {"key": 2})
jax.tree_util.tree_map(lambda x: x+ 2, ptree)
# (3, (4, 5), (), [(4, 5, 6), 7], {'key': 4})
哦豁,是不是有点好用?
下面我们就更新下sgd_update
,
@jit
def sgd_update(params, x, y, lr):
grads = grad(loss)(params, x, y)
return jax.tree_util.tree_map(lambda p, g: p - lr * g, params, grads)
``
`请把优雅打在公屏上,栓Q`
# 训练流程和验证流程
ok,接下来就是把代码串起来,整一个训练流程和验证流程:
```python
def one_hot(x, k=10, dtype=jnp.float32):
"""Create a one-hot encoding of x of size k."""
return jnp.array(x[:, None] == jnp.arange(k), dtype)
def accuracy(params, loader):
total_acc = 0
total_num = 0
for x, y in loader:
predicted_class = jnp.argmax(batched_forward(params, x), axis=1)
total_num += len(x)
total_acc += jnp.sum(predicted_class == y)
return total_acc / total_num
lr = 0.01
n_classes = 10
for epoch in range(5):
for idx, (x, y) in enumerate(train_loader):
y = one_hot(y, n_classes)
params = sgd_update(params, x, y, lr)
lr = lr * 0.999 if lr > 1e-3 else 1e-3 # very simple lr scheduler
if idx % 20 == 0: # evaluation
train_acc = accuracy(params, train_loader)
eval_acc = accuracy(params, eval_loader)
print("Epoch {} - batch_idx {}, Training set acc {}, eval set accuracy {}".format(
epoch, idx, train_acc, eval_acc))
# Training logs
Epoch 0 - batch_idx 0, Training set acc 0.09814999997615814, eval set accuracy 0.09950000047683716
Epoch 0 - batch_idx 100, Training set acc 0.8302666544914246, eval set accuracy 0.8362999558448792
Epoch 0 - batch_idx 200, Training set acc 0.8892666697502136, eval set accuracy 0.8940999507904053
Epoch 0 - batch_idx 300, Training set acc 0.8997166752815247, eval set accuracy 0.9006999731063843
Epoch 0 - batch_idx 400, Training set acc 0.9085167050361633, eval set accuracy 0.9128999710083008
Epoch 0 - batch_idx 500, Training set acc 0.9076499938964844, eval set accuracy 0.911300003528595
Epoch 0 - batch_idx 600, Training set acc 0.9230999946594238, eval set accuracy 0.9253000020980835
Epoch 0 - batch_idx 700, Training set acc 0.9269000291824341, eval set accuracy 0.9298999905586243
Epoch 0 - batch_idx 800, Training set acc 0.9295666813850403, eval set accuracy 0.9334999918937683
Epoch 0 - batch_idx 900, Training set acc 0.9290666580200195, eval set accuracy 0.9296999573707581
Epoch 0 - batch_idx 1000, Training set acc 0.9342833161354065, eval set accuracy 0.9357999563217163
# 大概3个epoch后,acc能达到95%
以上就是使用JAX NumPy API来实现网络训练的流程,感觉一下子回到了上个世纪。
PyTorch一大特色就是API设计非常简洁优雅,用最少的类干最多的活,比如群众喜闻乐见的以nn.Module
为核心进行模型创建,我们能不能模仿一个呢?
JAX && Flax && Optax
JAX NumPy API 在torch.nn.Module 面前显得太底层了,因此,衍生了不少基于JAX的深度学习框架 (Flax、Haiku、Equinox …),有点像当年TensorFlow1时代各种高阶API混战,不过没那么夸张,现在大家基本上接受了"JAX + Flax + Optax"的三件套:
jax.numpy
提供array操作函数,类似于torch.*
from flax import linen as nn
,对齐torch.nn.*
optax.*
对齐torch.optim.*
,同时也包含各种损失函数jit
、grad
、vmap
、pmap
等transformations做各种胶水操作
下面我们就重构下刚才的训练流程。
首先是利用nn.Module
创建模型,
Module
类型是dataclass
,我们将超参数(比如Dense
layer的size)设置为field,别忘了添加类型标注 (type annotation)- 由于Module类型是dataclass,所以
__init__
方法我们用不了了,那就在setup
方法中创建模型需要的Layer,也即是sub Module,相当于torch.nn.Module.__init__
- 在
__call__
方法中 实现模型前向计算过程,相当于torch.nn.Module.forward
还要注意flax的Module中是不包含模型参数的,必须显式的通过init
方法来创建参数, 而init
方法本质上就是调用了一次__call__
,只不过需要额外加上一个PRNGKey参数,比如__call__
方法的参数列表是(self, *args, **kwargs)
,那么init
方法的参数列表就是(rngkey, *args, **kwargs)
。
init方法需要创建模型参数, 也就是模型参数初始化,我们知道这一步需要用到随机数生成,比如
some_params = jnp.random(key, (2, 2)),所以需要用到PRNGKey。
一旦模型和参数都创建好,在调用模型时也要注意,不是model(x)
,而是用apply
方法,model.apply()
。 apply
方法默认调用__call__
,只不过需要额外加上一个模型参数,比如__call__
方法的参数列表是 (self, *args, **kwargs)
,那么apply
方法的参数列表就是(params, *args, rngs=<RNGS>, mutable=<MUTABLEKINDS>, **kwargs)
,其中rngs
用于那些需要随机性的layer,比如dropout,而mutable
用于BatchNorm等包含状态变量的layer。
import jax
from jax import numpy as jnp
from jax import grad, jit, vmap
from flax import linen as nn
from typing import Sequence
# 创建 PRNGKey (PRNG State)
key = jax.random.PRNGKey(0)
class MLP(nn.Module):
layer_sizes: Sequence[int] = None # 类型标注信息 Sequence[int]
def setup(self):
# 创建Dense()时只设置了输出维度大小,输入维度大小需要Flax进行推测
self.layers = [nn.Dense(features=size) for size in self.layer_sizes[1:]]
def __call__(self, x):
for layer in self.layers[:-1]:
x = layer(x)
x = nn.relu(x)
return self.layers[-1](x)
layer_sizes = [784, 512, 512, 10]
# 创建模型
model = MLP(layer_sizes)
# 使用`init`和dummy_x来创建模型参数
# 注意,在创建Dense()时并没有指定输入维度大小,
# `init`方法本质上也是调用`__call__`,利用dummpy_x来进行推测参数的shape
key, init_key = jax.random.split(key) # init_key used for initialization
dummy_x = jax.random.uniform(init_key, (784, ))
key, init_key = jax.random.split(key)
# init_key
params = model.init(init_key, dummy_x)
# params
# 创建一个随机batch数据, shape=(32, 784)
# 模型同Pytorch Module一样,自动支持batch数据,所以也不需要手动vmap了
random_batched_flattened_images = jax.random.normal(jax.random.PRNGKey(1), (32, 784))
model.apply(params, random_batched_flattened_images).shape
# (32, 10)
接下来我们使用Optax来创建优化器和学习率调度器,
import optax
lr = 1e-3
# 学习率调度算法
lr_decay_fn = optax.linear_schedule(
init_value=lr,
end_value=1e-5,
transition_steps=200,
)
# 直接上Adam
optimizer = optax.adam(
learning_rate=lr_decay_fn,
)
先不管Adam好不好用,经历了手撸SGD的苦日子,用上了Optax,就是这么任性!
数据读取的代码原封不动即可。
TrainState
如果我们需要对模型进行checkpoint怎么办,在PyTorch中,需要保存Model、optimizer、lr_scheduer的state_dict,Flax将这些统称为训练阶段的状态,可以用TrainState
类进行封装,方便checkpoint。
state = train_state.TrainState.create(apply_fn=model.apply, params=params, tx=optimizer)
此时模型的前向计算由model.apply
换成state.apply_fn
即可,梯度更新也很简单state.apply_gradients(grads=grads)
,接下来,就是训练流程,可以简化如下:
def train_step(state, x, y):
"""Computes gradients and loss for a single batch."""
def loss_fn(params):
logits = state.apply_fn(params, x)
one_hot = jax.nn.one_hot(y, 10)
loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
return loss
grad_fn = value_and_grad(loss_fn)
loss, grads = grad_fn(state.params)
new_state = state.apply_gradients(grads=grads)
return new_state, loss
# donate_argnums用于buffer复用,这里指的是输入和输出的state buffer复用
jit_train_step = jit(train_step, donate_argnums=(0,))
@jax.jit
def apply_model(state, x):
"""Computes gradients and loss for a single batch."""
logits = state.apply_fn(state.params, x)
return jnp.argmax(logits, -1)
def eval_model(state, loader):
total_acc = 0.
total_num = 0.
for x, y in loader:
y_pred = apply_model(state, x)
total_num += len(x)
total_acc += jnp.sum(y_pred == y)
return total_acc / total_num
for epoch in range(5):
for idx, (x, y) in enumerate(train_loader):
state, loss = jit_train_step(state, x, y)
if idx % 20 == 0: # evaluation
train_acc = eval_model(state, train_loader)
eval_acc = eval_model(state, eval_loader)
print("Epoch {} - batch_idx {}, loss {}, Training set acc {}, eval set accuracy {}".format(
epoch, idx, loss, train_acc, eval_acc))
Dropout和BatchNorm
前面我们讲过,JAX中每次涉及到随机数生成,都需要显式传入PRNGKey,如果网络中有Dropout,应该怎么处理呢?和PyTorch的Module不同,Flax的Module是不存储模型权重参数的,所以每次调用apply时都需要传入模型参数,如果网络中存在BatchNorm,又如何处理统计量数据呢?
首先,我们将上面的MLP模型改造下,添加上Dropout和BatchNorm,
class MLP(nn.Module):
def setup(self):
self.layer1 = nn.Dense(features=512)
self.dropout1 = nn.Dropout(rate=0.3)
self.norm1 = nn.BatchNorm()
self.layer2 = nn.Dense(features=512)
self.dropout2 = nn.Dropout(rate=0.4)
self.norm2 = nn.BatchNorm()
self.layer3 = nn.Dense(features=10)
def __call__(self, x, train:bool = True):
"""train用于区分train_mode or eval_mode"""
x = nn.relu(self.layer1(x))
x = self.dropout1(x, deterministic=not train)
x = self.norm1(x, use_running_average=not train)
x = nn.relu(self.layer2(x))
x = self.dropout2(x, deterministic=not train)
x = self.norm2(x, use_running_average=not train)
x = self.layer3(x)
return x
init
此时,在模型参数初始化时,就要注意了:Dropout在模型验证和Inference阶段不需要随机的dropout,在训练阶段每次前向过程都涉及随机操作,所以调用init时需要单独为"dropout"指定一个PRNGKey。
# 创建模型
model = MLP()
# 使用`init`和dummy_x来创建模型参数
key, init_key = jax.random.split(key)
dummy_x = jax.random.uniform(init_key, (784, ))
key, init_key, drop_key = jax.random.split(key, 3) # 通过split得到3个key
# "dropout"这个名字是固定的
variables = model.init({"params": init_key, "dropout": drop_key}, dummy_x, train=True)
我们再看下此时的variables
,多了"batch_stats",这就是BatchNorm中的统计量moving_mean和moving_var,
variables.keys()
# frozen_dict_keys(['params', 'batch_stats'])
variables['batch_stats'].keys()
# frozen_dict_keys(['norm1', 'norm2'])
variables['batch_stats']['norm1'].keys()
# frozen_dict_keys(['mean', 'var'])
apply
在调用apply进行前向计算时,也要注意,
- 通过
rngs
为Dropout传入PRNGKey mutable
指定"batch_stats"是可变的,需要在前向计算过程中进行更新
apply返回结果除了y还有更新后的"batch_stats":
key, drop_key = jax.random.split(key)
y, non_trainable_params = model.apply(variables, dummy_x, train=True, rngs={"dropout": drop_key},
mutable=['batch_stats'])
non_trainable_params.keys()
# frozen_dict_keys(['batch_stats'])
参数更新
此时variables中包含了"batch_stats",我们首先新建一个TrainState
类来包含batch_stats,其次 batch_stats不属于模型权重,不应该参与到optimizer的参数更新,所以训练流程也要进行修改:
class CustomTrainState(train_state.TrainState):
batch_stats: flax.core.FrozenDict[str, Any]
state = CustomTrainState.create(
apply_fn=model.apply,
params=variables['params'],
tx=optimizer,
batch_stats=variables['batch_stats'],
)
def train_step(state, x, y, dropout_key):
"""Computes gradients and loss for a single batch."""
def loss_fn(params):
logits, new_state = state.apply_fn({"params": params, "batch_stats": state.batch_stats},
x, train=True, rngs={"dropout": dropout_key}, mutable=["batch_stats"])
one_hot = jax.nn.one_hot(y, 10)
loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
return loss, new_state
grad_fn = value_and_grad(loss_fn, has_aux=True) # `value_and_grad`在进行grad同时返回loss
(loss, new_state), grads = grad_fn(state.params)
new_state = state.apply_gradients(grads=grads, batch_stats=new_state["batch_stats"])
return new_state, loss
jit_train_step = jit(train_step, donate_argnums=(0,)) # donate_argnums用于buffer复用,这里指的是输入和输出的state buffer复用
@jax.jit
def apply_model(state, x):
"""Computes gradients and loss for a single batch."""
logits = state.apply_fn({"params":state.params, "batch_stats": state.batch_stats},
x, train=False) # train设置为False,即为eval mode
return jnp.argmax(logits, -1)
for epoch in range(5):
for idx, (x, y) in enumerate(train_loader):
key, dropout_key = jax.random.split(key)
state, loss = jit_train_step(state, x, y, dropout_key)
if idx % 100 == 0: # evaluation
train_acc = eval_model(state, train_loader)
eval_acc = eval_model(state, eval_loader)
print("Epoch {} - batch_idx {}, loss {}, Training set acc {}, eval set accuracy {}".format(
epoch, idx, loss, train_acc, eval_acc))
# some logs
Epoch 0 - batch_idx 0, loss 2.559518337249756, Training set acc 0.3179420530796051, eval set accuracy 0.31070002913475037
Epoch 0 - batch_idx 100, loss 0.3981797695159912, Training set acc 0.9382011890411377, eval set accuracy 0.9367000460624695
Epoch 0 - batch_idx 200, loss 0.29799991846084595, Training set acc 0.9520065784454346, eval set accuracy 0.9492000341415405
Epoch 0 - batch_idx 300, loss 0.22030052542686462, Training set acc 0.9536759257316589, eval set accuracy 0.9513000249862671
Epoch 0 - batch_idx 400, loss 0.22531506419181824, Training set acc 0.9540432095527649, eval set accuracy 0.950700044631958
Epoch 1 - batch_idx 0, loss 0.2441655695438385, Training set acc 0.954594075679779, eval set accuracy 0.9508000612258911
Epoch 1 - batch_idx 100, loss 0.14692620933055878, Training set acc 0.9552618265151978, eval set accuracy 0.9508000612258911
源码
以上就是使用JAX + Flax + Optax训练神经网络的简单示例,上面的代码我已放到GitHub :)
参考资料
[1] JAX文档,JAX reference documentation
[2] Flax文档,Flax documentation
[3] Optax文档,https://optax.readthedocs.io/en/latest/
[4] Pytorch Dataloders for Jax, https://colab.research.google.com/github/kk1694/blog/blob/master/_notebooks/2021-05-03-Pytorch_Dataloaders_for_Jax.ipynb
[5] Flax的基本用法, https://colab.research.google.com/github/google/flax/blob/main/docs/notebooks/flax_basics.ipynb
[6] Flax文档中的Annotated MNIST,Flax documentation
[7] Dropout和BatchNorm的例子参考 Machine Learning with Flax - From Zero to Hero Machine Learning with Flax - From Zero to Hero 和 HuggingFace BERT Flax实现