torch>=2.0.0
pytorch-lightning~=2.0
scikit-learn
matplotlib
torchdiffeq==0.2.3
UMNN
wandb
