Metadata-Version: 2.3
Name: statedict2pytree
Version: 0.5.2
Summary: Converts torch models into PyTrees for Equinox
Author-email: "Artur A. Galstyan" <mail@arturgalstyan.dev>
Requires-Python: ~=3.10
Requires-Dist: beartype
Requires-Dist: equinox
Requires-Dist: flask
Requires-Dist: jax
Requires-Dist: jaxlib
Requires-Dist: jaxonmodels
Requires-Dist: jaxtyping
Requires-Dist: loguru
Requires-Dist: penzai
Requires-Dist: pydantic
Requires-Dist: torch
Requires-Dist: torchvision
Requires-Dist: typing-extensions
Provides-Extra: dev
Requires-Dist: mkdocs; extra == 'dev'
Requires-Dist: nox; extra == 'dev'
Requires-Dist: pre-commit; extra == 'dev'
Requires-Dist: pytest; extra == 'dev'
Provides-Extra: examples
Requires-Dist: jaxonmodels; extra == 'examples'
Description-Content-Type: text/markdown

# statedict2pytree

![statedict2pytree](statedict2pytree.png "A ResNet demo")

## Important

This package is still in its infancy and hihgly experimental! The code works, but it's far from perfect. With more and more iterations, it will eventually become stable and well tested.
PRs and other contributions are *highly* welcome! :)

## Info

The goal of this package is to simplify the conversion from PyTorch models into JAX PyTrees (which can be used e.g. in Equinox). The way this works is by putting both models side my side and aligning the weights in the right order. Then, all statedict2pytree is doing, is iterating over both lists and matching the weight matrices.

Usually, if you _declared the fields in the same order as in the PyTorch model_, you don't have to rearrange anything -- but the option is there if you need it.

(Theoretically, you can rearrange the model in any way you like - e.g. last layer as the first layer - as long as the shapes match!)

## Shape Matching? What's that?

Currently, there is no sophisticated shape matching in place. Two matrices are considered "matching" if the product of their shape match. For example:

(8, 1, 1) and (8, ) match, because (8 _ 1 _ 1 = 8)

## Get Started

### Installation

Run

```bash
pip install statedict2pytree
```

### Docs

Documentation will appear as soon as I have all the necessary features implemented. Until then, check out the "main.py" file for a better example.
