#!/usr/bin/env python3.6
import argparse
import vyper
from collections import Counter
from pprint import pprint
from vyper import compiler
from vyper.parser import (
    parser,
)
from vdb.vdb import (
    set_evm_opcode_debugger,
    set_evm_opcode_pass
)
from vdb.source_map import (
    produce_source_map
)
from eth_tester import (
    EthereumTester,
)
from web3.providers.eth_tester import (
    EthereumTesterProvider,
)
from web3 import (
    Web3,
)

aparser = argparse.ArgumentParser(description='Vyper {0} quick CLI runner'.format(vyper.__version__))
aparser.add_argument('input_file', help='Vyper sourcecode to run')
aparser.add_argument('call_list', help='call list, without parameters: func, with parameters func(1, 2, 3). Semicolon separated')
aparser.add_argument('-i', help='init args, comma separated', default=None, dest='init_args')

args = aparser.parse_args()
set_evm_opcode_pass()  # by default just pass over the debug opcode.


def cast_types(args, abi_signature):
    newargs = args.copy()
    for idx, abi_arg in enumerate(abi_signature['inputs']):
        if abi_arg['type'] in ('int128', 'uint256'):
            newargs[idx] = int(args[idx])
        elif abi_arg['type'].startswith('bytes'):
            newargs[idx] = args[idx].encode()
    return newargs


def get_tester():
    tester = EthereumTester()
    def zero_gas_price_strategy(web3, transaction_params=None):
        return 0  # zero gas price makes testing simpler.
    w3 = Web3(EthereumTesterProvider(tester))
    w3.eth.setGasPriceStrategy(zero_gas_price_strategy)
    return tester, w3


def get_contract(w3, source_code, *args, **kwargs):
    abi = compiler.mk_full_signature(source_code)
    bytecode = '0x' + compiler.compile(source_code).hex()
    contract = w3.eth.contract(abi=abi, bytecode=bytecode)

    value = kwargs.pop('value', 0)
    value_in_eth = kwargs.pop('value_in_eth', 0)
    value = value_in_eth * 10**18 if value_in_eth else value  # Handle deploying with an eth value.
    gasPrice = kwargs.pop('gasPrice', 0)
    deploy_transaction = {
        'from': w3.eth.accounts[0],
        'data': contract._encode_constructor_data(args, kwargs),
        'value': value,
        'gasPrice': gasPrice
    }
    tx = w3.eth.sendTransaction(deploy_transaction)
    address = w3.eth.getTransactionReceipt(tx)['contractAddress']
    contract = w3.eth.contract(address, abi=abi, bytecode=bytecode)
    # Filter logs.
    contract._logfilter = w3.eth.filter({
        'fromBlock': w3.eth.blockNumber - 1,
        'address': contract.address
    })
    return contract


def get_func_abi(abi, func_name, args):
    # next(filter(lambda func: func["name"] == func_name, abi))
    def guess_type(v):
        # is annotated type.
        if ':' in v:
            return v.split(':')[1]
        # otherwise just guess
        try:
            int(v)
            return 'int128'
        except ValueError:
            return 'bytes'

    func_name_count_map = dict(Counter([a['name'] for a in abi]))
    for candidate_func_abi in abi:
        if candidate_func_abi["type"] == "function":
            # try func name first.
            if candidate_func_abi["name"] == func_name and \
               func_name_count_map[candidate_func_abi['name']] == 1:
                if len(args) != len(candidate_func_abi['inputs']):
                    print('Incorrect arguments for {}'.format(func_name))
                    return
                else:
                    return candidate_func_abi
            # is overloaded function, use full signature.
            else:
                full_sig = "{func_name}({type_str})".format(
                    func_name=func_name,
                    type_str=','.join([guess_type(x) for x in args])
                )
                method = "{func_name}({type_str})".format(
                    func_name=candidate_func_abi['name'],
                    type_str=','.join([x['type'] for x in candidate_func_abi['inputs']])
                )
                if method == full_sig:
                    return candidate_func_abi


if __name__ == '__main__':

    with open(args.input_file) as fh:
        code = fh.read()
        # Patch in vdb.
        init_args = args.init_args.split(',') if args.init_args else []
        tester, w3 = get_tester()

        # Built list of calls to make.
        calls = []
        for signature in args.call_list.split(';'):
            name = signature.strip()
            args = []

            if '(' in signature:
                start_pos = signature.find('(')
                name = signature[:start_pos].strip()
                args = signature[start_pos + 1:-1].split(',')
                args = [arg.strip() for arg in args]
                args = [arg for arg in args if len(arg) > 0]

            calls.append((name, args))

            abi = compiler.mk_full_signature(code)

        # Format init args.
        if init_args:
            init_abi = next(filter(lambda func: func["name"] == '__init__', abi))
            init_args = cast_types(init_args, init_abi)


        # Compile contract to chain.
        contract = get_contract(w3, code, *init_args)

        # Execute calls
        for func_name, args in calls:
            if not hasattr(contract.functions, func_name):
                print('\n No method {} found, skipping.'.format(func_name))
                continue

            print('\n* Calling {}({})'.format(func_name, ','.join(args)))

            func_abi = get_func_abi(abi, func_name, args)
            if not func_abi:
                print('Did not find function in abi.')
                break

            cast_args = cast_types(args, func_abi)

            res = getattr(contract.functions, func_name)(*cast_args).call({'gas': func_abi.get('gas', 0) + 50000})

            source_map = produce_source_map(code)
            set_evm_opcode_debugger(source_code=code, source_map=source_map)
            tx_hash = getattr(contract.functions, func_name)(*cast_args).transact({'gas': func_abi.get('gas', 0) + 50000})
            set_evm_opcode_pass()

            print('- Returns:')
            res = getattr(contract.functions, func_name)(*cast_args).call({'gas': func_abi['gas'] + 92000})
            pprint('{}'.format(res))

            # Detect any new log events, and print them.
            print('- Logs:')
            event_names = [x['name'] for x in abi if x['type'] == 'event']
            tx_receipt = w3.eth.getTransactionReceipt(tx_hash)
            for event_name in event_names:
                logs = getattr(contract.events, event_name)().processReceipt(tx_receipt)
                for log in logs:
                    print(log.event + ":")
                    pprint(dict(log.args))
                else:
                    print(' No events found.')
