Metadata-Version: 2.1
Name: tjax
Version: 0.4.4
Summary: Tools for JAX.
Home-page: https://github.com/NeilGirdhar/cmm
License: MIT
Author: Neil Girdhar
Author-email: mistersheik@gmail.com
Requires-Python: >=3.7,<4.0
Classifier: Development Status :: 4 - Beta
Classifier: Intended Audience :: Developers
Classifier: License :: OSI Approved :: MIT License
Classifier: Operating System :: OS Independent
Classifier: Programming Language :: Python
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.7
Classifier: Programming Language :: Python :: 3.8
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Classifier: Topic :: Software Development :: Libraries :: Python Modules
Requires-Dist: chex (>=0.0.2,<0.0.3)
Requires-Dist: colorful (>=0.5.4,<0.6.0)
Requires-Dist: cooperative_dataclasses (>=0.2.0,<0.3.0)
Requires-Dist: jax (>=0.1.72,<0.2.0)
Requires-Dist: jaxlib (>=0.1.51,<0.2.0)
Requires-Dist: matplotlib (>=3.3,<4.0)
Requires-Dist: networkx (>=2.4,<3.0)
Requires-Dist: numpy (>=1.19,<2.0)
Project-URL: Repository, https://github.com/NeilGirdhar/cmm
Description-Content-Type: text/x-rst

=============
Tools for JAX
=============

.. role:: bash(code)
    :language: bash

.. role:: python(code)
   :language: python

This repository implements a variety of tools for the differential programming library
`JAX <https://github.com/google/jax>`_.  It includes:

- A dataclass decorator that facilitates defining JAX trees, provides convenient text display, and
  provides a mypy plugin

- A custom VJP decorator that supports both static and non-differentiable arguments

- A random number generator class

- JAX tree registration for `NetworkX <https://networkx.github.io/>`_ graph types

- Testing tools that automatically produce testing code

See the `documentation <https://neilgirdhar.github.io/tjax/tjax/index.html>`_.

Contribution guidelines
=======================

- Conventions: PEP8.

- How to run tests: :bash:`pytest .`

- How to clean the source:

  - :bash:`isort tjax`
  - :bash:`pylint tjax`
  - :bash:`mypy tjax`
  - :bash:`flake8 tjax`

