#!/Users/Steven/anaconda/bin/python
# encoding: utf-8
'''
scripts.hmfrun -- shortdesc

scripts.hmfrun is a description

It defines classes_and_methods
'''

import sys
import os
import traceback

import hmf
import numpy as np
from argparse import ArgumentParser
from argparse import RawDescriptionHelpFormatter

__all__ = []
__version__ = hmf.hmf.version
__date__ = "2014 - 01 - 23"
__updated__ = "2014 - 01 - 23"

DEBUG = 0
TESTRUN = 0
PROFILE = 0

class CLIError(Exception):
    '''Generic exception to raise and log different fatal errors.'''
    def __init__(self, msg):
        super(CLIError).__init__(type(self))
        self.msg = "E: %s" % msg
    def __str__(self):
        return self.msg
    def __unicode__(self):
        return self.msg

def main(argv=None):
    '''Generate halo mass functions and write them to file.'''

    if argv is None:
        argv = sys.argv
    else:
        sys.argv.extend(argv)

    program_name = os.path.basename(sys.argv[0])
    program_version = "v%s" % __version__
    program_build_date = str(__updated__)
    program_version_message = '%%(prog)s %s (%s)' % (program_version, program_build_date)
    program_shortdesc = __import__('__main__').__doc__.split("\n")[1]
    program_license = '''%s

  Created by user_name on %s.
  Copyright 2014 organization_name. All rights reserved.

  Licensed under the Apache License 2.0
  http://www.apache.org/licenses/LICENSE-2.0

  Distributed on an "AS IS" basis without warranties
  or conditions of any kind, either express or implied.

USAGE
''' % (program_shortdesc, str(__date__))

    try:
        h = hmf.MassFunction()
        m_attrs = ["M", "dndlog10m", "lnsigma", "n_eff", "sigma",
                   "dndm", "ngtm", "fsigma", "mgtm", "nltm", "dndlnm",
                   "how_big", "mltm", "_sigma_0", "_dlnsdlnm"]
        k_attrs = ["power", "delta_k", "lnk", "transfer", "nonlinear_power",
                   "_lnP_0", "_lnP_cdm_0", "_lnT_cdm", "_unnormalised_lnP",
                   "_unnormalised_lnT"]
        # Setup argument parser
        parser = ArgumentParser(description=program_license, formatter_class=RawDescriptionHelpFormatter)
        parser.add_argument("-v", "--verbose", dest="verbose", action="count", help="set verbosity level [default: %(default)s]")
        parser.add_argument('-V', '--version', action='version', version=program_version_message)

        # HMF specific arguments
        config = parser.add_argument_group("Config", "Variables of Configuration")
        config.add_argument("filename", help="filename to write to")
        config.add_argument("--get", nargs="*", default=["M", "dndm"],
                            choices=m_attrs + k_attrs)

        hmfargs = parser.add_argument_group("HMF", "HMF-specific arguments")
        hmfargs.add_argument("--M", nargs=3, type=float,
                            help="the mass range and intervals, min max step [default: %s %s %s]" %
                            (np.log10(h.M[0]), np.log10(h.M[-1]), np.log10(h.M[1]) - np.log10(h.M[0])))
        hmfargs.add_argument("--mf-fit", nargs="*", choices=hmf.Fits.mf_fits + ["all"],
                            help="fitting function(s) to use. 'all' uses all of them [default: %s]" % h.mf_fit)
        hmfargs.add_argument("--delta-h", nargs="*", type=float,
                            help="overdensity of halo w.r.t delta_wrt [default %s]" % h.delta_wrt)
        hmfargs.add_argument("--delta-wrt", choices=["mean", "crit"],
                            help="what delta_h is with respect to [default: %s]" % h.delta_h)
        hmfargs.add_argument("--user-fit", help="a custom fitting function defined as a string in terms of x for sigma [default: %s]" % "'" + h.user_fit + "'")
        hmfargs.add_argument("--no-cut-fit", action="store_true", help="whether to cut the fitting function at tested boundaries")
        hmfargs.add_argument("--z2", nargs="*", type=float, help="upper redshift for volume weighting")
        hmfargs.add_argument("--nz", nargs="*", type=float, help="number of redshift bins for volume weighting")
        hmfargs.add_argument("--delta-c", nargs="*", type=float, help="critical overdensity for collapse [default: %s]" % h.delta_c)

        # # Transfer-specific arguments
        transferargs = parser.add_argument_group("Transfer", "Transfer-specific arguments")
        transferargs.add_argument("--z", nargs="*", type=float, help="redshift of analysis [default: %s]" % h.transfer.z)
        transferargs.add_argument("--lnk", nargs=3, type=float, help="the wavenumber range and intervals, min max step [default: %s %s %s]" %
                                  (h.transfer.lnk[0], h.transfer.lnk[-1], h.transfer.lnk[1] - h.transfer.lnk[0]))
        transferargs.add_argument("--maxk", nargs="*", type=float, default=2e4, help="maximum wavenumber of analysis [default: %s]" % np.exp(h.transfer.lnk[-1]))
        transferargs.add_argument("--numk", nargs="*", type=int, default=250, help="number of wavenumbers in analysis [default: %s]" % len(h.transfer.lnk))
        transferargs.add_argument("--wdm-mass", nargs="*", type=float, help="warm dark matter mass (0 is CDM)")
        transferargs.add_argument("--transfer-fit", nargs="*", choices=hmf.transfer.Transfer.fits + ['all'],
                                  help="which fit for the transfer function to use ('all' uses all of them) [default: %s]" % h.transfer.transfer_fit)

        cambargs = parser.add_argument_group("CAMB", "CAMB-specific arguments")
        cambargs.add_argument("--Scalar-initial-condition", nargs="*", type=int, choices=[1, 2, 3, 4, 5],
                              help="[CAMB] initial scalar perturbation mode [default: %s]" % h.transfer._camb_options["Scalar_initial_condition"])
        cambargs.add_argument("--lAccuracyBoost", nargs="*", type=float,
                            help="[CAMB] optional accuracy boost [default: %s]" % h.transfer._camb_options["lAccuracyBoost"])
        cambargs.add_argument("--AccuracyBoost", nargs="*", type=float,
                            help="[CAMB] optional accuracy boost [default: %s]" % h.transfer._camb_options["AccuracyBoost"])
        cambargs.add_argument("--w-perturb", action="store_true", help="[CAMB] whether w should be perturbed or not")
        cambargs.add_argument("--transfer--k-per-logint", nargs="*", type=float,
                            help="[CAMB] number of estimated wavenumbers per interval [default: %s]" % h.transfer._camb_options["transfer__k_per_logint"])
        cambargs.add_argument("--transfer--kmax", nargs='*', type=float,
                            help="[CAMB] maximum wavenumber to estimate [default: %s]" % h.transfer._camb_options["transfer__kmax"])
        cambargs.add_argument("--ThreadNum", type=int,
                              help="number of threads to use (0 is automatic detection) [default: %s]" % h.transfer._camb_options["ThreadNum"])

        # # Cosmo-specific arguments
        cosmoargs = parser.add_argument_group("Cosmology", "Cosmology arguments")
        cosmoargs.add_argument("--default", nargs="*",
                            choices=['planck1_base'], help="base cosmology to use [default: %s]" % h.transfer.cosmo.default)
        cosmoargs.add_argument("--force-flat", action="store_true",
                            help="force cosmology to be flat (changes omega_lambda) [default: %s]" % h.transfer.cosmo.force_flat)
        cosmoargs.add_argument("--sigma-8", nargs="*", type=float, help="mass variance in top-hat spheres with r=8")
        cosmoargs.add_argument("--n", nargs="*", type=float, help="spectral index")
        cosmoargs.add_argument("--w", nargs="*", type=float, help="dark energy equation of state")
        cosmoargs.add_argument("--cs2-lam", nargs="*", type=float, help="constant comoving sound speed of dark energy")

        h_group = cosmoargs.add_mutually_exclusive_group()
        h_group.add_argument("--h", nargs="*", type=float, help="The hubble parameter")
        h_group.add_argument("--H0", nargs="*", type=float, help="The hubble constant")

        omegab_group = cosmoargs.add_mutually_exclusive_group()
        omegab_group.add_argument("--omegab", nargs="*", type=float, help="baryon density")
        omegab_group.add_argument("--omegab-h2", nargs="*", type=float, help="baryon density by h^2")

        omegac_group = cosmoargs.add_mutually_exclusive_group()
        omegac_group.add_argument("--omegac", nargs="*", type=float, help="cdm density")
        omegac_group.add_argument("--omegac-h2", nargs="*", type=float, help="cdm density by h^2")
        omegac_group.add_argument("--omegam", nargs="*", type=float, help="total matter density")

        cosmoargs.add_argument("--omegav", type=float, nargs="*", help="the dark energy density")

        # Process arguments
        args = parser.parse_args()

        # # Process the arguments
        kwargs = {}
        for arg in ["omegab", "omegab_h2", "omegac", "omegac_h2", "omegam", "h", "H0",
                    "sigma_8", "n", "w", "cs2_lam", "omegav", "ThreadNum", "transfer__kmax",
                    "transfer__k_per_logint", "AccuracyBoost", "lAccuracyBoost",
                    "Scalar_initial_condition", "z", "z2", "nz", "delta_c", "user_fit", "delta_h",
                    "delta_wrt"]:
            if getattr(args, arg) is not None:
                kwargs[arg] = getattr(args, arg)

        if args.M is not None:
            kwargs["M"] = np.arange(args.M[0], args.M[1], args.M[2])

        if args.mf_fit is not None:
            if "all" in args.mf_fit:
                kwargs['mf_fit'] = hmf.Fits.mf_fits
                kwargs["mf_fit"].remove("user_model")
            else:
                kwargs['mf_fit'] = args.mf_fit

        if args.user_fit is not None:
            if "user_model" not in kwargs['mf_fit']:
                kwargs['mf_fit'].append("user_model")

        if args.no_cut_fit:
            kwargs['cut_fit'] = not args.no_cut_fit

        if args.w_perturb:
            kwargs["w_perturb"] = args.w_perturb

        if args.lnk is not None:
            kwargs["lnk"] = np.arange(args.lnk[0], args.lnk[1], args.lnk[2])

        if args.transfer_fit is not None:
            if 'all' in args.transfer_fit:
                kwargs["transfer_fit"] = hmf.Transfer.fits
            else:
                kwargs["transfer_fit"] = args.transfer_fit


        m_att = [a for a in args.get if a in m_attrs]
        k_att = [a for a in args.get if a in k_attrs]
        # # run the hmf
        for res, label in hmf.tools.get_hmf(args.get, **kwargs):
            if m_att:
                marray = np.empty((len(h.M), len(m_att)))
                for i, attr in enumerate(m_att):
                    marray[:, i] = getattr(res, attr)
                np.savetxt(args.filename + "_MDATA_" + label, marray, header="\t".join(m_att))
            if k_att:
                karray = np.empty((len(h.transfer.lnk), len(k_att)))
                for i, attr in enumerate(k_att):
                    karray[:, i] = getattr(res, attr)
                np.savetxt(args.filename + "_KDATA_" + label, karray, header="\t".join(k_att))


        return 0
    except KeyboardInterrupt:
        ### handle keyboard interrupt ###
        return 0
    except Exception, e:
        if DEBUG or TESTRUN:
            raise(e)
        traceback.print_exc()
        indent = len(program_name) * " "
        sys.stderr.write(program_name + ": " + repr(e) + "\n")
        sys.stderr.write(indent + "  for help use --help\n")
        return 2

if __name__ == "__main__":
    if DEBUG:
        sys.argv.append("-h")
        sys.argv.append("-v")
    if TESTRUN:
        import doctest
        doctest.testmod()
    if PROFILE:
        import cProfile
        import pstats
        profile_filename = 'scripts.hmfrun_profile.txt'
        cProfile.run('main()', profile_filename)
        statsfile = open("profile_stats.txt", "wb")
        p = pstats.Stats(profile_filename, stream=statsfile)
        stats = p.strip_dirs().sort_stats('cumulative')
        stats.print_stats()
        statsfile.close()
        sys.exit(0)
    sys.exit(main())
