Metadata-Version: 2.1
Name: pytorch-cpr
Version: 0.2.0
Summary: Constrained Parameter Regularization for PyTorch
Project-URL: Homepage, https://github.com/automl/CPR
Author-email: Jörg Franke <frankej@cs.uni-freiburg.de>
License-File: LICENSE
Classifier: License :: OSI Approved :: Apache Software License
Classifier: Operating System :: OS Independent
Classifier: Programming Language :: Python :: 3
Requires-Python: >=3.10
Provides-Extra: full
Requires-Dist: pytorch>=2.0.0; extra == 'full'
Description-Content-Type: text/markdown


# Constrained Parameter Regularization

This repository contains the PyTorch implementation of [**Constrained Parameter Regularization**](https://arxiv.org/abs/2311.09058).




## Install

```bash
pip install pytorch-cpr
```

## Getting started

### Usage of `apply_CPR` Optimizer Wrapper

The `apply_CPR` function is a wrapper designed to apply CPR (Constrained Parameter Regularization) to a given optimizer by first creating parameter groups and the wrapping the actual optimizer class. 

#### Arguments

- `model`: The PyTorch model whose parameters are to be optimized.
- `optimizer_cls`: The class of the optimizer to be used (e.g., `torch.optim.Adam`).
- `kappa_init_param`: Initial value for the kappa parameter in CPR depending on tge initialization method.
- `kappa_init_method` (default `'warm_start'`): The method to initialize the kappa parameter. Options include `'warm_start'`, `'uniform'`, and `'dependent'`.
- `reg_function` (default `'l2'`): The regularization function to be applied. Options include `'l2'` or `'std'`.
- `kappa_adapt` (default `False`): Flag to determine if kappa should adapt during training.
- `kappa_update` (default `1.0`): The rate at which kappa is updated in the Lagrangian method.
- `normalization_regularization` (default `False`): Flag to apply regularization to normalization layers.
- `bias_regularization` (default `False`): Flag to apply regularization to bias parameters.
- `embedding_regularization` (default `False`): Flag to apply regularization to embedding parameters.
- `**optimizer_args`: Additional optimizer arguments to pass to the optimizer class.

#### Example usage

```python
import torch
from pytorch-cpr import apply_CPR

model = YourModel()
optimizer = apply_CPR(model, torch.optim.Adam, kappa_init_param=1000, kappa_init_method='warm_start',
                              lr=0.001, betas=(0.9, 0.98))
```


## Run examples

We provide scripts to replicate the experiments from our paper. Please use a system with at least 1 GPU. Install the package and the requirements for the example:

```bash
python3 -m venv venv
source venv/bin/activate
pip install -r examples/requirements.txt
pip install pytorch-cpr
``` 


### Modular Addition / Grokking Experiment

The grokking experiment should run within a few minutes. The results will be saved in the `grokking` folder.
To replicate the results in the paper, run variations with the following arguments:

####  For AdamW:
```bash
python examples/train_grokking_task.py --optimizer adamw --weight_decay 0.1
```

####  For Adam + Rescaling:
```bash
python examples/train_grokking_task.py --optimizer adamw --weight_decay 0.0 --rescale 0.8
```

####  For AdamCPR with L2 norm as regularization function:
```bash
python examples/train_grokking_task.py --optimizer adamcpr --kappa_init_method dependent --kappa_init_param 0.8
```



### Image Classification Experiment

The CIFAR-100 experiment should run within 20-30 minutes. The results will be saved in the `cifar100` folder.

####  For AdamW:
```bash
python examples/train_resnet.py --optimizer adamw --lr 0.001 --weight_decay 0.001
```

####  For Adam + Rescaling:
```bash
python examples/train_resnet.py --optimizer adamw --lr 0.001 --weight_decay 0 --rescale_alpha 0.8
```

####  For AdamCPR with L2 norm as regularization function and kappa initialization depending on the parameter initialization:
```bash
python examples/train_resnet.py --optimizer adamcpr --lr 0.001 --kappa_init_method dependent --kappa_init_param 0.8
```

####  For AdamCPR with L2 norm as regularization function and kappa initialization with warm start:
```bash
python examples/train_resnet.py --optimizer adamcpr --lr 0.001 --kappa_init_method warm_start --kappa_init_param 1000
```



## Citation

Please cite our paper if you use this code in your work:
    
```
@misc{franke2023cpr,
      title={Constrained Parameter Regularization}, 
      author={Jörg K. H. Franke and Michael Hefenbrock and Gregor Koehler and Frank Hutter},
      year={2023},
      eprint={2311.09058},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}
```

