#!python
import sys
import os
import argparse
import asyncio
import pprint
import pathlib
import enum

import prompt_toolkit.history
from hgdb import HGDBClient, DebugSymbolTable
from prompt_toolkit import PromptSession
from prompt_toolkit.patch_stdout import patch_stdout
from abc import abstractmethod
import pygments
import pygments.formatters
import pygments.lexers
from prompt_toolkit.auto_suggest import AutoSuggest, Suggestion
from prompt_toolkit.key_binding import KeyBindings


def get_arguments():
    parser = argparse.ArgumentParser("hgdb debugger")
    parser.add_argument("db", type=str, help="Filename to the symbol table")
    parser.add_argument("--port", "-p", dest="port", type=int, help="Port number for the simulator", default=8888)
    parser.add_argument("--no-db-connection", dest="no_db_connection", action="store_true",
                        help="If present, will not inform the simulator about the debug symbol table", default=False)
    parser.add_argument("--dir", "-d", "-w", "--workspace", dest="directory", type=str, default="",
                        help="Workspace directory to search when the filename is relative")
    args = parser.parse_args()
    return args


def index_filenames(filenames):
    result = {}
    conflicted = set()
    for filename in filenames:
        # check if we can do short hand
        basename = os.path.basename(filename)
        if basename != filename:
            result[filename] = filename
        if basename in conflicted:
            continue
        if basename in result:
            # detecting conflicts
            result.pop(basename)
            conflicted.add(basename)
            continue
        result[basename] = filename
    shorten = {}
    for name, value in result.items():
        if name != value:
            # it has been shortened
            shorten[value] = name
    return result, shorten


class DebuggingInformation:
    # this is just a class holding information
    def __init__(self, db, formatter, workspace):
        # list of commands
        self.current_scope = ""
        self.commands = {}
        self.commands_help = {}
        self.file_context_cache = {}
        self.current_breakpoint_fn = ""
        self.current_breakpoint_ln = 0
        self.current_breakpoint_cn = 0
        self.current_breakpoint_id = 0
        # for local vars
        self.local_vars = {}
        # already set values
        self.set_values = set()
        # used to resolve filename
        self.workspace = workspace
        self.current_instance_index = 0
        self.current_breakpoint_resp = None
        # current time
        self.current_time = 0

        self.filename_map, self.shortened_map = index_filenames(db.get_filenames())
        self.formatter = formatter

        self.print_help = None

    def parse(self, resp):
        payload = resp["payload"]
        self.current_breakpoint_fn = payload["filename"]
        self.current_breakpoint_ln = payload["line_num"]
        self.current_breakpoint_cn = payload["column_num"]
        # whenever we receive a resp, set the instance index to 0
        self.current_instance_index = 0
        instance = payload["instances"][self.current_instance_index]
        self.current_breakpoint_id = instance["breakpoint_id"]
        self.current_scope = str(self.current_breakpoint_id)
        self.local_vars = parse_local(instance["local"])
        self.current_breakpoint_resp = resp
        self.current_time = payload["time"]

        render_breakpoint(self.current_breakpoint_fn, self.current_breakpoint_ln, self.current_breakpoint_cn, self,
                          self.current_breakpoint_id)


class SubCommand:
    def __init__(self, parent, info, client, commands, help_str):
        self.parent = parent
        self.info: DebuggingInformation = info
        self.client = client

        self.parser: argparse.ArgumentParser = self.add_command(commands, help_str)

    def add_command(self, commands, help_str):
        parser = self.parent.add_parser(commands[0], help=help_str, aliases=commands[1:], add_help=False)
        # add dispatch to it

        async def function(args):
            return await self.dispatch(args)
        parser.set_defaults(dispatch=function)
        return parser

    @staticmethod
    def check_error(resp):
        if resp["status"] == "error":
            print(resp["payload"]["reason"])
            return False
        return True

    @abstractmethod
    async def dispatch(self, args):
        pass

    def _parse_fn_ln(self, expr):
        tokens = expr.split(":")
        if expr is None or len(expr) == 0:
            # use current line and line number
            filename = self.info.current_breakpoint_fn
            line_num = self.info.current_breakpoint_ln
        else:

            if len(tokens) < 2 or len(tokens) > 3:
                return None
            filename = tokens[0]
            if not tokens[1].isdigit():
                return None
            line_num = int(tokens[1])

        if filename not in self.info.filename_map:
            return None
        filename = self.info.filename_map[filename]

        column_num = 0
        if len(tokens) == 3:
            if not tokens[2].isdigit():
                return None
            column_num = int(tokens[2])
        return filename, line_num, column_num


class InsertBreakpointCommand(SubCommand):
    def __init__(self, parent, info, client):
        commands = ["b", "break"]
        help_str = "Set breakpoint"
        super(InsertBreakpointCommand, self).__init__(parent, info, client, commands, help_str)

        self.parser.add_argument("filename", help="Breakpoint line number information", type=str)

    async def dispatch(self, args):
        expr = args.filename
        r = self._parse_fn_ln(expr)
        if r is None:
            print("Invalid breakpoint", expr)
        else:
            filename, line_num, column_num = r
            resp = await self.client.set_breakpoint(filename, line_num, column_num, check_error=False)
            self.check_error(resp)


class RemoveBreakpointCommand(SubCommand):
    def __init__(self, parent, info, client):
        commands = ["delete", "d", "del"]
        help_str = "Delete breakpoints"
        super(RemoveBreakpointCommand, self).__init__(parent, info, client, commands, help_str)

        self.parser.add_argument("breakpoint_id", help="Breakpoint ID", type=int)

    async def dispatch(self, args):
        bp_id = args.breakpoint_id
        resp = await self.client.remove_breakpoint_id(bp_id, False)
        self.check_error(resp)


class ClearBreakpointCommand(SubCommand):
    def __init__(self, parent, info, client):
        commands = ["clear"]
        help_str = "Clear breakpoints"
        super(ClearBreakpointCommand, self).__init__(parent, info, client, commands, help_str)

        self.parser.add_argument("filename", nargs='?', help="Breakpoint filename:line_number", default="")

    async def dispatch(self, args):
        filename = args.filename
        if filename:
            r = self._parse_fn_ln(filename)
            if r is None:
                print("Unable to parse breakpoint", filename)
                return
            fn, ln, cn = r
            await self.client.remove_breakpoint(fn, ln, cn)
        else:
            # TODO:
            print("clear all not implemented")


class ContinueCommand(SubCommand):
    def __init__(self, parent, info, client):
        commands = ["c", "continue"]
        help_str = "Continue the execution"
        super(ContinueCommand, self).__init__(parent, info, client, commands, help_str)

    async def dispatch(self, _):
        await self.client.continue_()
        return True


class StepOverCommand(SubCommand):
    def __init__(self, parent, info, client):
        commands = ["n", "step-over"]
        help_str = "Step over the execution"
        super(StepOverCommand, self).__init__(parent, info, client, commands, help_str)

    async def dispatch(self, _):
        await self.client.step_over()
        return True


class StepBackCommand(SubCommand):
    def __init__(self, parent, info, client):
        commands = ["step-back"]
        help_str = "Step back the execution"
        super(StepBackCommand, self).__init__(parent, info, client, commands, help_str)

    async def dispatch(self, _):
        await self.client.step_back()
        return True


class ReverseContinueCommand(SubCommand):
    def __init__(self, parent, info, client):
        commands = ["rc", "reverse-continue"]
        help_str = "Reverse continue the execution"
        super(ReverseContinueCommand, self).__init__(parent, info, client, commands, help_str)

    async def dispatch(self, _):
        await self.client.reverse_continue()
        return True


class PrintCommand(SubCommand):
    def __init__(self, parent, info, client):
        commands = ["p", "print", "eval"]
        help_str = "Print value stored in variable or evaluate a expression"
        super(PrintCommand, self).__init__(parent, info, client, commands, help_str)

        self.parser.add_argument("expression", nargs="+", type=str, help="Variable/expression to print out")

    async def dispatch(self, args):
        expr = args.expression
        expr = " ".join(expr)
        target = locate_local_vars(expr, self.info.local_vars) if expr.strip() not in self.info.set_values else None
        if target is not None:
            pprint.pprint(target)
            return
        resp = await self.client.evaluate(self.info.current_scope, expr, is_context=True, check_error=False)
        if self.check_error(resp):
            if "result" not in resp["payload"]:
                print("Error in protocol setup\nRESP:", resp)
            else:
                print(resp["payload"]["result"])


class SetValueCommand(SubCommand):
    def __init__(self, parent, info, client):
        commands = ["set"]
        help_str = "Set the value to a variable. This will not work in replay mode"
        super(SetValueCommand, self).__init__(parent, info, client, commands, help_str)

        self.parser.add_argument("assignment", type=str, help="Format variable=value, e.g. a=6")

    async def dispatch(self, args):
        expr = args.assignment
        tokens = expr.split("=")
        if len(tokens) != 2:
            print("Invalid set value expression", expr)
            self.parser.print_usage()
            return
        var = tokens[0].strip()
        value = tokens[1].strip()
        if not value.isdigit():
            print("Value has to be an integer", value)
            self.parser.print_usage()
            return
        if not self.info.current_scope.isdigit():
            print("Invalid scope")
            return
        resp = await self.client.set_value(var, int(value), breakpoint_id=int(self.info.current_scope),
                                           check_error=False)
        if self.check_error(resp):
            self.info.set_values.add(var)


class ListCommand(SubCommand):
    def __init__(self, parent, info, client):
        commands = ["l", "list"]
        help_str = "List source code"
        super(ListCommand, self).__init__(parent, info, client, commands, help_str)

        self.parser.add_argument("filename", nargs="?", default="", type=str,
                                 help="File name and line number in the format of filename:line_number")
        self.parser.add_argument("-n", help="Number of lines centered at the current breakpoint", type=int,
                                 required=False, default=5)

    async def dispatch(self, args):
        expr = args.filename
        r = self._parse_fn_ln(expr)
        if r is None:
            if expr is None or len(expr) == 0:
                print("Invalid filename information")
            else:
                print("Invalid filename at", expr)
        else:
            filename, line_num, _ = r
            filename = resolve_filename(filename, self.info.workspace)
            # need to read out the filename given these lines
            lines = read_line(filename, line_num, self.info.file_context_cache, args.n)
            if lines is None:
                print("Unable to find file", filename)
            else:
                for line_num, line in lines:
                    print_line(filename, line_num, line, self.info.formatter)


class InfoCommand(SubCommand):
    def __init__(self, parent, info, client):
        commands = ["info"]
        help_str = "Print out simulation and debugging information"
        super(InfoCommand, self).__init__(parent, info, client, commands, help_str)

        self.parser.add_argument("type", type=InfoCommand.InfoEnum, choices=[v for v in InfoCommand.InfoEnum])

    class InfoEnum(enum.Enum):
        breakpoint = "breakpoint"
        threads = "threads"
        time = "time"

        def __repr__(self):
            return self.value

    async def dispatch(self, args):
        cmd = args.type
        if cmd == InfoCommand.InfoEnum.breakpoint:
            # need to get the current inserted breakpoints
            resp = await self.client.get_info("breakpoints")
            for bp in resp["payload"]["breakpoints"]:
                print("{0:8}\t{1}".format(bp["id"], get_fn_ln_cn(bp["filename"], bp["line_num"], bp["column_num"],
                                                                self.info.shortened_map)))
        elif cmd == InfoCommand.InfoEnum.threads:
            # print out the threads
            if self.info.current_breakpoint_resp is not None:
                print_threads(self.info.current_breakpoint_resp, self.info.current_instance_index)
        elif cmd == InfoCommand.InfoEnum.time:
            # print out the time
            print(self.info.current_time)


class ThreadCommand(SubCommand):
    def __init__(self, parent, info, client):
        commands = ["thread"]
        help_str = "Switch thread based on instance ID"
        super(ThreadCommand, self).__init__(parent, info, client, commands, help_str)

        self.parser.add_argument("instance_id", type=int, help="Instance ID")

    async def dispatch(self, args):
        instance_id = args.instance_id
        payload = self.info.current_breakpoint_resp["payload"]
        instances = payload["instances"]
        for i, instance in enumerate(instances):
            if instance["instance_id"] == instance_id:
                print("Switching to thread", i)
                self.info.current_instance_index = i


class ConditionCommand(SubCommand):
    def __init__(self, parent, info, client):
        commands = ["condition"]
        help_str = "Set breakpoint conditions"
        super(ConditionCommand, self).__init__(parent, info, client, commands, help_str)

        self.parser.add_argument("breakpoint_id", type=int, help="Breakpoint ID")
        self.parser.add_argument("expr", type=str, nargs="+", help="Breakpoint condition")

    async def dispatch(self, args):
        breakpoint_id = args.breakpoint_id
        condition = " ".join(args.expr)
        resp = await self.client.set_breakpoint_id(breakpoint_id, cond=condition)
        self.check_error(resp)


class GoCommand(SubCommand):
    def __init__(self, parent, info, client):
        commands = ["go"]
        help_str = "Jump to a particular time. Only works for replay mode"
        super(GoCommand, self).__init__(parent, info, client, commands, help_str)

        self.parser.add_argument("time", type=int, help="Simulation time")

    async def dispatch(self, expr):
        time_val = expr.time
        resp = await self.client.jump(time_val)
        # check the response
        # only the replay tool can be used to jump time
        if self.check_error(resp):
            # due to the event-driven set up. once we change the time it will start to evaluate everything
            return True


class ExitCommand(SubCommand):
    def __init__(self, parent, info, client):
        commands = ["q", "exit"]
        help_str = "Exit the debugger"
        super(ExitCommand, self).__init__(parent, info, client, commands, help_str)

    async def dispatch(self, _):
        print("exit")
        await self.client.close()
        raise ExitCommand.ExitException()

    class ExitException(Exception):
        pass


class HelpCommand(SubCommand):
    def __init__(self, parent, info, client):
        commands = ["help"]
        help_str = "Help"
        super(HelpCommand, self).__init__(parent, info, client, commands, help_str)

        self.parser.add_argument("command", nargs="?", default="",
                                 help="If provided, print out the help message for the particular command. "
                                      "Otherwise print out the help message for hgdb")

    async def dispatch(self, args):
        if args.command:
            command = args.command
            if command not in self.parent.choices:
                print("Unknown command", command)
            else:
                self.parent.choices[command].print_help()
        else:
            self.info.print_help()


async def get_client(filename, port, skip_connection=False):
    uri = "ws://localhost:{0}".format(port)
    # create a client and try to connect
    client = HGDBClient(uri, None if skip_connection else filename)
    # connect
    try:
        await client.connect()
    except AttributeError:
        print("Unable to connect to the debugger. Is the debugger running?")
        return None
    payload = await client.get_info("status")
    print(payload["payload"]["status"], end="")
    return client


def read_line(filename, line_num, file_context_cache, line_range=0):
    # notice the +1 line number conversion!
    if not os.path.exists(filename):
        return None
    else:
        if filename not in file_context_cache:
            with open(filename) as f:
                lines = f.readlines()
        else:
            lines = file_context_cache[filename]
        result = []
        min_line = max(1, line_num - line_range)
        max_line = min(len(lines), line_num + line_range)
        for i in range(min_line - 1, max_line):
            result.append((i, lines[i].rstrip()))
        return result


def print_line(filename, line_num, line, formatter):
    lexer = pygments.lexers.get_lexer_for_filename(filename)
    print(line_num + 1, pygments.highlight(line, lexer, formatter), end="")


def get_fn_ln_cn(filename, line_num, column_num, shorten_filename_map):
    if len(filename) == 0:
        return ""
    if filename in shorten_filename_map:
        filename = shorten_filename_map[filename]
    if column_num == 0:
        fn_ln_info = "{0}:{1}".format(filename, line_num)
    else:
        fn_ln_info = "{0}:{1}:{2}".format(filename, line_num, column_num)
    return fn_ln_info


class AutoSuggestFromHistoryPath(AutoSuggest):
    def __init__(self, path):
        self.path = path

    def get_suggestion(
            self, buffer, document):
        history = buffer.history
        text = document.text.rsplit("\n", 1)[-1]
        if text.strip():
            # match history first
            for string in reversed(list(history.get_strings())):
                for line in reversed(string.splitlines()):
                    if line.startswith(text):
                        return Suggestion(line[len(text):])
            tokens = text.split(" ")
            command = tokens[0]
            if len(tokens) > 1 and command in {"list", "l", "b", "breakpoint", "clear"}:
                command = command + " "
                path = text[text.index(command) + len(command):]
                if path:
                    for filename in self.path:
                        if filename.startswith(path):
                            return Suggestion(filename[len(path):])
        return None


def fix_local_str_num(target):
    if not isinstance(target, dict):
        if isinstance(target, str) and target.isdigit():
            return int(target)
        else:
            return target
    # only if all the keys are there and are numerical
    keys = set(target.keys())
    for i in range(len(target)):
        if str(i) not in keys:
            # skip this layer
            for key, value in target.items():
                target[key] = fix_local_str_num(value)
            return target
    result = [None for _ in range(len(target))]
    for key, value in target.items():
        result[int(key)] = fix_local_str_num(value)
    return result


def parse_local(local_vars):
    result = {}
    # notice that everything is flat from the debug server
    for name, value in local_vars.items():
        # replace [] with .
        name = name.replace("[", ".").replace("]", "")
        tokens = name.split(".")
        if len(tokens) == 1:
            result[name] = value
            continue
        # build up the hierarchy
        target = result
        for i in range(len(tokens) - 1):
            if tokens[i] not in target:
                target[tokens[i]] = {}
            target = target[tokens[i]]
        target[tokens[-1]] = value if not value.isdigit() else int(value)

    # second pass recursively change the map into array
    result = fix_local_str_num(result)
    return result


def locate_local_vars(expr, local_vars):
    if len(expr.split()) > 1:
        return None
    name = expr.replace("[", ".").replace("]", "")
    tokens = name.split(".")
    target = local_vars
    for var_name in tokens:
        key = int(var_name) if var_name.isdigit() else var_name
        if key not in target:
            return None
        target = target[key]
    return target


def resolve_filename(filename, workspace):
    if len(workspace) == 0 or os.path.isabs(filename):
        return filename
    # need to query the filesystem to find a match
    # might be slow if there are tons of files
    for root, dirs, files in os.walk(workspace):
        if filename in files:
            return os.path.join(root, filename)
    return filename


def print_threads(resp, current_idx):
    payload = resp["payload"]
    instances = payload["instances"]
    ids = [instance["instance_id"] for instance in instances]
    max_id = max(ids)
    pad_size = max(len(str(max_id)), 2)
    print("  " + " " * (pad_size - 2) + "ID\tInstance")
    for i, instance in enumerate(instances):
        fmt = "{0} {{0:{1}}}\t{{1}}".format("*" if i == current_idx else " ", pad_size)
        print(fmt.format(instance["instance_id"], instance["instance_name"]))


def render_breakpoint(filename, line_num, column_num, info, breakpoint_id):
    fn_ln_info = get_fn_ln_cn(filename, line_num, column_num, info.shortened_map)
    print("Breakpoint", breakpoint_id, "at", fn_ln_info)
    line_text = read_line(filename, line_num, info.file_context_cache)
    if line_text is not None and len(line_text) > 0:
        line_num, line = line_text[0]
        print_line(filename, line_num, line, info.formatter)


async def main_loop(client: HGDBClient, session, info, parser):
    # for tab auto-complete
    bindings = KeyBindings()

    @bindings.add('tab')
    def _(event):
        buffer = event.current_buffer
        suggestion: Suggestion = buffer.suggestion
        if suggestion is not None:
            buffer.text = buffer.text + suggestion.text
            buffer.cursor_position = len(buffer.text)

    while True:
        try:
            if sys.stdin.isatty():
                with patch_stdout():
                    result = await session.prompt_async("(hgdb) ",
                                                        auto_suggest=AutoSuggestFromHistoryPath(info.filename_map),
                                                        rprompt=get_fn_ln_cn(info.current_breakpoint_fn,
                                                                             info.current_breakpoint_ln,
                                                                             info.current_breakpoint_cn, info.shortened_map),
                                                        key_bindings=bindings)
            else:
                result = input("(hgdb) ")
        except KeyboardInterrupt:
            continue
        result = result.strip()
        args = result.split(" ")
        try:
            args_result = parser.parse_args(args)
        except (argparse.ArgumentError, argparse.ArgumentTypeError, SystemExit):
            continue
        if hasattr(args_result, "dispatch"):
            try:
                cmd_result = await args_result.dispatch(args_result)
                if cmd_result:
                    resp = await client.recv_bp()
                    if resp is None:
                        print("Simulator exited")
                        return
                    else:
                        info.parse(resp)
            except ExitCommand.ExitException:
                exit(0)


def main():
    args = get_arguments()
    filename = args.db
    if not os.path.exists(filename):
        print("Unable to find", filename, file=sys.stderr)
        exit(1)

    # load the table locally as well
    # use absolute path
    filename = os.path.abspath(filename)
    db = DebugSymbolTable(filename)
    client: HGDBClient = asyncio.get_event_loop().run_until_complete(get_client(filename, args.port,
                                                                                args.no_db_connection))
    if client is None:
        exit(1)
    # prompt session
    # use ~/.hgdb as persistent history
    prompt_history = os.path.join(pathlib.Path.home(), ".hgdb")
    session = PromptSession(history=prompt_toolkit.history.FileHistory(prompt_history)) if sys.stdin.isatty() else None

    # formatter
    formatter = pygments.formatters.get_formatter_by_name("terminal")

    info = DebuggingInformation(db, formatter, args.directory)
    parser = argparse.ArgumentParser("hgdb", add_help=False)
    sub_parsers = parser.add_subparsers()
    info.print_help = parser.print_help

    command_classes = SubCommand.__subclasses__()
    for cls in command_classes:
        cls(sub_parsers, info, client)

    # loop until finish
    try:
        asyncio.get_event_loop().run_until_complete(main_loop(client, session, info, parser))
    except (KeyboardInterrupt, EOFError):
        print("exit")


if __name__ == "__main__":
    main()
