numpy>=1.20.0
jax>=0.3.22
optax>=0.1.0
jaxlib>=0.3.22
dm-haiku>=0.0.9

[dev]
pytest>=7.1.2
