jaxlib==0.4.28
jax==0.4.28
chex==0.1.86
distrax==0.1.5
flax==0.8.4
flashbax==0.1.2
jaxtyping==0.2.29
gymnax==0.0.6
jax-tqdm==0.2.1
optax==0.2.2
orbax-checkpoint==0.5.14
numpy==1.26.4
scipy==1.13.1
pandas==2.2.2
matplotlib==3.9.0
pytest==8.2.1

[dev]
