jax>=0.4.0
jaxlib>=0.4.0
fmmax>=0.8.0

[test]
pre-commit
pytest-cov
ruff
optax
mypy
