Metadata-Version: 2.1
Name: fodnet
Version: 1.0.1
Summary: FOD-Net Reimplementation.
Author: Matthew Lyon
Author-email: matthewlyon18@gmail.com
License: MIT License
Classifier: Programming Language :: Python
Classifier: Operating System :: Unix
Classifier: Operating System :: MacOS
Classifier: Operating System :: Microsoft :: Windows :: Windows 10
Requires-Python: >=3.8
Description-Content-Type: text/markdown
License-File: LICENSE
Requires-Dist: torch>=2.0.0
Requires-Dist: lightning>=2.0.0
Requires-Dist: numpy
Requires-Dist: einops
Requires-Dist: npy-patcher

# FODNet

FOD-Net reimplementation with training and inference pipeline. This module uses the FODNet model originally implemented [here](https://github.com/ruizengalways/FOD-Net).

If you use this code for your research, please cite:

FOD-Net: A Deep Learning Method for Fiber Orientation Distribution Angular Super Resolution.<br>
[Rui Zeng](https://sites.google.com/site/ruizenghomepage/), Jinglei Lv, He Wang, Luping Zhou, Michael Barnett, Fernando Calamante\*, Chenyu Wang\*. In [Medical Image Analysis](https://www.sciencedirect.com/science/article/abs/pii/S1361841522000822). (* equal contributions) [[Bibtex]](bib.txt).

## Requirements

This module requires the following python packages:
- `torch >= 2.0.0`
- `lightning >= 2.0.0`
- `numpy`
- `einops`
- `npy-patcher`
These will be installed upon installation of this package, however it is recommended to follow the instructions for installing PyTorch independently before installing this package, to ensure correct hardware optimizations are enabled.

## Installation

```
pip install fodnet
```

## Training

Follow the instructions below on how to train the FODNet model.


### Data Preprocessing

This training pipeline requires data to be saved in `.npy` format. Additionally the spherical harmonic dimension must be the first dimension within each 4D array. This is because this module uses [npy-patcher](https://github.com/m-lyon/npy-cpp-patches) to extract training patches at runtime. Below is an example on how to convert `NIfTI` files into `.npy` using [nibabel](https://nipy.org/nibabel/).

```python
import numpy as np
import nibabel as nib

img = nib.load('/path/to/fod.nii.gz')
data = np.asarray(img.dataobj, dtype=np.float32)  # Load FOD data into memory
data = data.transpose(3, 0, 1, 2)  # Move the SH dimension to 0
np.save('/path/to/fod.npy', data, allow_pickle=False)  # Save in npy format. Ensure this is on an SSD.
```

**N.B.** *Training patches are read lazily from disk, therefore it is **highly** recommended to store the training data on an SSD type device, as an HDD will bottleneck the training process when data loading.*

### Training

```python
import lightning.pytorch as pl

from fodnet.core.model import FODNetLightningModel
from fodnet.core.dataset import Subject, FODNetDataModule

# Collect dataset filepaths
subj1 = Subject('/path/to/lowres_fod1.npy', '/path/to/highresres_fod1.npy', '/path/to/mask1.npy')
subj2 = Subject('/path/to/lowres_fod2.npy', '/path/to/highresres_fod2.npy', '/path/to/mask2.npy')
subj3 = Subject('/path/to/lowres_fod3.npy', '/path/to/highresres_fod3.npy', '/path/to/mask3.npy')

# Create DataModule instance. This is a thin wrapper around `pl.LightningDataModule`.
data_module = FODNetDataModule(
    train_subjects=(subj1, subj2),
    val_subjects=(subj3),
    batch_size=16, # Batch size of each device
    num_workers=8, # Number of CPU workers that load the data
)

# Load FODNet lightning model
model = FODNetLightningModel()

# Create `pl.Trainer` instance. `FODNetDataModule` is usable in DDP distributed training strategy.
trainer = pl.Trainer(devices=1, accelerator='gpu', epochs=100)

# Start training
trainer.fit(model, data_module)
```

#### Customization

This implemenation uses a different training optimizer, loss, and learning rate than that used in the [original implementation](https://github.com/ruizengalways/FOD-Net). In particular we use `AdamW`, `L1 Loss`, and `0.003` respectively.

Changing these hyperparameters is straightforward. Simply create a new class that inherits the `FODNetLightningModel`, and modify the properties/methods below. Use this class instead of `FODNetLightningModel` when training.

```python
class MyCustomModel(FODNetLightningModel):

    @property
    def loss_func(self):
        '''Different loss function'''
        return torch.nn.functional.mse_loss
    
    def configure_optimizers(self):
        '''Different Optimizer and learning rate'''
        optimizer = torch.optim.SGD(self.parameters(), lr=1e-5)
        return optimizer
```

### Prediction

Coming soon.
