numpy>=1.25
pandas>=2.0
jax>=0.4.19
jaxlib>=0.4.19
flax>=0.7.4
optax>=0.1.7
chex>=0.1.83
tensorflow-probability>=0.22.0
wandb>=0.13
