warning. This project is built in 2022 with older version of JAX and FLAX. There would be simpler way to implement the same thing. I will update codes in the future.
What is JAX?
JAX is a library developed by Google for scientific calculation. It supports various features including parallel computating, GPU acceleration, and auto-grad.
For the general usage for the scientific computations, developers of JAX has imitated a popular library called numpy. Literally, you can use JAX as numpy.
Furthermore, it provides a compiler (even it is a python library!) so that the runtime for repeated process could be reduced significantly. In other words, it supports just-in-time functions.
There would exist lots of applications for JAX, but the most attractive area (to me) is machine learning. Specifically, it provides very simple way of parallel computation without deep understading on it.
What is FLAX?
However, quite complicate code writing is required to build a neural network model solely using JAX.
You need to implement enormous modules which are composed of basic operations of numpy such as matrix multiplication and convolution.
To mitigate this, Google also has developed a library called FLAX, which provides pre-defined layer modules such as fully-connected, or conv2d. Of course, their underlying computations rely on JAX so all useful features are inherited.
Then, is everything OK?
Unfortunately, there exists remaining issues to build entire process of training and inference with current deep-learning models.
Particularly, it is different with Pytorch in many ways. For example, every pytorch module classes contain their parameters, while FLAX module classes work as placeholders. This means that parameters are independently handeled with FLAX model which makes more complex update rules than optimizer.step( ) of pytorch.
In the followings, I simply write what kinds of codes are required to
initialize model and optimizer
apply one optimization step
for Pytorch and JAX-FLAX with assumptions:
neural network contains batch normalization
loss and prediction are returned after one opimization step
input is 32x32x1 images and output is same size of images
You don't need to understand the meaning of every line. Just look at how different two libraries are:
# Pytorch
# Initialize model and optimizer
net = Model().train()
optim = torch.Adam(net.parameters(), lr=1e-4)
...
# Apply one optimization step
def Loss(prediction, label):
return torch.linalg.norm((prediction-label).flatten())**2
def train_step(in_, label):
optim.zero_grad()
output = net(in_)
loss = Loss(label, output)
loss.backward()
optim.step()
return loss, output
#JAX & FLAX
# Initialize model and optimizer
rng = random.PRNGKey(seed=123)
net = Model(train=True)
dummy_x = jnp.ones((1,32,32,1), dtype=jnp.float32)
state = net.init(rng, dummy_x) # for compilation.
optim = optax.adam(learning_rate=1e-4)
train_state = TrainState.create(
apply_fn=autoencoder.apply,
params=state["params"],
tx=optim,
batch_stats=state["batch_stats"]
)
...
# Apply one optimization step
@jax.vmap
def Loss(prediction, label):
return ((prediction-label)**2).sum()**0.5
@jax.jit
def train_step(state, in_):
def loss_fn(params, variables):
output, variables = state.apply_fn(
{'params':params,
'batch_stats': variables['batch_stats']},
in_,
mutable=['batch_stats']
)
loss = Loss(output, in_).mean()
return loss, (variables, output)
grad_fn = jax.value_and_grad(
loss_fn,
argnums=0,
has_aux=True
)
aux, grads = grad_fn(
state.params,
{'batch_stats':state.batch_stats}
)
loss, (variavles, output) = aux
state = state.apply_gradients(
grads=grads,
batch_stats=variables["batch_stats"]
)
return state, loss, recon
Due to the nature of JAX, we need to handle parameters using another object called 'state'. Also, statistical variables in some layers such as batch normalization should be handled separately.
Yes, it is obviously longer than pytorch, but once you get used to it, it promises satisfactory performance and the implementation of complex models in a simple way, thanks to characteristic features such as JIT and vmap.
Here are simple examples
As a result, I published a github repository that provides several examples on image processing tasks including autoencoder, classification, denoising, and generation (WGAN-GP).
Actually, it was created for my understanding to JAX and FLAX implementation, but I hope that this could help your understanding on how to define neural networks using FLAX, loss functions, and how to update parameters.