optax>=0.1.4
tree-math>0.1.0
matplotlib

[:python_version == "3.8"]
jax==0.4.13
jaxlib==0.4.13

[:python_version > "3.8"]
jax>=0.4.13
jaxlib>=0.4.13
