#!/usr/bin/python3
import os
import argparse
import math
import time
import numpy as np
from rivuletpy.trace import R2Tracer
from rivuletpy.utils.io import loadimg, writetiff3d, crop
from filtering.thresholding import rescale
from scipy.ndimage.interpolation import zoom

def show_logo():
    s = "====================Welcome to Rivulet2=================================="
    s += """\n\n8888888b.  d8b                   888          888           .d8888b.  
888   Y88b Y8P                   888          888          d88P  Y88b 
888    888                       888          888                 888 
888   d88P 888 888  888 888  888 888  .d88b.  888888            .d88P 
8888888P\"  888 888  888 888  888 888 d8P  Y8b 888           .od888P\"  
888 T88b   888 Y88  88P 888  888 888 88888888 888          d88P\"      
888  T88b  888  Y8bd8P  Y88b 888 888 Y8b.     Y88b.        888\"       
888   T88b 888   Y88P    \"Y88888 888  \"Y8888   \"Y888       888888888  \n\n\n\n"""
    print(s)

if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        description='Arguments to perform the Rivulet2 tracing algorithm.')
    parser.add_argument(
        '-f',
        '--file',
        type=str,
        default=None,
        required=True,
        help='The input file. A image file (*.tif, *.nii, *.mat).')
    parser.add_argument(
        '-o',
        '--out',
        type=str,
        default=None,
        required=False,
        help='The name of the output file')
    parser.add_argument(
        '-t',
        '--threshold',
        type=float,
        default=0,
        help='threshold to distinguish the foreground and background. Defulat 0. If threshold<0, otsu will be used.'
    )
    parser.add_argument(
        '-z',
        '--zoom_factor',
        type=float,
        default=1.,
        help='The factor to zoom the image to speed up the whole thing. Default 1.')

    # Arguments for soma detection
    parser.add_argument('--save-soma', dest='save_soma', action='store_true')
    parser.add_argument('--no-save-soma', dest='save_soma', action='store_false')
    parser.set_defaults(save_soma=False)

    parser.add_argument('--soma', dest='soma_detection', action='store_true')
    parser.add_argument('--no-soma', dest='soma_detection', action='store_false')
    parser.set_defaults(soma_detection=False)

    # Args for tracing
    parser.add_argument(
        '--speed',
        type=str,
        default='dt',
        help='The type of speed image to use (dt, ssm)')

    parser.add_argument('--quality', dest='quality', action='store_true')
    parser.add_argument('--no-quality', dest='quality', action='store_false')
    parser.set_defaults(quality=False)

    parser.add_argument('--clean', dest='clean', action='store_true')
    parser.add_argument('--no-clean', dest='clean', action='store_false')
    parser.set_defaults(clean=True)

    # MISC
    parser.add_argument('--silent', dest='silent', action='store_true')
    parser.add_argument('--no-silent', dest='silent', action='store_false')
    parser.set_defaults(silent=False)

    args = parser.parse_args()

    if not args.silent: show_logo()

    starttime = time.time()
    img = loadimg(args.file)
    imgdtype = img.dtype
    imgshape = img.shape

    if not args.silent:
        print('The shape of the image is', img.shape)
    # Modify the crop function so that it can crop somamask as well
    img, crop_region = crop(img, args.threshold)  # Crop by default

    if args.zoom_factor != 1.:
        if not args.silent:
            print('-- Zooming image to %.2f of original size' % args.zoom_factor)
        img = zoom(img, args.zoom_factor)

    # Run rivulet2 for the first time
    tracer = R2Tracer(quality=args.quality, silent=args.silent, speed=args.speed, clean=args.clean)
    swc, soma = tracer.trace(img, args.threshold)
    swc.reset(crop_region, args.zoom_factor)
    swc.save(args.out if args.out else os.path.splitext(args.file)[0] + '.r2.swc')

    # Save the soma mask if required
    if args.save_soma:
        soma.pad(crop_region)
        soma.save((os.path.splitext(outswcfile)[0] + 'soma.tif'))

    print('-- Finshed: %.2f sec.' % (time.time() - starttime))