Simple JAX-FLAX implementation