jax<=0.5.0,>=0.3.2
numpy<2.0,>=1.20.0
