Metadata-Version: 2.1
Name: torch-training-loop
Version: 0.1.0
Summary: Simple Keras-inspired Training Loop for Pytorch.
Home-page: https://github.com/beekill95/torch-training-loop
License: MIT
Author: beekill
Author-email: nguyenmbquan95@gmail.com
Requires-Python: >=3.8,<4.0
Classifier: License :: OSI Approved :: MIT License
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.8
Classifier: Programming Language :: Python :: 3.9
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Topic :: Scientific/Engineering
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Classifier: Topic :: Software Development
Classifier: Topic :: Software Development :: Libraries
Classifier: Topic :: Software Development :: Libraries :: Python Modules
Requires-Dist: numpy
Requires-Dist: pandas
Requires-Dist: torch
Requires-Dist: torcheval
Requires-Dist: tqdm
Project-URL: CHANGELOG, https://github.com/beekill95/torch-training-loop/blob/main/CHANGELOG.md
Project-URL: Repository, https://github.com/beekill95/torch-training-loop
Description-Content-Type: text/markdown

[![Tests](https://github.com/beekill95/torch-training-loop/workflows/Tests/badge.svg)](https://github.com/beekill95/torch-training-loop/actions?query=workflow:"Tests")
[![License](https://img.shields.io/badge/License-MIT-blue)](#license)
[![PyPI - Version](https://img.shields.io/pypi/v/torch-training-loop)](https://pypi.org/project/torch-training-loop/)
![PyPI - Python Version](https://img.shields.io/pypi/pyversions/torch-training-loop)

⚠️The package is under development, expect bugs and breaking changes!

# Torch Training Loop

Simple Keras-inspired Training Loop for Pytorch.

## Installation

> pip install torch-training-loop

## Features

* Simple API for training Torch models;
* Support training `DataParallel` models;
* Support Keras-like callbacks for logging metrics to Tensorboard, model checkpoint,
and early stopping;
* Show training & validation progress via `tqdm`;
* Display metrics during training & validation via `torcheval`.

## Usage

This package consists of two main classes for training Torch models:
`TrainingLoop` and `SimpleTrainingStep`.
In order to train a torch model, you need to initiate these two classes:

```python
import torch
from torch.optim import Adam
from torcheval.metrics import MulticlassAccuracy
from training_loop import TrainingLoop, SimpleTrainingStep
from training_loop.callbacks import EarlyStopping

model = ...
# Support training DataParallel models.
# model = DataParallel(model)

train_dataloader = ...
val_dataloader = ...

loop = TrainingLoop(
    model,
    step=SimpleTrainingStep(
        optimizer_fn=lambda params: Adam(params),
        loss=torch.nn.CrossEntropyLoss(),
        metrics=('accuracy', MulticlassAccuracy(num_classes=10)),
    ),
    device='cuda',
)
loop.fit(
    train_dataloader,
    val_dataloader,
    epochs=10,
    callbacks=[
        EarlyStopping(monitor='val_loss', mode='min', patience=20),
    ],
)
```

In the above example, initializing the `SimpleTrainingStep` class and
calling the `fit()` method of the `TrainingLoop` class are very similar to that of Keras API.
You can find more examples and documentation in the source code and in the `examples` folder.

## License

Distributed under the MIT License. See `LICENSE.txt` for more information.

