torch>=1.12
functorch>=0.2
optree
numpy
graphviz
typing-extensions

[lint]
isort
black>=22.6.0
pylint
mypy
flake8
flake8-bugbear
doc8<1.0.0a0
pydocstyle
pyenchant
cpplint
pre-commit

[test]
pytest
pytest-cov
pytest-xdist
jax[cpu]>=0.3
jaxopt
optax
