Metadata-Version: 2.1
Name: jax-random-projections
Version: 1.0.1
Summary: sklearn's random projection with JAX to run on a GPU
Home-page: https://github.com/Baschdl/jax-random_projections
Author: Sebastian Bischoff
Author-email: sebastian@salzreute.de
License: MIT
Download-URL: https://github.com/Baschdl/jax-random_projections/archive/v1.0.1.tar.gz
Keywords: random projections,jax,GPU
Platform: UNKNOWN
Classifier: Development Status :: 5 - Production/Stable
Classifier: Intended Audience :: Science/Research
Classifier: Intended Audience :: Developers
Classifier: Topic :: Software Development
Classifier: Topic :: Scientific/Engineering
Classifier: License :: OSI Approved :: MIT License
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.6
Classifier: Programming Language :: Python :: 3.7
Classifier: Programming Language :: Python :: 3.8
Classifier: Environment :: GPU
Description-Content-Type: text/markdown
Requires-Dist: sklearn
Requires-Dist: jaxlib
Requires-Dist: jax

# JAX Random Projection Transformers
Using JAX to speed up sklearn's random projection transformers

## Installation

**Note: Installation with pip will install the CPU-only version of JAX**

To use a GPU follow [JAX's installation guide](https://github.com/google/jax#installation) before installing `jax-random_projections`.
```
pip install jax-random_projections
```

## Usage
```python
from jax_random_projections.sparse import SparseRandomProjectionJAX

transfomer = SparseRandomProjectionJAX()
transfomer.fit_transform(X)
```

For the API documentation, refer to [sklearn's SparseRandomProjection documentation](https://scikit-learn.org/stable/modules/generated/sklearn.random_projection.SparseRandomProjection.html).
The only difference is that `jax-random_projections` currently only supports `xla.DeviceArray` and doesn't support `dense_output=False` and `y` for `fit()`
This library currently only includes the `SparseRandomProjection` but a future release will also include `GaussianRandomProjection`.

`jax-random_projections` also includes `SparseRandomProjectionJAXCached` which uses a lru cache (`maxsize=5`) to speed up repeated calls by caching the random matrix for data with the same input dimension.


