#!/usr/bin/env python

import argparse
import asts
import logging
import os
import sys

import matplotlib.pylab as plt

logger = logging.getLogger()

def main(argv):
    description="""Calculate completeness fraction, make AST plots or add AST corrections to a trilegal catalog
For example:
Find the 50% completeness mags of a given AST file:
$ asts -c 0.5 10190_M33-DISK1_F606W_F814W.gst.matchfake
 10190_M33-DISK1_F606W_F814W.gst.matchfake 0.5 completeness fraction:
 M33-DISK1            25.5470 24.6400

Make AST diagnostic plots:
asts -p 10190_M33-DISK1_F606W_F814W.gst.matchfake
(see *_comp.png, _ast.png)

Add lines at 90%, and 50% completeness to plot:
asts -p -f 0.9,0.5 10190_M33-DISK1_F606W_F814W.gst.matchfake
(see *_comp.png)

Use an AST file to "correct" a TRILEGAL file:
asts -t M33-DISK1.dat 10190_M33-DISK1_F606W_F814W.gst.matchfake
(see asts.log)
"""
    parser = argparse.ArgumentParser(description=description)

    parser.add_argument('-c', '--comp_frac', type=float, default=0.9,
                        help='completeness fraction to calculate')

    parser.add_argument('-p', '--makeplots', action='store_true',
                        help='make AST plots')

    parser.add_argument('-m', '--bright_mag', type=float, default=20.,
                        help='brighest mag to consider for completeness frac')

    parser.add_argument('-f', '--plot_fracs', type=str, default=None,
                        help='comma separated completeness fractions to overplot')

    parser.add_argument('-t', '--trilegal', type=str, default=None,
                        help='Trilegal catalog to add ASTs to')
    
    parser.add_argument('fake', type=str, nargs='*', help='match AST file(s)')

    args = parser.parse_args(argv)
    
    logfile = 'asts.log'
    handler = logging.FileHandler(logfile)
    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    handler.setFormatter(formatter)
    logger.addHandler(handler)
    logger.setLevel(logging.DEBUG)

    if args.trilegal is not None:
        ast = asts.ASTs(args.fake[0])
        header = open(args.trilegal, 'r').readline()            
        if ast.filter1 + '_cor' in header.split():
            logger.debug('{}_cor already in header'.format(ast.filter1))
        elif ast.filter2 + '_cor' in header.split():
            logger.debug('{}_cor already in header'.format(ast.filter2))
        else:
            asts.ast_correct_starpop(args.trilegal, outfile=args.trilegal,
                                fake_file=args.fake[0], hdf5=False, overwrite=True,
                                diag_plot=args.makeplots)
    else:
        for fake in args.fake:
            ast = asts.ASTs(fake)
            ast.completeness(combined_filters=True, interpolate=True,
                             binsize=0.15)
            comp1, comp2 = ast.get_completeness_fraction(args.comp_frac,
                                                         bright_lim=args.bright_mag)    
            print('{} {} completeness fraction:'.format(fake, args.comp_frac))
            print('{0:20s} {1:.4f} {2:.4f}'.format(ast.target, comp1, comp2))
            
            if args.makeplots:
                comp_name = os.path.join(ast.base, ast.name + '_comp.png')
                ast_name = os.path.join(ast.base, ast.name + '_ast.png')
        
                ax = ast.completeness_plot()
                if args.plot_fracs is not None:
                    fracs = map(float, args.plot_fracs.split(','))
                    ast.add_complines(ax, *fracs, **{'bright_lim': args.bright_mag})
                plt.savefig(comp_name)
                plt.close()
        
                ast.magdiff_plot()
                plt.savefig(ast_name)
                plt.close()


if __name__ == "__main__":
    main(sys.argv[1:])