Metadata-Version: 2.1
Name: scikit-jax
Version: 0.0.2.dev0
Summary: Classical machine learning algorithms on the GPU/TPU.
Home-page: https://github.com/LiibanMo/scikit-jax
Author: Liiban Mohamud
Author-email: liibanmohamud12@gmail.com
Keywords: jax classical machine learning
Classifier: Development Status :: 3 - Alpha
Classifier: Intended Audience :: Developers
Classifier: Intended Audience :: Science/Research
Classifier: Intended Audience :: Education
Classifier: Topic :: Software Development :: Build Tools
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Classifier: License :: OSI Approved :: MIT License
Classifier: Programming Language :: Python :: 3
Requires-Python: >=3.9
Description-Content-Type: text/markdown
License-File: LICENSE.txt
Requires-Dist: jax
Requires-Dist: pandas
Requires-Dist: numpy
Requires-Dist: matplotlib
Requires-Dist: seaborn
Provides-Extra: dev
Requires-Dist: pytest>=6.0; extra == "dev"
Provides-Extra: test
Requires-Dist: pytest>=6.0; extra == "test"

<p align="center">
  <img src="assets/logo.png" alt="Alt text"/>
</p>

# Scikit-JAX: Classical Machine Learning on the GPU

Welcome to **Scikit-JAX**, a machine learning library designed to leverage the power of GPUs through JAX for efficient and scalable classical machine learning algorithms. Our library provides implementations for a variety of classical machine learning techniques, optimized for performance and ease of use.

## Features

- **Linear Regression**: Implemented with options for different weight initialization methods and dropout regularization.
- **KMeans**: Clustering algorithm to group data points into clusters.
- **Principal Component Analysis (PCA)**: Dimensionality reduction technique to simplify data while preserving essential features.
- **Multinomial Naive Bayes**: Classifier suitable for discrete data, such as text classification tasks.
- **Gaussian Naive Bayes**: Classifier for continuous data with a normal distribution assumption.

## Installation

To install Scikit-JAX, you can use pip. The package is available on PyPI:

```python
pip install scikit-jax
```

## Usage

Here is a quick guide on how to use the key components of Scikit-JAX.

### Linear Regression
```py
from skjax.linear_model import LinearRegression

# Initialize the model
model = LinearRegression(weights_init='xavier', epochs=100, learning_rate=0.01)

# Fit the model
model.fit(X_train, y_train)

# Make predictions
predictions = model.predict(X_test)

# Plot losses
model.plot_losses()
```

### K-Means
```python
from skjax.clustering import KMeans

# Initialize the model
kmeans = KMeans(num_clusters=3)

# Fit the model
kmeans.fit(X_train)
```

### Gaussian Naive Bayes
```python
from skjax.naive_bayes import GaussianNaiveBayes

# Initialize the model
nb = GaussianNaiveBayes()

# Fit the model
nb.fit(X_train, y_train)

# Make predictions
predictions = nb.predict(X_test)
```

### License

Scikit-JAX is licensed under the [MIT License](LICENSE.txt).
