#!/usr/bin/python3
import json
import os
import subprocess
import time
import argparse
import threading
import numpy as np
from rivuletpy.utils.io import loadswc, saveswc, loadimg, writetiff3d
from scipy.ndimage.interpolation import zoom


class imgThread(threading.Thread):
    def __init__(self,
                 imgid,
                 img,
                 method,
                 root,
                 sema,
                 centos=False,
                 zoom_factor=1.):
        threading.Thread.__init__(self)
        self.imgid = imgid
        self.img = img
        self.method = method
        self.root = root
        self.sema = sema
        self.centos = centos
        self.zoom_factor = zoom_factor
        self._time_passed = 0

    def run(self):
        print("Starting " + self.method + ' on ' + self.name)
        self.sema.acquire()
        self._time_passed = doimg(self.imgid, self.img, self.method, self.root,
                                  self.centos, self.zoom_factor)
        self.sema.release()
        print('Exiting ' + self.name + ' with image ' + self.imgid)

    def join(self):
        threading.Thread.join(self)
        return self._time_passed


def makepp(json_path):
    # Make the according pp.json file
    with open(json_path) as f:
        dataidx = json.load(f)
        for dataset in dataidx["data"]:
            for img in dataidx["data"][dataset]:
                inpath = dataidx["data"][dataset][img]["imagepath"]
                inpath, _ = os.path.splitext(inpath)
                inpath += '.pp.tif'
                dataidx["data"][dataset][img]["imagepath"] = inpath
                dataidx["data"][dataset][img]["misc"]["threshold"] = 0

    jpath, _ = os.path.splitext(json_path)
    jpath += '.pp.json'

    with open(jpath, 'w') as f:
        json.dump(dataidx, f)


def matlab_cmd(matlab_script):
    cmd = "matlab -nodesktop -nosplash -r "
    cmd += " \"%s\"" % matlab_script
    return cmd


def doimg(imgid, img, method, root, centos=False, zoom_factor=1.):
    print('---- Working on image', imgid, 'with method ', method)
    if 'threshold' in img['misc']:
        threshold = img['misc']['threshold']
    else:
        threshold = 0

    if zoom_factor == 1.:
        inpath = img['imagepath']
    else:
        imgmat = loadimg(os.path.join(root, img['imagepath']))
        dtype = imgmat.dtype
        imgmat = zoom(imgmat, zoom_factor)
        inpath = os.path.join(root,
                              img['imagepath'] + '.%.2fzoom.tif' % zoom_factor)
        writetiff3d(inpath, imgmat.astype(dtype))

    # Make the v3d safe cmd
    my_env = os.environ.copy()
    if centos:
        v3dcmd = "export DISPLAY=:30;Xvfb :30 -auth /dev/null & %s/vaa3d" % my_env[
            'V3DPATH']
    else:
        v3dcmd = "%s/vaa3d" % my_env['V3DPATH']

    if method == 'rivulet2':
        out = img['out'] if 'out' in img else img['imagepath'] + '.r2.swc'
        cmd = 'python3 rivulet2 --file %s --out %s' % (os.path.join(
            root, inpath), os.path.join(root, out))
        cmd += ' --threshold ' + str(threshold)
        cmd += ' --no-msfm '  # For now
        cmd += ' --silence --no-clean'
        # Parameters for different methods
        if 'misc' in img and method in img['misc']:
            speed = img['misc'][method]['speed'] if 'speed' in img['misc'][
                method] else 'dt'
            ssmiter = img['misc'][method]['ssmiter'] if 'ssmiter' in img[
                'misc'][method] else 40
            cmd += ' --speed %s --ssmiter %d' % (speed, ssmiter)
    elif method == 'rivulet1':
        if 'misc' in img and 'threshold' in img['misc']:
            threshold = img['misc']['threshold']
        else:
            threshold = 0
        out = img['out'] if 'out' in img else img['imagepath'] + '.r1.swc'
        matlab_script = "cd %s; addpath(genpath(\'%s\'));" % (
            my_env['RNTPATH'], my_env['RNTPATH'])
        matlab_script += "I=loadtif(\'%s\'); I=I>%d;" % (os.path.join(
            root, inpath), threshold)
        matlab_script += 'swc=trace(I, false, 0.98, false, 15, false, true, 1.3, 8, false, 0, false, false);'
        matlab_script += "saveswc(swc, \'%s\');" % os.path.join(root, out)
        matlab_script += "exit;"
        print('Matlab script to run:', matlab_script)
        cmd = matlab_cmd(matlab_script)
    elif method == 'rpp':
        out = os.path.splitext(img['out'] if 'out' in img else img[
            'imagepath'])[0] + '.pp.tif'
        cmd = 'python3 rpp --file %s --out %s' % (os.path.join(root, inpath),
                                                  os.path.join(root, out))
        # Parameters for different methods
        if 'misc' in img and method in img['misc']:
            pipeline = img['misc'][method]['pipeline'] if 'pipeline' in img[
                'misc'][method] else 'T'
            sigma = img['misc'][method]['sigma'] if 'sigma' in img['misc'][
                method] else 1.0
            median_size = img['misc'][method][
                'median_size'] if 'median_size' in img['misc'][method] else 2.0
            ssmiter = img['misc'][method]['ssmiter'] if 'ssmiter' in img[
                'misc'][method] else 20
            cmd += """ --sigma %f --median_size %f --threshold %f --pipeline %s --ssmiter %d""" % (
                sigma, median_size, threshold, pipeline, ssmiter)
    elif 'misc' in img and method == 'app2':
        out = img['out'] if 'out' in img else img['imagepath'] + '.app2.swc'
        cmd = v3dcmd + ' -x vn2 -f app2 -i ' + os.path.join(
            root, inpath) + ' -o ' + os.path.join(
                root,
                out) + ' -p  NULL 0 0 ' + str(threshold) + ' 1 1 1 4 0 0 0'
    elif method == 'smart':
        # The smart tracing plugin totall ignores the -o option.
        # So we keep consistent with it
        out = img['out'] if 'out' in img else img[
            'imagepath'] + '_smartTracing.swc'
        cmd = v3dcmd + ' -x smartTrace -f smartTrace -i ' + os.path.join(
            root, inpath) + ' -o ' + os.path.join(root, out)
    else:
        raise Exception('Method %s not implemented yet' % method)

    print(cmd)
    start_time = time.time()
    subprocess.run(cmd, shell=True)
    time_passed = time.time() - start_time

    # When it's finished, zoom back the swc
    if method != "rpp" and zoom_factor != 1.:
        swc = loadswc(os.path.join(root, out))
        swc[:, 2:5] *= 1. / zoom_factor
        saveswc(os.path.join(root, out), swc)

    return time_passed


def tracejsonfile(fpath,
                  pbsfile=None,
                  method='rivulet2',
                  nthread=1,
                  centos=False,
                  zoom_factor=1.):
    sema = threading.BoundedSemaphore(value=nthread)

    with open(fpath) as f:
        dataidx = json.load(f)

    root = dataidx['rootpath']
    if not os.path.isabs(root):  # Make the root path absolute if it is not yet
        root = os.path.join(os.path.dirname(os.path.abspath(fpath)), root)

    thpool = []
    for datasetname, dataset in dataidx['data'].items():
        print('-- Working on dataset', datasetname)
        for imgid, img in dataset.items():
            th = imgThread(imgid, img, method, root, sema, centos, zoom_factor)
            th.start()
            thpool.append(th)
    times = []
    for th in thpool:
        time_passed = th.join()
        times.append(time_passed)

    # Remove the annoying intermedia shits
    os.system("find " + root + " -type f -name  \'*ini*\' -delete")
    os.system("find " + root + " -type f -name  \'*tmp*\' -delete")

    if args.run == 'rpp':
        makepp(fpath)

    total_time = sum(times)
    print('--rjson finished')
    print('--Total time running algorithm: %f seconds' % total_time)
    print('--Mean time running algorithm: %f seconds' % (total_time /
                                                         len(times)))
    print('--Std time running algorithm: %f seconds' %
          np.std(np.asarray(times)))
    print('--Min time running algorithm: %f seconds' % min(times))
    print('--Max time running algorithm: %f seconds' % max(times))


if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        description='Arguments for tracing files index with a .json file.')
    parser.add_argument(
        '-f',
        '--file',
        type=str,
        default=None,
        required=True,
        help='The input json file. Pls see ./tests/data/test.json for example.')
    parser.add_argument(
        '-x',
        '--run',
        type=str,
        default='rivulet2',
        required=False,
        help='The method to use (rivulet2/rivulet1/app2) ')
    parser.add_argument(
        '--pbsfile',
        type=str,
        default=None,
        required=False,
        help='The pbs file to run. Required if --pbs written.')
    parser.add_argument(
        '--thread',
        type=int,
        default=1,
        required=False,
        help='The number of threads to parallelise the tasks. Default 1')
    parser.add_argument(
        '-z',
        '--zoom_factor',
        type=float,
        default=1.,
        help="""The factor to zoom the image to speed up the whole thing.
                     The result swc will be zoomed back according
                     to the outfile name. Default 1.""")
    parser.add_argument('--centos', dest='centos', action='store_true')
    parser.add_argument('--no-centos', dest='centos', action='store_false')
    parser.set_defaults(centos=False)
    args = parser.parse_args()
    tracejsonfile(
        args.file,
        args.pbsfile,
        method=args.run,
        nthread=args.thread,
        centos=args.centos,
        zoom_factor=args.zoom_factor)
