#!/usr/bin/env python

import nibabel as nib
from deepbrain import Extractor
import sys
import time
import argparse
import numpy as np
import os


def load_arguments():
    parser = argparse.ArgumentParser(description="Extract brain tissue from wT1 MRI (i.e. skull stripping)")
    parser.add_argument("-i", "--input", help="Path to nifti brain image")
    parser.add_argument("-o", "--output", help="Directory where the script will output the mask and the brain nifti images")

    args = parser.parse_args()
    return args


def run(img_path, output_dir, p=0.5):
    img = nib.load(img_path)

    affine = img.affine
    img = img.get_fdata()

    extractor = Extractor()

    now = time.time()
    prob = extractor.run(img)
    print("Extraction time: {0:.2f} secs.".format(time.time() - now))
    mask = prob > p
    brain_mask = (1 * mask).astype(np.uint8)
    brain_mask = nib.Nifti1Image(brain_mask, affine)
    nib.save(brain_mask, os.path.join(output_dir, "brain_mask.nii"))

    brain = img[:]
    brain[~mask] = 0
    brain = nib.Nifti1Image(brain, affine)
    nib.save(brain, os.path.join(output_dir, "brain.nii"))


if __name__ == "__main__":
    args = load_arguments()

    if not args.input:
        print("Should specify input. See -h for instructions.")
        sys.exit(1)

    if not args.output:
        print("Should specify output. See -h for instructions.")
        sys.exit(1)

    os.makedirs(args.output, exist_ok=True)
    run(args.input, args.output)
