jax==0.4.33
flax==0.9.0
ml_dtypes==0.4.0
optax==0.2.3
orbax-checkpoint==0.6.4
orbax-export==0.0.5

[dev]
pytest
pytest-xdist

[grain]
grain==0.2.1

[tfds]
tensorflow==2.17.0
tensorflow_datasets==4.9.6
