jax >= 0.3.2, <= 0.5.0
numpy >= 1.20.0, < 2.0
