jax==0.4.31
flax==0.8.5
ml_dtypes==0.4.0
optax==0.2.3
orbax==0.1.9

[dev]
pytest
pytest-xdist

[grain]
grain==0.2.0

[tfds]
tensorflow==2.17.0
tensorflow_datasets==4.9.6
