#!/usr/bin/env python
# Copyright 2012 the rootpy developers
# distributed under the terms of the GNU General Public License
from __future__ import division

import os
import sys
import fnmatch

import ROOT

import rootpy
from rootpy import log
from rootpy import ROOTError
from rootpy.context import preserve_set_th1_add_directory
from rootpy.extern import argparse
from rootpy.utils.extras import humanize_bytes, print_table


def find_files(dirs, pattern=None):

    for arg in dirs:
        if os.path.isfile(arg):
            yield arg
            continue
        elif os.path.isdir(arg):
            for root, dirnames, filenames in os.walk(arg):
                if pattern is not None:
                    for filename in fnmatch.filter(filenames, pattern):
                        yield os.path.join(root, filename)
                else:
                    for filename in filenames:
                        yield os.path.join(root, filename)


def make_chain(args, **kwargs):

    if hasattr(args, 'staged') and args.staged:
        from rootpy.tree import TreeChain
        files = list(find_files(args.files, args.pattern))
        chain = TreeChain(
            args.tree,
            files,
            verbose=args.verbose,
            **kwargs)
        nfiles = len(files)
    else:
        nfiles = 0
        chain = ROOT.TChain(args.tree, '')
        for filename in find_files(args.files, args.pattern):
            nfiles += 1
            if args.verbose:
                print(filename)
            chain.Add(filename)
    return chain


class formatter_class(argparse.ArgumentDefaultsHelpFormatter,
                      argparse.RawTextHelpFormatter):
    pass


parser = argparse.ArgumentParser(formatter_class=formatter_class)
parser.add_argument(
    '-v', '--verbose', action='store_true', default=False)
parser.add_argument(
    '-V', '--version', action='version',
    version=rootpy.__version__,
    help="show the version number and exit")
parser.add_argument(
    '-d', '--debug', action='store_true', default=False,
    help="show a stack trace")
parser.add_argument(
    '-p', '--pattern', default='*.root*',
    help="files must match this pattern when searching in directories")
subparsers = parser.add_subparsers()


def entries(args):

    chain = make_chain(args)
    if args.selection is None:
        entries = chain.GetEntries()
        print("{0:d} entr{1}".format(
            entries,
            'ies' if entries != 1 else 'y'))
    else:
        from rootpy.tree import Cut
        selection = str(Cut(args.selection))
        entries = chain.GetEntries(selection)
        print("{0:d} entr{1} after selection {2}".format(
            entries,
            'ies' if entries != 1 else 'y',
            selection))


parser_entries = subparsers.add_parser('entries')
parser_entries.add_argument(
    '-s', '--selection', default=None,
    help="only entries satisfying this cut will be included in total")
parser_entries.add_argument(
    'tree',
    help="name of tree (including path) in each file")
parser_entries.add_argument('files', nargs='+')
parser_entries.set_defaults(op=entries)


def draw(args):

    from rootpy.interactive import wait
    from rootpy.plotting import Canvas
    canvas = Canvas()

    def update(*args, **kwargs):
        canvas.Modified()
        canvas.Update()

    chain = make_chain(
        args,
        onfilechange=[(update, ())])
    if args.selection is None:
        chain.Draw(args.expression, '', args.draw)
    else:
        from rootpy.tree import Cut
        selection = str(Cut(args.selection))
        chain.Draw(args.expression, selection, args.draw)
    wait(True)


parser_draw = subparsers.add_parser('draw')
parser_draw.add_argument(
    '-e', '--expression', required=True,
    help="expression to be drawn")
parser_draw.add_argument(
    '-s', '--selection', default=None,
    help="only entries satisfying this cut will be drawn")
parser_draw.add_argument(
    '-d', '--draw', default='',
    help="draw options")
parser_draw.add_argument(
    '--staged', action='store_true', default=False,
    help="update the canvas after each file is drawn")
parser_draw.add_argument(
    'tree',
    help="name of tree (including path) in each file")
parser_draw.add_argument('files', nargs='+')
parser_draw.set_defaults(op=draw)


def hsum(args):

    from rootpy.interactive import wait
    from rootpy.plotting import Canvas
    from rootpy.io import root_open as ropen

    total = None
    for filename in find_files(args.files):
        try:
            f = ropen(filename)
            h = f.get(args.hist)
        except ROOTError:
            log.warning("skipping file {0}".format(filename))
        else:
            if total is None:
                total = h.Clone()
                total.SetDirectory(0)
            else:
                total += h
            f.close()
    if total is not None:
        canvas = Canvas()
        total.Draw(args.draw)
        wait(True)


parser_sum = subparsers.add_parser('sum')
parser_sum.add_argument(
    '-d', '--draw', default='',
    help="draw options")
parser_sum.add_argument(
    'hist',
    help="name of histogram (including path) in each file")
parser_sum.add_argument('files', nargs='+')
parser_sum.set_defaults(op=hsum)


def merge(args):

    chain = make_chain(args)
    if os.path.exists(args.output):
        sys.exit("Output destination already exists.")
    print("Merging tree {0} in {1:d} files into {2} ...".format(
        args.tree, len(args.files), args.output))
    chain.Merge(
        args.output,
        'fast SortBasketsBy{0}'.format(args.sort_by.capitalize()))


parser_merge = subparsers.add_parser('merge')
parser_merge.add_argument(
    '-o', '--output', required=True,
    help="output file name")
parser_merge.add_argument(
    'tree',
    help="name of tree (including path) in each file")
parser_merge.add_argument(
    '--sort-by', choices=('offset', 'branch', 'entry'),
    default='offset',
    help="""\
When using 'offset' the baskets are written in
the output file in the same order as in the original file
(i.e. the basket are sorted on their offset in the original
file; Usually this also means that the baskets are sorted
on the index/number of the *last* entry they contain)

When using 'branch' all the baskets of each
individual branches are stored contiguously. This tends to
optimize reading speed when reading a small number (1->5) of
branches, since all their baskets will be clustered together
instead of being spread across the file. However it might
decrease the performance when reading more branches (or the full
entry).

When using 'entry' the baskets with the lowest
starting entry are written first. (i.e. the baskets are
sorted on the index/number of the first entry they contain).
This means that on the file the baskets will be in the order
in which they will be needed when reading the whole tree
sequentially.""")
parser_merge.add_argument('files', nargs='+')
parser_merge.set_defaults(op=merge)


def ls(args):
    from rootpy.io import root_open as ropen
    for i, filename in enumerate(args.files):
        if not os.path.isfile(filename):
            sys.exit("file {0} does not exist".format(filename))
        with ropen(filename) as f:
            if len(args.files) > 1:
                if i > 0:
                    print()
                print("{0}:".format(filename))
            for dirpath, dirnames, objects in f.walk():
                depth = 0
                if dirpath:
                    depth = dirpath.count('/') + 1
                    print(dirpath)
                prefix = '   ' * depth
                for name in objects:
                    thing = f.Get(os.path.join(dirpath, name))
                    print("{0} {1}".format(
                        prefix + thing.__class__.__name__,
                        thing.GetName()))
                    if not args.showinfo:
                        continue
                    if hasattr(thing, "GetEntries"):
                        print(prefix + " Entries: {0:d}".format(
                            thing.GetEntries()))
                    if hasattr(thing, "GetWeight"):
                        print(prefix + " Weight: {0:e}".format(
                            thing.GetWeight()))
                    if hasattr(thing, "Integral"):
                        print(prefix + " Integral: {0:e}".format(
                            thing.Integral()))
                    if hasattr(thing, "GetListOfFriends"):
                        friends = thing.GetListOfFriends()
                        if friends:
                            print(prefix + " Friends:")
                            for friend in friends:
                                print(prefix + " {0} in {1}".format(
                                    friend.GetName(), friend.GetTitle()))

parser_ls = subparsers.add_parser('ls')
parser_ls.add_argument(
    '-l', '--showinfo', action='store_true', default=False,
    help="display object properties")
parser_ls.add_argument('files', nargs='+')
parser_ls.set_defaults(op=ls)


def tree(args):
    from rootpy.io import root_open as ropen
    import re
    import fnmatch
    from operator import itemgetter

    file = ropen(args.file)
    tree = file.Get(args.tree)

    if args.showtypes:
        totalsize = 0
        totalmatchedsize = 0
        table = []
        for branch in tree.GetListOfBranches():
            typename = branch.GetClassName()
            if args.uncompressed:
                branchsize = branch.GetTotBytes('*')
            else:
                branchsize = branch.GetZipBytes('*')
            totalsize += branchsize
            if args.regex is not None:
                if not re.match(args.regex, branch.GetName()):
                    continue
            if args.glob is not None:
                if not re.search(fnmatch.translate(args.glob),
                                 branch.GetName()):
                    continue
            if not typename:
                typename = branch.GetListOfLeaves()[0].GetTypeName()
            table.append((branchsize, (
                humanize_bytes(branchsize), typename, branch.GetName())))
            totalmatchedsize += branchsize
        table.sort(key=itemgetter(0), reverse=True)
        table = [row[1] for row in table]
        print_table(table)
        print("total size {0}".format(humanize_bytes(totalmatchedsize)))
        if totalmatchedsize != totalsize:
            print("{0:.3g}% of full tree size".format(
                100. * totalmatchedsize / totalsize))
    else:
        for branch in tree.GetListOfBranches():
            if args.regex is not None:
                if not re.match(args.regex, branch.GetName()):
                    continue
            if args.glob is not None:
                if not re.search(fnmatch.translate(args.glob),
                                 branch.GetName()):
                    continue
            print(branch.GetName())
    file.Close()

parser_tree = subparsers.add_parser('tree')
parser_tree.add_argument(
    '-t', '--tree',
    help="Tree name", required=True)
parser_tree.add_argument(
    '-l', '--showtypes', action="store_true",
    help="show branch types/classnames and sizes", default=False)
parser_tree.add_argument(
    '-e', '--regex',
    help="only show branches matching this regex", default=None)
parser_tree.add_argument(
    '-g', '--glob',
    help="only show branches matching this glob", default=None)
parser_tree.add_argument(
    '-z', '--uncompressed', action="store_true",
    help="show uncompressed branch sizes", default=False)
parser_tree.add_argument('file')
parser_tree.set_defaults(op=tree)

browser = None


def browse(args):

    from rootpy.io import root_open as ropen

    if not args.file:
        files = [ropen(f) for f in os.listdir(".") if f.endswith(".root")]
    else:
        files = [ropen(f) for f in args.file]

    import rootpy.ROOT as ROOTWrapper
    namespace = {"R": ROOTWrapper, "ROOT": ROOTWrapper}

    for i, f in reversed(list(enumerate(files, 1))):
        namespace["f{0}".format(i)] = rootpy.asrootpy(f)

    if args.browser and files:
        browser = ROOT.TBrowser("browser", files[0])
        for f in files[1:]:
            browser.Add(f)
    else:
        browser = None

    with preserve_set_th1_add_directory(False):
        # Above context prevents apparent duplication of keys
        for f in reversed(files):
            for k in f.GetListOfKeys():
                n = k.GetName()
                if n not in namespace:
                    namespace[n] = rootpy.asrootpy(k.ReadObj(), warn=False)

    if args.ipython:
        vartext = ""
        try:
            from IPython.terminal.embed import InteractiveShellEmbed
        except ImportError:
            pass
        else:
            banner = "[[ IPython ]] files named f1 (up to f[n])" + vartext
            ip = InteractiveShellEmbed(banner1=banner)

            def pre_prompt_hook(_):
                ROOT.gInterpreter.EndOfLineAction()

            ip.set_hook('pre_prompt_hook', pre_prompt_hook)

            class RootpyBrowse(object):
                pass

            module = RootpyBrowse()
            module.__dict__ = namespace
            ip(local_ns=namespace, module=module)
            if args.browser and browser:
                browser.Delete()

    if browser:
        from rootpy.interactive.rootwait import wait_for_browser_close
        wait_for_browser_close(browser)

parser_browse = subparsers.add_parser('browse')
parser_browse.add_argument(
    '-I', '--no-ipython', action='store_false',
    help="don't try to start ipython", dest='ipython', default=True)
parser_browse.add_argument(
    '-B', '--no-browser', action='store_false',
    help="don't try to start a browser", dest='browser', default=True)
parser_browse.add_argument('file', nargs="*")
parser_browse.set_defaults(op=browse)


def include(args):
    from rootpy.userdata import BINARY_PATH

    if not os.path.exists(BINARY_PATH):
        os.makedirs(BINARY_PATH)
    
    include_list = os.path.join(
        BINARY_PATH, 'include_paths.list')

    if args.paths:
        with open(include_list, 'a') as inc_list:
            for path in args.paths:
                inc_list.write(path)
                inc_list.write('\n')
        return
    if not os.path.exists(include_list):
        print("There are no custom include paths stored in {0}".format(
            include_list))
    else:
        print("The following include paths are stored in {0}".format(
            include_list))
        print(open(include_list).read())


parser_include = subparsers.add_parser(
    'include', 
    description='Add one or multiple include paths to be used when compiling dictionaries.\n'
                'If no paths are given then all currently stored paths will be shown.')
parser_include.add_argument('paths', nargs='*', help='absolute include paths')
parser_include.set_defaults(op=include)


args = parser.parse_args()
try:
    args.op(args)
except KeyboardInterrupt:
    sys.exit(1)
except Exception as e:
    if args.debug:
        # If in debug mode show full stack trace
        import traceback
        traceback.print_exception(*sys.exc_info())
    else:
        sys.exit("{0}: {1}".format(e.__class__.__name__, e))
