#!/usr/bin/env python

import ast
import doctest
import logging
logging.basicConfig(level=logging.INFO)

import astor
from epc.server import EPCServer
from node_transformers import Annotator, FunctionExploder, SyntaxRewriter, IPythonEmbedder, LineNumberFinder

server = EPCServer(('localhost', 0))


@server.register_function
def annotate(*code):
    """Annotate code with code to make and eval cells

    Args:
        s (str): the code
        ns (str): the namespace

    `ns` is of the form

        - <module-name>
        - <module-name>.<func-name>
        - <module-name>.<class-name>.<method-name>

    code = [s, ns]

    >>> s = '''
    ...
    ... x
    ... class Foo:
    ...     def bar():
    ...         \"\"\"function\"\"\"
    ...         pass
    ...     def biz():
    ...         \"\"\"function\"\"\"
    ...         pass
    ...
    ... class Qux:
    ...     def quux():
    ...         \"\"\"function\"\"\"
    ...         pass
    ...     def quuux():
    ...         \"\"\"function\"\"\"
    ...         pass
    ... y
    ...
    ... '''
    >>>
    >>> namespace = 'ast-server.Qux.quux'
    >>> code = [s, namespace]

    """
    code, namespace = code[0], code[1]
    tree = ast.parse(code)
    ns_tokens = namespace.split('.')
    if len(ns_tokens) == 1: # top-level
        tree.body = [stmt for stmt in tree.body if not isinstance(stmt, ast.FunctionDef) and not isinstance(stmt, ast.ClassDef)]
    elif len(ns_tokens) == 2: # function
        module_name, func_name = ns_tokens
        funcs = [stmt for stmt in tree.body if isinstance(stmt, ast.FunctionDef)]
        tree.body = [func for func in funcs if func.name == func_name]
    else: # method
        assert len(ns_tokens) == 3
        module_name, class_name, method_name = ns_tokens
        classdefs = [stmt for stmt in tree.body if isinstance(stmt, ast.ClassDef)]
        classdef = [classdef for classdef in classdefs if classdef.name == class_name][0]
        methods = [stmt for stmt in classdef.body if isinstance(stmt, ast.FunctionDef)]
        for method in methods:
            if method.name == method_name:
                method.name = namespace # rename method for readability
                tree.body = [method]
                break

    exploded_tree = FunctionExploder(buffer=namespace).visit(tree)
    rewritten_tree = SyntaxRewriter(buffer=namespace).visit(exploded_tree)
    annotated_tree = Annotator(buffer=namespace).visit(rewritten_tree)
    new_code = astor.to_source(annotated_tree)

    return new_code

@server.register_function
def parse_namespaces(*code):
    """Parse namespaces out of the code

    Returns:
        namespaces (list): a list of 3-tuples = [
            (namespace, start-line, end-line),
            (namespace, start-line, end-line),
            .
            .
            .
            (namespace, start-line, end-line),
        ]

    >>> s = '''
    ...
    ... x
    ... class Foo:
    ...     def bar():
    ...         \"\"\"function\"\"\"
    ...         pass
    ...     def biz():
    ...         \"\"\"function\"\"\"
    ...         pass
    ...
    ... class Qux:
    ...     def quux():
    ...         \"\"\"function\"\"\"
    ...         pass
    ...     def quuux():
    ...         \"\"\"function\"\"\"
    ...         pass
    ... y
    ...
    ... '''
    >>> code = [s, 'ast-server']

    """
    code, namespace = code[0], code[1]
    tree = ast.parse(code)
    new_code = str()
    module_name = namespace
    funcs = [stmt for stmt in tree.body if isinstance(stmt, ast.FunctionDef)]
    methods = []
    classdefs = [stmt for stmt in tree.body if isinstance(stmt, ast.ClassDef)]
    for classdef in classdefs:
        for expr in classdef.body:
            if not isinstance(expr, ast.FunctionDef):
                continue
            methods.append([classdef, expr])
    namespaces = \
        [(f'ns={module_name}', -1, -1)] + \
        [(f'ns={module_name}.{func.name}', func.lineno, func.body[-1].lineno) for func in funcs] + \
        [(f'ns={module_name}.{classdef.name}.{method.name}', method.lineno, method.body[-1].lineno) for classdef, method in methods]
    namespaces = list(reversed(namespaces))
    return namespaces

@server.register_function
def embed(*code):
    """Replace the function or method with a call to `IPython.embed()`

    Args:
        s (str): the code
        ns (str): the namespace

    `ns` is of the form

        - <module-name>.<func-name>
        - <module-name>.<class-name>.<method-name>

    code = [s, ns]

    >>> s = '''
    ...
    ... x
    ... class Foo:
    ...     def bar():
    ...         \"\"\"function\"\"\"
    ...         pass
    ...     def biz():
    ...         \"\"\"function\"\"\"
    ...         pass
    ...
    ... class Qux:
    ...     def quux():
    ...         \"\"\"function\"\"\"
    ...         pass
    ...     def quuux():
    ...         \"\"\"function\"\"\"
    ...         pass
    ... y
    ...
    ... '''
    >>>
    >>> namespace = 'ast_server.Qux.quux'
    >>> code = [s, namespace]

    """
    code, namespace = code[0], code[1]
    tree = ast.parse(code)
    embedded = IPythonEmbedder(namespace).visit(tree)
    c = astor.to_source(embedded)
    return c

def find_namespace(code, func_name, lineno):
    """Compute the fully qualified namespace of `func_name` at `lineno` from `code`

    Args:
        code (str): the code
        func_name (str): the function name
        lineno (str): the line that `func_name` is defined at in `code`

    Returns:
        A namespace string

        - <func-name>
        - <class-name>.<method-name>

    >>> code = '''
    ...
    ... x
    ... class Foo:
    ...     def bar():
    ...         \"\"\"function\"\"\"
    ...         pass
    ...     def biz():
    ...         \"\"\"function\"\"\"
    ...         pass
    ...
    ... class Qux:
    ...     def bar():
    ...         \"\"\"function\"\"\"
    ...         pass
    ...     def biz():
    ...         \"\"\"function\"\"\"
    ...         pass
    ... y
    ...
    ... '''
    >>>
    >>> func_name = 'bar'
    >>> lineno = 5

    """
    namespace = None
    try:
        LineNumberFinder(func_name, lineno).visit(tree)
    except Exception as e:
        namespace, = e.args
    return namespace


if __name__ == '__main__':
    server.print_port()
    server.serve_forever()

if __name__ == '__test__':
    code = '''

    def foo(a):
        """This is a docstring

        >>> a = 7

        """
        for i in range(a):
            print(i)

    '''
    active_funcname = 'foo'
    annotated_code = annotate(code, active_funcname)
    print(annotated_code)
