jax==0.4.35
flax==0.10.0
ml_dtypes==0.4.0
optax==0.2.3
orbax-checkpoint==0.7.0
orbax-export==0.0.5

[dev]
pytest
pytest-xdist

[grain]
grain==0.2.2

[tfds]
tensorflow==2.18.0
tensorflow_datasets==4.9.6
