Metadata-Version: 2.1
Name: fusions
Version: 0.4.3
Summary: Diffusion meets sampling
Author: David Yallup
Author-email: david.yallup@gmail.com
Requires-Python: >=3.9,<4.0
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.9
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Requires-Dist: anesthetic (>=2.4.2,<3.0.0)
Requires-Dist: diffrax (>=0.4.0,<0.5.0)
Requires-Dist: distrax (>=0.1.5,<0.2.0)
Requires-Dist: flax (>=0.8.2,<0.9.0)
Requires-Dist: optax (>=0.2.2,<0.3.0)
Requires-Dist: pytest (>=7.3,<8.0)
Requires-Dist: torch (>=2.1.0,<3.0.0)
Requires-Dist: tqdm (>=4.62.0,<5.0.0)
Description-Content-Type: text/markdown

# fusions

[![tests](https://github.com/yallup/fusions/actions/workflows/tests.yml/badge.svg)](https://github.com/yallup/fusions/actions/workflows/tests.yml)
[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)
[![PyPI version](https://badge.fury.io/py/fusions.svg)](https://badge.fury.io/py/fusions)

Diffusion meets (nested) sampling

A miniminal implementation of diffusion models in JAX (Flax). Tuned for usage in building emulators for scientific models, particularly where MCMC sampling is tractable and used.


## Quickstart

Install `fusions` and `lsbi` from pypi
```
pip install lsbi fusions
```

create a 5D sampling problem then train a flow matched model to approximate the posterior

```python
from fusions.cfm import CFM
from lsbi.model import MixtureModel
from anesthetic import MCMCSamples
import matplotlib.pyplot as plt
import numpy as np


dims = 5
Model = MixtureModel(
    M=np.stack([np.eye(dims), -np.eye(dims)]),
    C=np.eye(dims)*0.1,
)

data = Model.evidence().rvs()

diffusion = CFM(Model.prior())
# diffusion = CFM(dims)

diffusion.train(Model.posterior(data).rvs(1000))

a = MCMCSamples(Model.posterior(data).rvs(500)).plot_2d(np.arange(dims))
MCMCSamples(diffusion.rvs(500)).plot_2d(a)
plt.show()
```

