#!/usr/bin/python3
# The preprocessing pipeline of Rivulet2
import numpy as np
import argparse
import os
import random
import string
import subprocess
from pathlib import Path
from rivuletpy.utils.io import loadimg, writetiff3d
from filtering.thresholding import suppress, rescale
import skfmm
from scipy.ndimage.filters import gaussian_filter, median_filter
from scipy.ndimage.interpolation import zoom
from filtering.morphology import ssmdt

try:
    from skimage import filters
except ImportError:
    from skimage import filter as filters


def anisotropic_enhance(img):
    '''
    Using the anisodiff_littlequick plugin of Vaa3D
    to perform the preprocessing
    To use this option, add the following 2 lines in ~/.bashrc
    export V3DPATH=PATH/2/Vaa3Droot;
    export LD_LIBRARY=$LD_LIBRARY:$V3DPATH;
    '''

    my_env = os.environ.copy()
    N = 10
    tmpfile = ''.join(
        random.SystemRandom().choice(string.ascii_uppercase + string.digits)
        for _ in range(N)) + '.tif'
    writetiff3d(tmpfile, img)  # Save the img to a tmp path
    cmd = subprocess.Popen([
        my_env['V3DPATH'] + '/vaa3d', '-x', 'anisodiff_littlequick', '-f',
        'anisodiff_littlequick_func', '-i', tmpfile
    ])
    cmd.wait()
    # Check if the little quick was successful
    if not Path(tmpfile + '_anisodiff.raw').is_file():
        raise Exception('V3D Anisotropic Diffusion Failed!')

    # Convert .raw to tif with v3d
    tmptif = tmpfile + '_anisodiff.raw.tif'
    cmd = subprocess.Popen([
        my_env['V3DPATH'] + '/vaa3d', '-x', 'convert_file_format', '-f',
        'convert_format', '-i', tmpfile + '_anisodiff.raw', '-o', tmptif
    ])
    cmd.wait()

    # Load it back to python
    result = loadimg(tmptif)
    os.remove(tmpfile)
    os.remove(tmpfile + '_anisodiff.raw')
    os.remove(tmptif)
    return result


if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        description='Arguments to perform a simple preprocessing pipeline.')
    parser.add_argument(
        '-f',
        '--file',
        type=str,
        default=None,
        required=True,
        help='The input file. A image file (*.tif, *.nii, *.mat).')
    parser.add_argument(
        '-o',
        '--out',
        type=str,
        default=None,
        required=False,
        help='The name of the output file')

    # Arguments for filtering
    parser.add_argument(
        '-t',
        '--threshold',
        type=float,
        default=[0, ],
        nargs='+',
        help="""The thresohld to suppress the weak signals.
                  Voxels values  <= threshold will be set 0)."""
    )
    parser.add_argument(
        '--pipeline',
        type=str,
        default='TA',
        help="""A string with the pipeline components to run in order.
                     T for thresholding; A for anisotropic diffusion (VED);
                     G for Gaussian Filtering;
                     M for median filtering."""
    )
    parser.add_argument(
        '--sigma',
        type=float,
        default=[3., ],
        nargs='+',
        required=False,
        help='The sigma value used for gaussian filter. Default 3.0')
    parser.add_argument(
        '--median_size',
        type=float,
        default=[3., ],
        nargs='+',
        required=False,
        help='The window size used for median filter. Default 3.0')
    parser.add_argument(
        '--ssmiter',
        type=int,
        default=20,
        required=False,
        help='The window size used for median filter. Default 3.0')
    parser.add_argument(
        '-z',
        '--zoom_factor',
        type=float,
        default=1.,
        help="""The factor to zoom the image to speed up the whole thing.
                      Default 1."""
    )
    args = parser.parse_args()

    img = loadimg(args.file)
    if args.zoom_factor != 1:
        img = zoom(img, args.zoom_factor)  # Zoom to speed up
    imgtype = img.dtype
    lastthreshold = args.threshold[0]
    for c in args.pipeline:
        if c == 'T':
            if len(args.threshold) != 0:
                lastthreshold = args.threshold.pop()
            img = suppress(img, lastthreshold)
        elif c == 'G':
            if len(args.sigma) != 0:
                lastsigma = args.sigma.pop()
            img = gaussian_filter(img, lastsigma)
        elif c == 'M':
            if len(args.median_size) != 0:
                lastmediansz = args.median_size.pop()
            img = median_filter(img, lastmediansz)
        elif c == 'A':
            img = anisotropic_enhance(img)
        elif c == 'O':
            threshold = filters.threshold_otsu(img)
            img[img <= threshold] = 0
        elif c == 'S':
            dt = skfmm.distance((img > 0).astype('int'), dx=5e-2)
            img = ssmdt(dt, ssmiter=args.ssmiter)
            # from matplotlib import pyplot as plt
            # plt.imshow(img.max(axis=-1))
            # plt.show()
            # img = rescale(img)
            img[img > 0] = 180
            # img[img <= 1] = 0
        else:
            raise Exception(
                'Pipeline %s not defined. Valid options are T/G/M/A')

    if args.out:
        outfile = args.out
    else:
        basename, _ = os.path.splitext(args.file)
        outfile = basename + '.pp.tif'

    # img = np.flipud(img) # Needed to show in v3d
    if args.zoom_factor != 1.:
        img = zoom(img, 1 / args.zoom_factor)  # Zoom back

    img = np.floor(img)
    # from matplotlib import pyplot as plt
    # plt.imshow(img.max(axis=-1))
    # plt.show()
    img = rescale(img)
    print('Saving Image to %s' % outfile)
    writetiff3d(outfile, img.astype(imgtype))
    img[img > 0] = 180
    writetiff3d(os.path.splitext(outfile)[0] + '.seg.tif', img.astype(imgtype))
    print('--rpp finished')
