Metadata-Version: 2.1
Name: exciting_environments
Version: 0.1.1
Summary: Physical differential equations wrapped into Gymnasium environments
Author-email: Oliver Schweins <oliverjs@mail.uni-paderborn.de>, Hendrik Vater <vater@lea.uni-paderborn.de>
Project-URL: Homepage, https://excitingsystems.github.io/exciting-environments/
Classifier: Programming Language :: Python :: 3
Classifier: License :: OSI Approved :: MIT License
Classifier: Operating System :: OS Independent
Requires-Python: >=3.11
Description-Content-Type: text/markdown
License-File: LICENSE
Requires-Dist: matplotlib ==3.9.0
Requires-Dist: jax ==0.4.28
Requires-Dist: jaxlib ==0.4.28
Requires-Dist: chex ==0.1.86
Requires-Dist: numpy ==1.26.4
Requires-Dist: scipy ==1.13.1
Requires-Dist: pytest ==8.2.1
Requires-Dist: pytest-cov ==5.0.0
Requires-Dist: diffrax ==0.5.1
Requires-Dist: jax-dataclasses ==1.6.0

# exciting-environments

## Overview
The exciting-environments package is a toolbox for the simulation of physical [differential equations](https://en.wikipedia.org/wiki/Differential_equation) wrapped into [Gymnasium](https://github.com/Farama-Foundation/Gymnasium) inspired environments using [Jax](https://github.com/google/jax). Due to the just-in-time compilation native to JAX, this type of implementation offers great advantages in terms of simulation speed.

## Getting Started

A basic routine is as simple as:
```py
import jax.numpy as jnp
import exciting_environments as excenvs

env = excenvs.make("Pendulum-v0", batch_size=5, action_constraints={"torque": 15}, tau=2e-2) 
obs, state = env.reset()

actions = jnp.linspace(start=-1, stop=1, num=1000)[None, :, None]
actions = actions.repeat(env.batch_size, axis=0)

observations = []
observations.append(obs)

for idx in range(actions.shape[1]):
    obs, reward, terminated, truncated, state = env.vmap_step(
        state, actions[:, idx, :]
    )
    observations.append(obs)
observations = jnp.stack(observations, axis=1)

print("actions shape:", actions.shape)
print("observations shape:", observations.shape)
```

which produces $5$ identical trajectories in parallel:

![](https://github.com/ExcitingSystems/exciting-environments/blob/main/fig/excenvs_pendulum_simulation_example.png?raw=true)

alternatively, simulate full trajectories:

```py
import jax.numpy as jnp
import exciting_environments as excenvs
import diffrax

env = excenvs.make(
    "Pendulum-v0", solver=diffrax.Tsit5(), batch_size=5, action_constraints={"torque": 15}, tau=2e-2
) 
obs, state = env.reset()

actions = jnp.linspace(start=-1, stop=1, num=2000)[None, :, None]
actions = actions.repeat(env.batch_size, axis=0)

observations, rewards, terminations, truncations, last_state = env.vmap_sim_ahead(
    init_state=state,
    actions=actions,
    obs_stepsize=env.tau,
    action_stepsize=env.tau
)

print("actions shape:", actions.shape)
print("observations shape:", observations.shape)
```

which produces $5$ identical trajectories in parallel as well:

![](https://github.com/ExcitingSystems/exciting-environments/blob/main/fig/excenvs_pendulum_simulation_example_advanced.png?raw=true)

Note that in this case the Tsit5 ODE solver instead of the default explicit Euler is used.
All solvers used here are from the diffrax library (https://docs.kidger.site/diffrax/).
