matplotlib
scikit-learn
pandas
nbdev<2,>=1.0.10
dm-haiku
test_tube
jax[cpu]
torch>=1.7.0
tqdm
optax
pydantic<2,>=1.9.0
