jax[cpu]>=0.4.3
numpy>=1.22
scipy>=1.8
opt_einsum>=3.3
optax>=0.1.5
