Metadata-Version: 2.1
Name: gateloop
Version: 0.0.2
Project-URL: Homepage, https://github.com/axrwl/gateloop
Classifier: Programming Language :: Python :: 3
Classifier: License :: OSI Approved :: MIT License
Requires-Python: >=3.8
Description-Content-Type: text/markdown
License-File: LICENSE

# GateLoop

Implementation of [arxiv.org/abs/2311.01927](https://arxiv.org/abs/2311.01927) in JAX based on [lucidrains/gateloop-transformer](https://github.com/lucidrains/gateloop-transformer).

### Installation
`pip install gateloop`

### Usage
```py
from gateloop import GateLoopTransformer

key = jax.random.PRNGKey(0)
model = GateLoopTransformer(
    num_tokens = 1000, 
    dim = 512, 
    depth = 12
)
params = model.init(key, jnp.ones((1, 10), jnp.int32))
ids = jax.random.randint(key, (1, 10), 0, 1000)
logits = model.apply(params, ids)
```
