numpy
scipy
matplotlib
seaborn
jaxlib>=0.4.1
jax>=0.4.1
optax
flax
