numpy
scipy
matplotlib
seaborn
jaxlib>=0.4.1
jax>=0.4.1
optax
flax
torch
ml_collections
tqdm
absl-py
wandb
