#!python

"""Makes sequence logo plots.

Written by Jesse Bloom."""


import sys
import os
import re
import time
import logging
import natsort
import numpy
import pandas
import phydmslib.weblogo
import dms_tools2.parseargs
import dms_tools2.utils
import dms_tools2.prefs


def main():
    """Main body of script."""

    # Parse command line arguments
    parser = dms_tools2.parseargs.logoplotParser()
    args = vars(parser.parse_args())
    prog = parser.prog

    # what type of data are we plotting?
    datatype = [x for x in ['prefs', 'diffsel'] if args[x]]
    assert len(datatype) == 1
    datatype = datatype[0]

    # define output file names
    if args['outdir']:
        if not os.path.isdir(args['outdir']):
            os.mkdir(args['outdir'])
    else:
        args['outdir'] = ''
    filesuffixes = {
            'log':'.log',
            'logo':'_{0}.pdf'.format(datatype),

            }
    files = dict([(f, os.path.join(args['outdir'], '{0}{1}'.format(
            args['name'], s))) for (f, s) in filesuffixes.items()])

    # do we need to proceed?
    if args['use_existing'] == 'yes' and all(map(
            os.path.isfile, files.values())):
        print("Output files already exist and '--use_existing' is 'yes', "
                "so exiting with no furhter action.")
        sys.exit(0)

    logger = dms_tools2.utils.initLogger(files['log'], prog, args)

    # log in try / except / finally loop
    try:
        # remove expected output files if they already exist
        for (ftype, f) in files.items():
            if os.path.isfile(f) and ftype != 'log':
                logger.info("Removing existing file {0}".format(f))
                os.remove(f)

        # some checking on arguments
        assert dms_tools2.parseargs.checkName(args['name'], 'name')
        assert args['nperline'] >= 1
        assert args['numberevery'] >= 1
        assert args['stringency'] >= 0

        # read data
        logger.info("Reading {0} from file {1}...".format(datatype, 
                args[datatype]))
        assert os.path.isfile(args[datatype]), "Can't find {0}".format(
                args[datatype])
        data = pandas.read_csv(args[datatype])
        assert 'site' in data.columns, "no 'site' column"
        data['site'] = data['site'].astype(str)
        sites = data['site'].unique()
        logger.info("Read data for {0} sites.\n".format(len(sites)))
        if args['sortsites'] == 'yes':
            sites = natsort.natsorted(sites)
        elif args['sortsites'] != 'no':
            raise ValueError("invalid --sortsites")

        if datatype == 'prefs':
            if args['stringency'] != 1:
                logger.info("Re-scaling preferences by stringency "
                    "parameter {0}".format(args['stringency']))
                data = dms_tools2.prefs.rescalePrefs(data, 
                        args['stringency'])
            nosepline = True
            ylimits = None

        elif datatype == 'diffsel':
            # logo plot of differential selection
            data = (data.pivot_table(index='site', values='mutdiffsel',
                                columns='mutation')
                        .fillna(0)
                        .reindex_axis(sites)
                        )
            if args['restrictdiffsel'] == 'positive':
                data = data.clip(0, None)
            elif args['restrictdiffsel'] == 'negative':
                data = data.clip(None, 0)
            elif args['restrictdiffsel'] != 'all':
                raise ValueError('invalid restrictdiffsel')
            data = data.reset_index()

            args['letterheight'] *= 2 # taller letter stacks for diffsel
            nosepline = {'no':True, 'yes':False}[args['sepline']]

            ymin = (data.drop('site', axis=1)
                        .clip(None, 0)
                        .sum(axis=1)
                        .min()
                        )
            ymax = (data.drop('site', axis=1)
                        .clip(0, None)
                        .sum(axis=1)
                        .max()
                        )
            if args['diffselrange']:
                if (ylimits[0] > ymin) or (ylimits[0] < ymax):
                    raise ValueError("Invalid diffselrange does not "
                            "fully include data range of {0} to {1}"
                            .format(ymin, ymax))
                ymin = diffselrange[0]
                ymax = diffselrange[1]
            dy = ymax - ymin
            ylimits = (ymin - 0.02 * dy, ymax + 0.02 * dy)

        else:
            raise ValueError("Invalid datatype {0}".format(datatype))

        # exclude stop codons if specified
        assert (set(['site'] + dms_tools2.AAS) <= set(data.columns) 
                <= set(['site'] + dms_tools2.AAS_WITHSTOP)), (
                "invalid columns")
        if args['excludestop'] == 'yes':
            data = data[['site'] + dms_tools2.AAS]
            if datatype == 'prefs':
                # re-norm after excludestop
                data = dms_tools2.prefs.rescalePrefs(data, 
                        args['stringency'])
        elif args['excludestop'] != 'no':
            raise ValueError("invalid excludestop")
            
        # convert data from wide data frame to dict for logo plot
        data = data.set_index('site').to_dict('index')

        # read any overlays
        overlay = []
        for i in range(3): # loop over possibly overlays
            overlayarg = 'overlay{0}'.format(i + 1)
            if not args[overlayarg]:
                continue
            (overlayfile, shortname, longname) = args[overlayarg]
            logger.info("Reading overlay for {0} from {1}...".format(
                    shortname, overlayfile))
            assert (len(shortname) < 6) or (shortname == 'wildtype'), \
                "{0} SHORTNAME too long".format(overlayarg)
            overlaydf = pandas.read_csv(overlayfile)
            assert {'site', shortname} <= set(overlaydf.columns), \
                    "No 'site' and {0} columns in {1} FILE".format(
                    shortname, overlayarg)
            overlaydf = overlaydf[['site', shortname]].drop_duplicates()
            overlaydf['site'] = overlaydf['site'].astype('str')
            assert len(overlaydf['site']) == len(set(overlaydf['site'])),\
                    "Duplicate sites in {0} FILE".format(overlayarg)
            extrasites = set(overlaydf['site']) - set(sites)
            assert not extrasites, "Extra sites in {0}:\n{1}".format(
                    overlayarg, extrasites)
            overlay.append((
                    dict(zip(overlaydf['site'], overlaydf[shortname])),
                    shortname, longname))
            logger.info("Read overlay for {0} sites.\n".format(
                    len(overlaydf['site'])))
            if args['underlay'] == 'yes':
                overlay.reverse()
        shortnames = [tup[1] for tup in overlay]
        assert len(set(shortnames)) == len(shortnames), ("Duplicate "
                "SHORTNAME in overlay arguments")

        # make logo plot
        logger.info("Making logo plot {0}...".format(files['logo']))
        phydmslib.weblogo.LogoPlot(
                    sites=sites, 
                    datatype=datatype, 
                    data=data, 
                    plotfile=files['logo'],
                    nperline=args['nperline'],
                    numberevery=args['numberevery'],
                    allowunsorted=True,
                    ydatamax=1.01, # no meaning for prefs or diffsel
                    overlay=overlay,
                    fix_limits={},
                    fixlongname=False,
                    overlay_cmap=args['overlaycolormap'],
                    ylimits=ylimits,
                    relativestackheight=args['letterheight'],
                    custom_cmap=args['colormap'],
                    map_metric=args['mapmetric'],
                    noseparator=nosepline,
                    underlay={'no':False, 'yes':True}[args['underlay']],
                    )
        logger.info("Successfully created logo plot.\n")

    except:
        logger.exception('Terminating {0} with ERROR.'.format(prog))
        for (fname, fpath) in files.items():
            if fname != 'log' and os.path.isfile(fpath):
                logger.exception("Deleting file {0}".format(fpath))
                os.remove(fpath)

    else:
        logger.info('Successful completion of {0}'.format(prog))

    finally:
        logging.shutdown()



if __name__ == '__main__':
    main() # run the script
