#!/usr/bin/python3
import os
import argparse
import math
import time
import numpy as np
from rivuletpy.trace import R2Tracer
from rivuletpy.utils.io import loadimg, writetiff3d, crop
from filtering.thresholding import rescale
from scipy.ndimage.interpolation import zoom

def show_logo():
    s = "====================Welcome to Rivulet2=================================="
    s += """\n\n8888888b.  d8b                   888          888           .d8888b.  
888   Y88b Y8P                   888          888          d88P  Y88b 
888    888                       888          888                 888 
888   d88P 888 888  888 888  888 888  .d88b.  888888            .d88P 
8888888P\"  888 888  888 888  888 888 d8P  Y8b 888           .od888P\"  
888 T88b   888 Y88  88P 888  888 888 88888888 888          d88P\"      
888  T88b  888  Y8bd8P  Y88b 888 888 Y8b.     Y88b.        888\"       
888   T88b 888   Y88P    \"Y88888 888  \"Y8888   \"Y888       888888888  \n\n\n\n"""
    print(s)

if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        description='Arguments to perform the Rivulet2 tracing algorithm.')
    parser.add_argument(
        '-f',
        '--file',
        type=str,
        default=None,
        required=True,
        help='The input file. A image file (*.tif, *.nii, *.mat).')
    parser.add_argument(
        '-o',
        '--out',
        type=str,
        default=None,
        required=False,
        help='The name of the output file')
    parser.add_argument(
        '-t',
        '--threshold',
        type=float,
        default=0,
        help='threshold to distinguish the foreground and background. Defulat 0. If threshold<0, otsu will be used.'
    )
    parser.add_argument(
        '-z',
        '--zoom_factor',
        type=float,
        default=1.,
        help='The factor to zoom the image to speed up the whole thing. Default 1.')

    # Arguments for soma detection
    parser.add_argument('--save-soma', dest='save_soma', action='store_true',
        help="Save the automatically reconstructed soma volume along with the SWC.")
    parser.add_argument('--no-save-soma', dest='save_soma', action='store_false',
        help="Don't save the automatically reconstructed soma volume along with the SWC (default)")
    parser.set_defaults(save_soma=False)

    parser.add_argument('--soma', dest='soma_detection', action='store_true',
        help="Use the morphological operator based soma detection")
    parser.add_argument('--no-soma', dest='soma_detection', action='store_false',
        help="Don't use the morphological operator based soma detection (default)")
    parser.set_defaults(soma_detection=False)

    # Args for tracing
    parser.add_argument(
        '--speed',
        type=str,
        default='dt',
        help='The type of speed image to use (dt, ssm). dt would work for most of the cases. ssm provides slightly better curves with extra computing time')

    parser.add_argument('--quality', dest='quality', action='store_true',
        help="Reconstruct the neuron with higher quality and slightly more computing time")
    parser.add_argument('--no-quality', dest='quality', action='store_false',
        help="Reconstruct the neuron with lower quality and slightly more computing time")
    parser.set_defaults(quality=False)

    parser.add_argument('--clean', dest='clean', action='store_true',
        help="Remove the unconnected segments (default). It is relatively safe to do with the Rivulet2 algorithm")
    parser.add_argument('--no-clean', dest='clean', action='store_false',
        help="Keep the unconnected segments")
    parser.set_defaults(clean=True)

    # MISC
    parser.add_argument('--silent', dest='silent', action='store_true', help="Omit the terminal outputs")
    parser.add_argument('--no-silent', dest='silent', action='store_false', help="Show the terminal outputs & the nice logo (default)")
    parser.set_defaults(silent=False)

    args = parser.parse_args()

    if not args.silent: show_logo()

    starttime = time.time()
    img = loadimg(args.file)
    imgdtype = img.dtype
    imgshape = img.shape

    if not args.silent:
        print('The shape of the image is', img.shape)
    # Modify the crop function so that it can crop somamask as well
    img, crop_region = crop(img, args.threshold)  # Crop by default

    if args.zoom_factor != 1.:
        if not args.silent:
            print('-- Zooming image to %.2f of original size' % args.zoom_factor)
        img = zoom(img, args.zoom_factor)

    # Run rivulet2 for the first time
    tracer = R2Tracer(quality=args.quality, silent=args.silent, speed=args.speed, clean=args.clean)
    swc, soma = tracer.trace(img, args.threshold)
    swc.reset(crop_region, args.zoom_factor)
    swc.save(args.out if args.out else os.path.splitext(args.file)[0] + '.r2.swc')

    # Save the soma mask if required
    if args.save_soma:
        soma.pad(crop_region)
        soma.save((os.path.splitext(outswcfile)[0] + 'soma.tif'))

    print('-- Finshed: %.2f sec.' % (time.time() - starttime))