jax>=0.7.0
jaxlib>=0.7.0
multipledispatch
numpy
tqdm

[cpu]
jax[cpu]>=0.7.0

[cuda12]
jax[cuda12]>=0.7.0

[cuda13]
jax[cuda13]>=0.7.0

[tpu]
jax[tpu]>=0.7.0
