Metadata-Version: 2.1
Name: tjax
Version: 0.15.2
Summary: Tools for JAX.
Home-page: https://github.com/NeilGirdhar/tjax
License: MIT
Author: Neil Girdhar
Author-email: mistersheik@gmail.com
Requires-Python: >=3.7.1,<3.10
Classifier: Development Status :: 4 - Beta
Classifier: Intended Audience :: Science/Research
Classifier: License :: OSI Approved :: MIT License
Classifier: Operating System :: OS Independent
Classifier: Programming Language :: Python
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.7
Classifier: Programming Language :: Python :: 3.8
Classifier: Programming Language :: Python :: 3.9
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Classifier: Topic :: Software Development :: Libraries :: Python Modules
Classifier: Topic :: Utilities
Classifier: Typing :: Typed
Requires-Dist: colorful (>=0.5.4,<0.6.0)
Requires-Dist: jax (>=0.2.21,<0.3.0)
Requires-Dist: matplotlib (>=3.3,<4.0)
Requires-Dist: networkx (>=2.4,<3.0)
Requires-Dist: numpy (>=1.21)
Requires-Dist: optax (>=0.1)
Requires-Dist: yapf (>=0.31)
Project-URL: Repository, https://github.com/NeilGirdhar/tjax
Description-Content-Type: text/x-rst

=============
Tools for JAX
=============

.. role:: bash(code)
    :language: bash

.. role:: python(code)
   :language: python

This repository implements a variety of tools for the differential programming library
`JAX <https://github.com/google/jax>`_.

----------------
Major components
----------------

Tjax's major components are:

- A `dataclass <https://github.com/NeilGirdhar/tjax/blob/master/tjax/_src/dataclasses>`_ and `mypy_plugin <https://github.com/NeilGirdhar/tjax/blob/master/tjax/mypy_plugin.py>`_ decorator :python:`dataclasss` that facilitates defining structured JAX objects (so-called "pytrees"), which benefits from:

  - the ability to mark fields as static (not available in `chex.dataclass`),
  - a MyPy plugin, and
  - a display method that produces formatted text according to the tree structure.

- A `fixed_point <https://github.com/NeilGirdhar/tjax/blob/master/tjax/_src/fixed_point>`_ finding library heavily based on `fax <https://github.com/gehring/fax>`_.  Our
  library

  - supports stochastic iterated functions, and
  - uses dataclasses instead of closures to avoid leaking JAX tracers.

- A `shim <https://github.com/NeilGirdhar/tjax/blob/master/tjax/_src/gradient>`_ for the gradient transformation library `optax <https://github.com/deepmind/optax>`_ that supports:


  - easy differentiation and vectorization of “gradient transformation” (learning rule) parameters,
  - gradient transformation objects that can be passed *dynamically* to jitted functions, and
  - generic type annotations.

----------------
Minor components
----------------

Tjax also includes:

- A pretty printer :python:`print_generic` for aggregate and vector types, including dataclasses.  (See
  `display <https://github.com/NeilGirdhar/tjax/blob/master/tjax/_src/display.py>`_.)

- Versions of :python:`custom_vjp` and :python:`custom_jvp` that support being used on methods.
  (See `shims <https://github.com/NeilGirdhar/tjax/blob/master/tjax/_src/shims.py>`_.)

- Tools for working with cotangents.  (See
  `cotangent_tools <https://github.com/NeilGirdhar/tjax/blob/master/tjax/_src/cotangent_tools.py>`_.)

- A random number generator class :python:`Generator`.  (See `generator <https://github.com/NeilGirdhar/tjax/blob/master/tjax/_src/generator.py>`_.)

- JAX tree registration for `NetworkX <https://networkx.github.io/>`_ graph types.  (See
  `graph <https://github.com/NeilGirdhar/tjax/blob/master/tjax/_src/graph.py>`_.)

- Leaky integration :python:`leaky_integrate` and Ornstein-Uhlenbeck process iteration
  :python:`diffused_leaky_integrate`.  (See `leaky_integral <https://github.com/NeilGirdhar/tjax/blob/master/tjax/_src/leaky_integral.py>`_.)

- An improved version of :python:`jax.tree_util.Partial`.  (See `partial <https://github.com/NeilGirdhar/tjax/blob/master/tjax/_src/partial.py>`_.)

- A Matplotlib trajectory plotter :python:`PlottableTrajectory`.  (See `plottable_trajectory <https://github.com/NeilGirdhar/tjax/blob/master/tjax/_src/plottable_trajectory.py>`_.)

- A testing function :python:`assert_tree_allclose` that automatically produces testing code.  And, a related
  function :python:`tree_allclose`.  (See `testing <https://github.com/NeilGirdhar/tjax/blob/master/tjax/_src/testing.py>`_.)

- Basic tools like :python:`divide_where`.  (See `tools <https://github.com/NeilGirdhar/tjax/blob/master/tjax/_src/tools.py>`_.)

Also, see the `documentation <https://neilgirdhar.github.io/tjax/tjax/index.html>`_.

-----------------------
Contribution guidelines
-----------------------

- Conventions: PEP8.

- How to run tests: :bash:`pytest .`

- How to clean the source:

  - :bash:`isort tjax`
  - :bash:`pylint tjax`
  - :bash:`mypy tjax`
  - :bash:`flake8 tjax`

