#!/usr/bin/env python
import argparse
import sys

import matplotlib.pyplot as plt

from color_by_arg import StarPop, color_by_arg

def main(argv):
    parser = argparse.ArgumentParser(description="Make a TRILEGAL diagnostic plot")

    parser.add_argument('-c', '--cmap', type=str, default='Spectral',
                        help='Matplotlib colormap to use')

    parser.add_argument('-x', '--xlim', type=str, default=None,
                        help='comma separated x axis limits')

    parser.add_argument('-y', '--ylim', type=str, default=None,
                        help='comma separated y axis limits')

    parser.add_argument('-z', '--clim', type=str, default=None,
                        help='comma separated color bar axis limits')

    parser.add_argument('-v', '--vlims', type=str, default=None,
                        help='comma separated color bar vlimits (outside these limits will be gray)')

    parser.add_argument('-b', '--bins', type=str, default=None,
                        help='number of color bins to use')
    
    parser.add_argument('-e', '--ext', type=str, default='png',
                        help='image extension (without the .)')
    
    parser.add_argument('-a', '--multiplot', action='store_true',
                        help='make two default diagnostic plots, trilegal must have been run with -a and -l flags')

    parser.add_argument('infile', type=str, help='TRILEGAL file')
    
    parser.add_argument('xcolumn', type=str, help='x axis column')

    parser.add_argument('ycolumn', type=str, help='y axis column')
    
    parser.add_argument('ccolumn', type=str, help='color bar column')
    
    args = parser.parse_args(argv)
    
    if args.xlim is not None:
        args.xlim = map(float, args.xlim.split(','))
    
    if args.ylim is not None:
        args.ylim = map(float, args.ylim.split(','))
    
    if args.clim is not None:
        args.clim = map(float, args.clim.split(','))
    
    if args.bins is not None:
        args.bins = int(args.bins)
    
    skw = {}
    if args.vlims is not None:
        vlims = map(float, args.vlims.split(','))
        skw={'vmin': vlims[0], 'vmax': vlims[1]}
    
    sgal = StarPop(args.infile)
    
    if args.multiplot is True:
        zdict = {'stage': [(0, 9), 10, (1, 8)],
                 'logAge': [(6, 10.1), 20, (6, 10.1)],
                 'm_ini':  [(0.9, 8.), 10, (0.9, 8.)],
                 '[M/H]': [(-2, 2), 10, (-2, 2)],
                 'C/O': [(0.48, 4), 10, (0.48, 4)],
                 'logML': [(-11, -4), 10, (-11, -4)]}
        
        # could add observation to left plot...
        #if mag1 is not None:
        #    fig, (axs) = plt.subplots(ncols=4, nrows=2, sharex=True,
        #                              sharey=True, figsize=(16, 12))

        #    for ax in axs[:, 0]:
        #        ax.plot(mag1 - mag2, mag2, '.', color='k')
        #        ax.set_xlabel((r'${}-{}$'.format(f1,f2)).replace('_', '\_'))
        #        ax.set_ylabel((r'${}$'.format(f2)).replace('_', '\_'))
        #    
        #    axs = axs[:, 1:].flatten()
        #else:        
        fig, (axs) = plt.subplots(ncols=3, nrows=2, sharex=True,
                                  sharey=True, figsize=(16, 12))    
        axs = axs.flatten()

        for i, (zcol, dat) in enumerate(zdict.items()):
            ax = axs[i]
            clim, bins, vlims = dat
            # data model CMD plot                
            ax = color_by_arg(sgal, args.xcolumn, args.ycolumn, zcol, ax=ax,
                              clim=clim, bins=bins, xlim=args.xlim,
                              skw={'vmin': vlims[0], 'vmax': vlims[1]},
                              ylim=args.ylim)
        cstr = 'multi'
    else:
        cstr = args.ccolumn
        ax = color_by_arg(sgal, args.xcolumn, args.ycolumn, args.ccolumn,
                          bins=args.bins, cmap=args.cmap, xlim=args.xlim,
                          ylim=args.ylim, clim=args.clim, skw=skw)
        
    outpref = args.infile.replace('.dat', '')
    outfile ='{}_{}_{}_{}.{}'.format(outpref, args.xcolumn, args.ycolumn,
                                     cstr, args.ext)
    outfile = outfile.translate(None, '[]/')
    plt.savefig(outfile)
    print 'wrote {}'.format(outfile)

if __name__ == '__main__':
    main(sys.argv[1:])
