jax>=0.4.25
jaxlib>=0.4.25
multipledispatch
numpy
tqdm

[cpu]
jax[cpu]>=0.4.25

[cuda]
jax[cuda]>=0.4.25

[dev]
dm-haiku
flax
funsor>=0.4.1
graphviz
jaxns==2.6.3
matplotlib
optax>=0.0.6
pylab-sdk
pytest-cov
pyyaml
requests
tensorflow_probability>=0.18.0

[doc]
ipython
nbsphinx>=0.8.9
readthedocs-sphinx-search>=0.3.2
sphinx>=5
sphinx_rtd_theme
sphinx-gallery

[examples]
arviz
jupyter
matplotlib
pandas
seaborn
scikit-learn
wordcloud

[test]
importlib-metadata<5.0
ruff>=0.1.8
mypy>=1.13
pytest>=4.1
pyro-api>=0.1.1
scikit-learn
scipy>=1.9

[tpu]
jax[tpu]>=0.4.25
