#!python

import argparse
import sys
from os.path import abspath, expanduser
import os
import warnings
import matplotlib
matplotlib.use('Agg')
from matplotlib import pyplot as plt
sys.path.insert(0, os.path.abspath('..'))
from specdal.collection import Collection, proximal_join

parser = argparse.ArgumentParser(description='SpecDAL Pipeline')
# io options
parser.add_argument('-i', '--input_dir', metavar='PATH', default='./',
                    action='store',
                    help='input directory containing input files')
parser.add_argument('-o', '--output_dir', metavar='PATH',
                    default='./specdal_output', action='store',
                    help='output directory to store the results')
parser.add_argument('-of', '--output_figures', action='store_true',
                    help='output figures')
parser.add_argument('-od', '--output_data', action='store_true',
                    help='output data')
parser.add_argument('-oi', '--output_individual', action='store_true',
                    help='output individual spectrum')
parser.add_argument('-oa', '--output_aggregates', action='store_true',
                    help='output group aggregates')
# optional arguments
parser.add_argument('-n', '--name', type=str, action='store',
                    default='dataset', help='name of the dataset')
# resampler
parser.add_argument('-r', '--resample', default=None,
                    choices=['slinear', 'cubic'],
                    help='interpolation method')
parser.add_argument('-rs', '--resample_spacing', metavar='SPC', nargs=1,
                    type=int, default=1,
                    help='spacing for resampling (in nm)')
## overlap stitcher
parser.add_argument('-s', '--stitch', default=None,
                    choices=['mean', 'median', 'min', 'max'],
                    help='overlap stitching method')
# jump corrector
parser.add_argument('-j', '--jump_correct', default=None,
                    choices=['additive'],
                    help='jump correction method')
parser.add_argument('-js', '--jump_correct_splices', metavar='WVL',
                    default=[1000, 1800], type=int, nargs='+',
                    help='wavelengths of jump locations')
parser.add_argument('-jr', '--jump_correct_reference', metavar='REF',
                    type=int, nargs=1, default=0, help='position of'
                    'the reference detector')
# groupby
parser.add_argument('-g', '--group_by', action='store_true',
                    help='create groups using filenames')
parser.add_argument('-gs', '--group_by_separator',
                    metavar='S', nargs=1, default='_',
                    help='separator sequence to split the file names')
parser.add_argument('-gi', '--group_by_indices', metavar='I', nargs='*', type=int,
                    help='indices of the split filenames to define a group')
parser.add_argument('--group_mean', dest='aggr', action='append_const',
                    default=[],
                    const='mean', help='calculate group means')
parser.add_argument('--group_median', dest='aggr', action='append_const',
                    const='median', help='calculate group median')
parser.add_argument('--group_std', dest='aggr', action='append_const',
                    const='std', help='calculate group standard deviation')
parser.add_argument('--group_append', dest='aggr_append', nargs='+', default=[],
                    help='append aggregates to group figures')
parser.add_argument('-d', '--debug', action='store_true')
# misc
parser.add_argument('-q', '--quiet', default=False, action='store_true')
parser.add_argument('--proximal', default=None, metavar='PATH', action='store',
                    help='path containing base measurement spectral files')

args = parser.parse_args()
if args.debug:
    print('args = {}'.format(args))

################################################################################
# main
################################################################################
VERBOSE = not args.quiet

def print_if_verbose(*args, **kwargs):
    if VERBOSE:
        print(*args, **kwargs)
    
indir = abspath(expanduser(args.input_dir))
outdir = abspath(expanduser(args.output_dir))
datadir = os.path.join(outdir, 'data')
figdir = os.path.join(outdir, 'figures')
assert os.path.exists(indir)
if args.debug is True:
    import shutil
    if os.path.isdir('./specdal_output/'):
        shutil.rmtree('./specdal_output/')
assert not os.path.exists(outdir)
# make output directories
for d in (outdir, datadir, figdir):
    os.makedirs(d)

c = Collection(name=args.name)
print_if_verbose('Reading target measurements from ' + indir)
c.read(directory=indir)

if args.proximal:
    print_if_verbose('Reading base measurements from ' + args.proximal)
    c_base = Collection(name=args.name + '_base')
    c_base.read(directory=args.proximal)

if args.resample:
    print_if_verbose('Resampling...')
    c.resample(spacing=args.resample_spacing, method=args.resample)
    if args.proximal:
        c_base.resample(spacing=args.resample_spacing, method=args.resample)
if args.stitch:
    print_if_verbose('Stitching...')
    c.stitch(method=args.stitch)
    if args.proximal:
        c_base.stitch(method=args.stitch)
if args.jump_correct:
    print_if_verbose('Jump correcting...')
    c.jump_correct(splices=args.jump_correct_splices,
                   reference=args.jump_correct_reference,
                   method=args.jump_correct)
    if args.proximal:
        c_base.jump_correct(splices=args.jump_correct_splices,
                            reference=args.jump_correct_reference,
                            method=args.jump_correct)

if args.proximal:
    print_if_verbose('Joining proximal data...')
    c = proximal_join(c_base, c, on='gps_time_tgt', direction='nearest')

groups = None
# group by
if args.group_by:
    print_if_verbose('Grouping...')
    groups = c.groupby(separator=args.group_by_separator,
                       indices=args.group_by_indices)

# output individual spectra
if args.output_individual:
    print_if_verbose('Saving individual spectrum outputs...')
    indiv_datadir = os.path.join(datadir, 'indiv')
    indiv_figdir = os.path.join(figdir, 'indiv')
    os.mkdir(indiv_datadir)
    os.mkdir(indiv_figdir)
    for spectrum in c.spectra:
        spectrum.to_csv(os.path.join(indiv_datadir, spectrum.name + '.csv'))
        spectrum.plot(legend=False)
        plt.savefig(os.path.join(indiv_figdir, spectrum.name + '.png'), bbox_inches='tight')
        plt.close()

# output whole and group data
if args.output_data:
    print_if_verbose('Saving data files...')
    c.to_csv(os.path.join(datadir, c.name + ".csv"))
    if groups:
        for group_id, group_coll in groups.items():
            group_coll.to_csv(os.path.join(datadir, group_id + '.csv'))

# calculate group aggregates
if groups:
    if len(args.aggr) > 0:
        print_if_verbose('Calculating group aggregates...')
    for aggr in args.aggr:
        append = aggr in args.aggr_append
        aggr_coll = Collection(name=c.name+'_'+aggr,
                               spectra=[getattr(group_coll, aggr)(append=append)
                                        for group_coll in groups.values()],
                               measure_type=c.measure_type)
        # output
        if args.output_aggregates:
            print_if_verbose('Saving group {} outputs...'.format(aggr))
            aggr_coll.to_csv(os.path.join(datadir, aggr_coll.name + '.csv'))
            aggr_coll.plot(legend=False)
            plt.savefig(os.path.join(figdir, aggr_coll.name + '.png'),
                        bbox_inches='tight')
            plt.close()

# output whole and group figures (possibly with aggregates appended)
if args.output_figures:
    print_if_verbose('Saving entire and grouped figure outputs...')
    c.plot(legend=False)
    plt.savefig(os.path.join(figdir, c.name + ".png"),  bbox_inches="tight")
    plt.close()
    if groups:
        for group_id, group_coll in groups.items():
            group_coll.plot(legend=False)
            plt.savefig(os.path.join(figdir, group_id + ".png"),  bbox_inches="tight")
            plt.close()
