jax>=0.4.28
numpy
scipy
