#! /usr/bin/env python
from __future__ import print_function

import profile as _pprofile
import cProfile as _cprofile
from data_profiler import profile as _aprofile
from data_profiler import pstats as _apstats
import pstats as _ppstats
import timeit
import sys
import os
import re
from itertools import chain
try:
    from urllib.parse import unquote_plus
except ImportError:
    from urllib import unquote_plus

import cmd
try:
    import readline
except ImportError:
    pass

timer = timeit.default_timer

ProfileClass = None
StatsClass = None
stats = None

class Stats(_apstats.Stats):
    """Extend the base in different ways.

    * augment stats report with function call info (signature)
    """

    def remove(self, pattern):
        """Remove unwanted entries based on pattern-matching."""

        rex = re.compile(pattern)
        new_stats = {}
        for func in self.stats:
            if not rex.search(self.func_info(func)):
                new_stats[func] = self.stats[func]
        self.stats = new_stats

    
    def func_info(self, func):
        """Print additional info about function calls."""


        if func[:2] == ('~', 0):
            # special case for built-in functions
            name = func[2]
            if name.startswith('<') and name.endswith('>'):
                return '{%s}' % name[1:-1]
            else:
                return name
        elif type(func[2]) in (str, _aprofile.FunctionCall):
            return '%s:%d:%s'%func
        else:
            # Annotate output here with function argument info
            return '%s:%d:%s(%s)'%(func[0], func[1],
                                   func[2].name,
                                   ', '.join('%s:%s'%(a[0][1],a[1][1])
                                             for a in func[2].args))

# This class is copied straight out of the pstats.py module.
# We need to duplicate it here for it to use our extended Stats
# type (which reads in profile dumps created with pickle).
class ProfileBrowser(cmd.Cmd):
    def __init__(self, stats):#profile=None):
        cmd.Cmd.__init__(self)
        self.prompt = "% "
        self.stats = stats
        self.stream = sys.stdout
        #if profile is not None:
        #    self.do_read(profile)

    def generic(self, fn, line):
        args = line.split()
        processed = []
        for term in args:
            try:
                processed.append(int(term))
                continue
            except ValueError:
                pass
            try:
                frac = float(term)
                if frac > 1 or frac < 0:
                    print("Fraction argument must be in [0, 1]", file=self.stream)
                    continue
                processed.append(frac)
                continue
            except ValueError:
                pass
            processed.append(term)
        if self.stats:
            getattr(self.stats, fn)(*processed)
        else:
            print("No statistics object is loaded.", file=self.stream)
        return 0
    def generic_help(self):
        print("Arguments may be:", file=self.stream)
        print("* An integer maximum number of entries to print.", file=self.stream)
        print("* A decimal fractional number between 0 and 1, controlling", file=self.stream)
        print("  what fraction of selected entries to print.", file=self.stream)
        print("* A regular expression; only entries with function names", file=self.stream)
        print("  that match it are printed.", file=self.stream)

    def do_add(self, line):
        if self.stats:
            self.stats.add(line)
        else:
            print("No statistics object is loaded.", file=self.stream)
        return 0
    def help_add(self):
        print("Add profile info from given file to current statistics object.", file=self.stream)

    def do_callees(self, line):
        return self.generic('print_callees', line)
    def help_callees(self):
        print("Print callees statistics from the current stat object.", file=self.stream)
        self.generic_help()

    def do_callers(self, line):
        return self.generic('print_callers', line)
    def help_callers(self):
        print("Print callers statistics from the current stat object.", file=self.stream)
        self.generic_help()

    def do_EOF(self, line):
        print("", file=self.stream)
        return 1
    def help_EOF(self):
        print("Leave the profile brower.", file=self.stream)

    def do_quit(self, line):
        return 1
    def help_quit(self):
        print("Leave the profile brower.", file=self.stream)

    def do_read(self, line):
        if line:
            try:
                self.stats = StatsClass(line)
            except OSError as err:
                print(err.args[1], file=self.stream)
                return
            except Exception as err:
                print(err.__class__.__name__ + ':', err, file=self.stream)
                return
            self.prompt = line + "% "
        elif len(self.prompt) > 2:
            line = self.prompt[:-2]
            self.do_read(line)
        else:
            print("No statistics object is current -- cannot reload.", file=self.stream)
        return 0
    def help_read(self):
        print("Read in profile data from a specified file.", file=self.stream)
        print("Without argument, reload the current file.", file=self.stream)

    def do_reverse(self, line):
        if self.stats:
            self.stats.reverse_order()
        else:
            print("No statistics object is loaded.", file=self.stream)
        return 0
    def help_reverse(self):
        print("Reverse the sort order of the profiling report.", file=self.stream)

    def do_remove(self, line):
        if self.stats:
            self.stats.remove(line)
        else:
            print("No statistics object is loaded.", file=self.stream)
        return 0
    def help_remove(self):
        print("Remove matching entries from the profiling report.", file=self.stream)

    def do_sort(self, line):
        if not self.stats:
            print("No statistics object is loaded.", file=self.stream)
            return
        abbrevs = self.stats.get_sort_arg_defs()
        if line and all((x in abbrevs) for x in line.split()):
            self.stats.sort_stats(*line.split())
        else:
            print("Valid sort keys (unique prefixes are accepted):", file=self.stream)
            for (key, value) in Stats.sort_arg_dict_default.items():
                print("%s -- %s" % (key, value[1]), file=self.stream)
        return 0
    def help_sort(self):
        print("Sort profile data according to specified keys.", file=self.stream)
        print("(Typing `sort' without arguments lists valid keys.)", file=self.stream)
    def complete_sort(self, text, *args):
        return [a for a in Stats.sort_arg_dict_default if a.startswith(text)]

    def do_stats(self, line):
        return self.generic('print_stats', line)
    def help_stats(self):
        print("Print statistics from the current stat object.", file=self.stream)
        self.generic_help()

    def do_strip(self, line):
        if self.stats:
            self.stats.strip_dirs()
        else:
            print("No statistics object is loaded.", file=self.stream)
    def help_strip(self):
        print("Strip leading path information from filenames in the report.", file=self.stream)

    def help_help(self):
        print("Show help for a given command.", file=self.stream)

    def postcmd(self, stop, line):
        if stop:
            return stop
        return None

def json_stats(stats):
    """
    Convert the all_callees data structure to something compatible with
    JSON. Mostly this means all keys need to be strings.

    """
    keyfmt = '{0}:{1}({2})'.format
    #keyfmt = '{0}:{1}:{2}'.format
    def _replace_keys(d):
        return dict((keyfmt(*k), v) for k, v in d.items())

    stats.calc_callees()

    nstats = {}

    for k, v in stats.all_callees.items():
        nk = keyfmt(*k)
        nstats[nk] = {}
        nstats[nk]['children'] = dict(
            (keyfmt(*ck), list(cv)) for ck, cv in v.items())
        nstats[nk]['stats'] = list(stats.stats[k][:4])
        nstats[nk]['callers'] = dict(
            (keyfmt(*ck), list(cv)) for ck, cv in stats.stats[k][-1].items())
        nstats[nk]['display_name'] = keyfmt(os.path.basename(k[0]), k[1], k[2])

    # remove anything that both never called anything and was never called
    # by anything.
    # this is profiler cruft.
    no_calls = set(k for k, v in nstats.items() if not v['children'])
    called = set(chain.from_iterable(
        d['children'].keys() for d in nstats.values()))
    cruft = no_calls - called

    for c in cruft:
        del nstats[c]

    return nstats

    
def main():
    
    global ProfileClass
    global StatsClass
    global stats

    import argparse

    parser = argparse.ArgumentParser(description='Profile a Python script.')
    parser.add_argument('script', nargs='?', help='a script to be profiled')
    parser.add_argument('-p', dest='profiler', choices=['n', 'p', 'c', 'a'],
        help='select a profiler: n == none, p == profile, c == cProfile d == data_profiler')
    parser.add_argument('-f', dest='filename', help='profile filename to use.')
    parser.add_argument('-s', dest='sort', default=1,
        help="Sort order when printing to stdout, based on pstats.Stats class")
    parser.add_argument('-r', dest='remove', action='append', default=[],
        help="Remove functions matching the given pattern.")
    parser.add_argument('-b', action='store_true', dest='browse',
        help="Browse profile statistics in an interactive session.")
    parser.add_argument('-v', action='store_true', dest='visualize',
        help="Visualize profile statistics using snakeviz.")

    args, remainder = parser.parse_known_args()
    if args.profiler == 'd':
        ProfileClass = _aprofile.Profile
        StatsClass = Stats
    elif args.profiler == 'c':
        ProfileClass = _cprofile.Profile
        StatsClass = _ppstats.Stats
    elif args.profiler == 'p':
        ProfileClass = _pprofile.Profile
        StatsClass = _ppstats.Stats


    profiler = ProfileClass and ProfileClass()
    # If a script is given, run it.
    if args.script:
        # Restore the argument vector so the executed script
        # can deal with it as if it was invoked directly.
        sys.argv = [args.script] + remainder
        # Run the profiler
        sys.path.insert(0, os.path.dirname(args.script))
        with open(args.script, 'rb') as fp:
            code = compile(fp.read(), args.script, 'exec')
        globs = {
            '__file__': args.script,
            '__name__': '__main__',
            '__package__': None,
            '__cached__': None,
        }
        start = timer()
        if not profiler:
            exec(code, globs, None)
            print('elapsed time:', timer() - start)
            # Nothing left to do, quit.
            sys.exit()
        else:
            profiler.runctx(code, globs, None)
            print('elapsed time:', timer() - start)
            stats = StatsClass(profiler)
            for r in args.remove:
                stats.remove(r)
            if args.filename:
                stats.dump_stats(args.filename)
            elif not (args.browse or args.visualize):
                stats.strip_dirs()
                stats.sort_stats(args.sort)
                stats.print_stats()

    # If we have a filename but no stats object, instantiate one.
    if not stats:
        if not args.filename:
            print('missing filename argument')
        else:
            stats = StatsClass(args.filename)

    if args.browse:
        # Run the browser
        try:
            browser = ProfileBrowser(stats)
            print("Welcome to the profile statistics browser.", file=browser.stream)
            browser.cmdloop()
            print("Goodbye.", file=browser.stream)
        except KeyboardInterrupt:
            pass

    elif args.visualize:
        import threading
        import webbrowser
        import snakeviz
        import snakeviz.stats
        import tornado.ioloop
        import tornado.web

        class VizHandler(tornado.web.RequestHandler):
            def get(self, profile_name):
                # ignore profile_name, as we really only want to look at the
                # pre-generated stats here...
                #profile_name = unquote_plus(profile_name)
                #try:
                #    s = StatsClass(profile_name)
                #except:
                #    raise RuntimeError('Could not read %s.' % profile_name)
                self.render('viz.html', profile_name=profile_name,
                            table_rows=snakeviz.stats.table_rows(stats),
                            callees=json_stats(stats))

        # Run the visualizer
        try:
            handlers = [(r'/(.*)', VizHandler)]
            resdir=os.path.dirname(snakeviz.__file__)
            print('resdir', resdir)
            settings = {
                'static_path': os.path.join(resdir, 'static'),
                'template_path': os.path.join(resdir, 'templates'),
                'debug': True,
                'gzip': True
            }
            app = tornado.web.Application(handlers, **settings)
            hostname, port = 'localhost', 8080
            app.listen(port, address=hostname)
            url = "http://{0}:{1}".format(hostname, port)
            print(('snakeviz web server started on %s:%d; enter Ctrl-C to exit'%(hostname, port)))
            print(url)
            browser = webbrowser.get()
            def bt():
                browser.open(url, new=2)
            threading.Thread(target=bt).start()
            tornado.ioloop.IOLoop.instance().start()
        except KeyboardInterrupt:
            tornado.ioloop.IOLoop.instance().stop()
            print('\nBye!')
        
# When invoked as main program, invoke the profiler on a script
if __name__ == '__main__':
    main()

    
