jax[cpu]
torch==1.11
graphviz
