#!/usr/bin/env python

import argparse
import asyncio
import json
import os
import sys

from argparse import ArgumentParser
from loguru import logger
from termcolor import colored
from prettytable import PrettyTable

from bittensor.balance import Balance
from bittensor.crypto import encrypt, is_encrypted, decrypt_data, KeyError
from bittensor.crypto.keyfiles import load_keypair_from_data, KeyFileError
from bittensor.subtensor.client import WSClient
from bittensor.subtensor.interface import Keypair
from bittensor.utils import Cli
from bittensor.utils.cli_utils import cli_utils
from bittensor.subtensor.client import Neuron, Neurons


class CommandExecutor:
    __keypair : Keypair
    __client : WSClient
    def __init__(self, keypair : Keypair, client : WSClient):
        self.__keypair = keypair
        self.__client = client

    async def regen_wallet( self, wallet_path: str, wallet_name: str, mnemonic: str):
        if wallet_name == None:
            wallet_name = input("Wallet name: ")

        # expand wallet path.
        wallet_path = os.path.expanduser(wallet_path + '/' + wallet_name)
        if not os.path.exists(wallet_path):
            os.makedirs(wallet_path)

        coldkey_path = wallet_path + "/coldkey"
        coldkey_path = cli_utils.validate_create_path( coldkey_path )
        coldkey_keypair = cli_utils.validate_generate_mnemonic( mnemonic )
        cli_utils.write_pubkey_to_text_file(coldkey_path, coldkey_keypair.public_key )
        
        password = cli_utils.input_password()
        print("Encrypting key ... (this might take a few moments)")
        json_data = json.dumps(coldkey_keypair.toDict()).encode()
        coldkey_data = encrypt(json_data, password)
        del json_data
        
        cli_utils.save_keys(coldkey_path, coldkey_data)
        cli_utils.set_file_permissions(coldkey_path)

        cli_utils.save_keys(coldkey_path, coldkey_data)
        cli_utils.set_file_permissions(coldkey_path)

        # Create default hotkey.
        hotkeys_dir_path = wallet_path + '/hotkeys/'
        if not os.path.exists(hotkeys_dir_path):
            os.makedirs(hotkeys_dir_path)
            
        cli_utils.create_hotkeys_dir_if_not_exists(wallet_path + "/hotkeys")
        choice = input("Create 'default' hotkey ? (y/N) ")
        if choice == "y":
            hotkey_path = wallet_path + "/hotkeys/default"
        else:
            hotkey_name = input("Hotkey name: ")
            hotkey_path = wallet_path + "/hotkeys/" + hotkey_name
        hotkey_path = cli_utils.validate_create_path( hotkey_path )
        hotkey_keypair = cli_utils.gen_new_key( 14 )
        hotkey_data = json.dumps(hotkey_keypair.toDict()).encode()
        cli_utils.save_keys(hotkey_path, hotkey_data)
        cli_utils.set_file_permissions(hotkey_path)

    async def new_wallet( self, n_words:int, wallet_path: str, wallet_name:str,  use_password: bool):

        if wallet_name == None:
            choice = input("Create 'default' wallet ? (y/N) ")
            if choice == "y":
                wallet_name = 'default'
            else:
                wallet_name = input("Wallet name: ")
        
        # expand wallet path.
        wallet_path = os.path.expanduser(wallet_path + '/' + wallet_name)
        if not os.path.exists(wallet_path):
            os.makedirs(wallet_path)
            os.makedirs(wallet_path + '/hotkeys/')

        # Create coldkey.
        coldkey_path = wallet_path + "/coldkey"
        coldkey_path = cli_utils.validate_create_path( coldkey_path )
        coldkey_keypair = cli_utils.gen_new_key( n_words )
        cli_utils.display_mnemonic_msg( coldkey_keypair )
        cli_utils.write_pubkey_to_text_file(coldkey_path, coldkey_keypair.public_key )
        if use_password:
            password = cli_utils.input_password()
            print("Encrypting coldkey ... (this might take a few moments)")
            coldkey_json_data = json.dumps(coldkey_keypair.toDict()).encode()
            coldkey_data = encrypt(coldkey_json_data, password)
            del coldkey_json_data
        else:
            coldkey_data = json.dumps(coldkey_keypair.toDict()).encode()
        cli_utils.save_keys(coldkey_path, coldkey_data)
        cli_utils.set_file_permissions(coldkey_path)

        # Create default hotkey.
        cli_utils.create_hotkeys_dir_if_not_exists(wallet_path + "/hotkeys")
        choice = input("Create 'default' hotkey ? (y/N) ")
        if choice == "y":
            hotkey_path = wallet_path + "/hotkeys/default"
        else:
            hotkey_name = input("Hotkey name: ")
            hotkey_path = wallet_path + "/hotkeys/" + hotkey_name
        hotkey_path = cli_utils.validate_create_path( hotkey_path )
        hotkey_keypair = cli_utils.gen_new_key( n_words )
        hotkey_data = json.dumps(hotkey_keypair.toDict()).encode()
        cli_utils.save_keys(hotkey_path, hotkey_data)
        cli_utils.set_file_permissions(hotkey_path)

    async def new_hotkey( self, n_words:int, wallet_path: str, wallet_name: str, hotkey_name: str):

        if wallet_name == None:
            choice = input("Store in 'default' wallet ? (y/N) ")
            if choice == "y":
                wallet_name = 'default'
            else:
                wallet_name = input("Wallet name: ")

        if hotkey_name == None:
            choice = input("Create 'default' hotkey ? (y/N) ")
            if choice == "y":
                hotkey_name = 'default'
            else:
                hotkey_name = input("Hotkey name: ")

        # expand wallet path.
        wallet_path = os.path.expanduser(wallet_path + '/' + wallet_name)
        if not os.path.exists(wallet_path):
            os.makedirs(wallet_path)
        
        hotkeys_dir_path = wallet_path + '/hotkeys/'
        if not os.path.exists(hotkeys_dir_path):
            os.makedirs(hotkeys_dir_path)

        # expand wallet path.
        hotkey_file_path = hotkeys_dir_path + hotkey_name
        hotkey_file_path = cli_utils.validate_create_path( hotkey_file_path )
        hotkey_keypair = cli_utils.gen_new_key( 12 )
        hotkey_data = json.dumps(hotkey_keypair.toDict()).encode()
        cli_utils.save_keys(hotkey_file_path, hotkey_data)
        cli_utils.set_file_permissions(hotkey_file_path)

    async def connect(self):
        self.__client.connect()
        await self.__client.is_connected()

    async def _associated_neurons(self) -> Neurons:
        pubkey = self.__keypair.public_key
        print(colored("Retrieving all nodes associated with cold key : {}".format(pubkey), 'white'))
        neurons = await self.__client.neurons(decorator=True)
        result = filter(lambda x : x.coldkey == pubkey, neurons )# These are the neurons associated with the provided cold key
        associated_neurons = Neurons(result)
        # Load stakes
        for neuron in associated_neurons:
            neuron.stake = await self.__client.get_stake_for_uid(neuron.uid)
        return associated_neurons

    async def overview(self):
        await self.connect()
        balance = await self.__client.get_balance(self.__keypair.ss58_address)
        neurons = await self._associated_neurons()

        print("BALANCE: %s : [%s]" % (self.__keypair.ss58_address, balance))
        print()
        print("--===[[ STAKES ]]===--")
        t = PrettyTable(["UID", "IP", "STAKE"])
        t.align = 'l'
        for neuron in neurons:
            t.add_row([neuron.uid, neuron.ip, neuron.stake])
        print(t.get_string())

    async def unstake_all(self):
        await self.connect()
        neurons = await self._associated_neurons()
        for neuron in neurons:
            neuron.stake = await self.__client.get_stake_for_uid(neuron.uid)
            await self.__client.unstake(neuron.stake, neuron.hotkey)
            print(colored("Unstaked: {} Tao from uid: {} to coldkey.pub: {}".format(neuron.stake, neuron.uid, self.__keypair.public_key), 'green'))

    async def unstake(self, uid, amount: Balance):
        await self.connect()
        neurons = await self._associated_neurons()
        neuron = neurons.get_by_uid(uid)
        if not neuron:
            print(colored("Neuron with uid: {} is not associated with coldkey.pub: {}".format(uid, self.__keypair.public_key), 'red'))
            quit()

        neuron.stake = await self.__client.get_stake_for_uid(uid)
        if amount > neuron.stake:
            print(colored("Neuron with uid: {} does not have enough stake ({}) to be able to unstake {}".format(uid, neuron.stake, amount), 'red'))
            quit()

        await self.__client.unstake(amount, neuron.hotkey)
        print(colored("Unstaked:{} from uid:{} to coldkey.pub:{}".format(amount.tao, uid, self.__keypair.public_key), 'green'))


    async def stake(self, uid, amount: Balance):
        await self.connect()
        balance = await self.__client.get_balance(self.__keypair.ss58_address)
        if balance < amount:
            print(colored("Not enough balance ({}) to stake {}".format(balance, amount), 'red'))
            quit()

        neurons = await self._associated_neurons()
        neuron = neurons.get_by_uid(uid)
        if not neuron:
            print(colored("Neuron with uid: {} is not associated with coldkey.pub: {}".format(uid, self.__keypair.public_key), 'red'))
            quit()

        await self.__client.add_stake(amount, neuron.hotkey)
        print(colored("Staked: {} Tao to uid: {} from coldkey.pub: {}".format(amount.tao, uid, self.__keypair.public_key), 'green'))

    async def transfer(self, dest: str, amount: Balance):
        await self.connect()
        balance = await self.__client.get_balance(self.__keypair.ss58_address)
        if balance < amount:
            print(colored("Not enough balance ({}) to transfer {}".format(balance, amount), 'red'))
            quit()

        await self.__client.transfer(dest, amount)
        print(colored("Transfered: {} Tao to dest: {} from coldkey.pub: {}".format(amount.tao, dest, self.__keypair.public_key), 'green'))



'''
Functions :
- Generate cold key
- View balance
- View hotkeys associated with the supplied cold key
- Stake funds into hotkey ( one by one / amount devided equally over keys)
- Unstake funds into coldkey (one by one / withdraw all)
'''
def run(args):
    cli_utils.enable_debug(args.debug)
    cli_utils.create_dirs()
    loop = asyncio.get_event_loop()

    # --- New wallet command
    if args.command == "new_wallet":
        executor = CommandExecutor(None, None)
        # ---- Run command new hotkey.
        loop.run_until_complete( executor.new_wallet( n_words = args.n_words, wallet_path = args.wallet_path, wallet_name = args.wallet_name, use_password = args.use_password) )

    # --- New hotkey command
    elif args.command == "new_hotkey":
        executor = CommandExecutor(None, None)
        # ---- Run command new hotkey.
        loop.run_until_complete( executor.new_hotkey( n_words = args.n_words, wallet_path = args.wallet_path, wallet_name = args.wallet_name, hotkey_name = args.hotkey_name) )

    # --- Regen command.
    elif args.command == "regen_wallet":
        executor = CommandExecutor(None, None)
        # ---- Run command new regen
        loop.run_until_complete( executor.regen_wallet( wallet_path = args.wallet_path, wallet_name = args.wallet_name, mnemonic = args.mnemonic ) )

    # --- Transfer command
    elif args.command == "transfer":
        if not args.chainendpoint:
            print(colored("The --chain-endpoint argument is required for this command", 'red'))
            quit()
        if not args.dest:
            print(colored("The --dest argument is required for this command", 'red'))
            quit()
        if not args.amount:
            print(colored("The --amount argument is required for this command", 'red'))
            quit()
        wallet_name = cli_utils.validate_wallet_name(args.wallet_name)
        coldkey_path = args.wallet_path + "/" + wallet_name + '/coldkey'
        cli_utils.validate_path( coldkey_path )
        keypair = cli_utils.load_key( coldkey_path )
        client = cli_utils.get_client(args.chainendpoint, keypair)
        executor = CommandExecutor(keypair, client)
        amount = Balance.from_float(args.amount)
        loop.run_until_complete( executor.transfer( dest = args.dest, amount = amount ) )

    # --- Overview command
    elif args.command == "overview":
        if not args.chainendpoint:
            print(colored("The --chain-endpoint argument is required for this command", 'red'))
            quit()
        # ---- Get Executor for coldkey and endpoint.
        wallet_name = cli_utils.validate_wallet_name(args.wallet_name)
        coldkey_path = args.wallet_path + "/" + wallet_name + '/coldkey'
        cli_utils.validate_path(coldkey_path)
        keypair = cli_utils.load_key(coldkey_path)
        client = cli_utils.get_client(args.chainendpoint, keypair)
        executor = CommandExecutor(keypair, client)

        # ---- Run command new hotkey.
        loop.run_until_complete( executor.overview( ) )


    # --- Unstake command
    elif args.command == "unstake":
        if not args.chainendpoint:
            print(colored("The --chain-endpoint argument is required for this command", 'red'))
            quit()

        # ---- Get Executor for coldkey and endpoint.
        wallet_name = cli_utils.validate_wallet_name(args.wallet_name)
        coldkey_path = args.wallet_path + "/" + wallet_name + '/coldkey'
        cli_utils.validate_path(coldkey_path)
        keypair = cli_utils.load_key(coldkey_path)
        client = cli_utils.get_client(args.chainendpoint, keypair)
        executor = CommandExecutor(keypair, client)
        
        if args.unstake_all:
            confirm = input("This will remove all stake from associated neurons, and transfer the balance in the account associated with the cold key. Continue? (y/N) ")
            if confirm not in (["Y", 'y']):
                quit()
            loop.run_until_complete(executor.unstake_all())
            quit()
        if args.uid is None:
            print(colored("The --uid argument is required for this command", 'red'))
            quit()
        if not args.amount:
            print(colored("The --amount argument is required for this command", 'red'))
            quit()
        amount = Balance.from_float(args.amount)
        loop.run_until_complete(executor.unstake(args.uid, amount))

    # --- Stake command
    elif args.command == "stake":
        if args.uid is None:
            print(colored("The --uid argument is required for this command", 'red'))
            quit()
        if args.amount is None:
            print(colored("The --amount argument is required for this command", 'red'))
            quit()
        if not args.chainendpoint:
            print(colored("The --chain-endpoint argument is required for this command", 'red'))
            quit()
        # ---- Get Executor for coldkey and endpoint.
        wallet_name = cli_utils.validate_wallet_name(args.wallet_name)
        coldkey_path = args.wallet_path + "/" + wallet_name + '/coldkey'
        cli_utils.validate_path(coldkey_path)
        keypair = cli_utils.load_key(coldkey_path)
        client = cli_utils.get_client(args.chainendpoint, keypair)
        executor = CommandExecutor(keypair, client)

        # ---- Run command stake.
        amount = Balance.from_float(args.amount)
        loop.run_until_complete(executor.stake(args.uid, amount))


def main():
        
    # Build top level parser.
    parser = ArgumentParser(description="Bittensor cli", usage="bittensor-cli <args> <command>", add_help=True)
    parser._positionals.title = "commands"
    parser.add_argument("--debug", default=False, help="Turn on debugging information", action="store_true")

    # Add subparsers.
    cmd_parsers = parser.add_subparsers(dest='command')
    overview_parser = cmd_parsers.add_parser('overview', 
        help='''Show account overview.''')
    transfer_parser = cmd_parsers.add_parser('transfer', 
        help='''Transfer Tao between accounts.''')
    unstake_parser = cmd_parsers.add_parser('unstake', 
        help='''Unstake from hotkey accounts.''')
    stake_parser = cmd_parsers.add_parser('stake', 
        help='''Stake to your hotkey accounts.''')
    regen_wallet_parser = cmd_parsers.add_parser('regen_wallet',
        help='''Regenerates a wallet from a passed mnemonic''')
    wallet_parser = cmd_parsers.add_parser('new_wallet', 
        help='''Creates a new wallet containing your coldkey (for containing balance)
            and a default hotkey (for running your node). The default wallet is stored 
            in ~/.bittensor/wallets/default/. 
            ''')
    new_hotkey_parser = cmd_parsers.add_parser('new_hotkey', 
        help='''Creates a new hotkey (for running a node) under the specified wallet path.
            The created hotkey is stored in plain text under ~/.bittensor/wallets/$wallet_name$/hotkeys/$hotkey_name$. 
            ''')

    # Fill arguments for the regen command.
    regen_wallet_parser = cmd_parsers.add_parser('regen_key')
    regen_wallet_parser.add_argument("--mnemonic", required=True, nargs="+", 
        help='Mnemonic used to regen your key i.e. horse cart dog ...') 
    regen_wallet_parser.add_argument('--wallets_path',  dest="wallet_path", default='~/.bittensor/wallets/', type=str, 
        help='''Path to your wallets directory (default: ~/.bittensor/wallets/)''')
    regen_wallet_parser.add_argument('--wallet', dest="wallet_name", default=None, type=str, 
        help='''The wallet where the new hotkey will be stored. Will overwrite previous wallet if it is already in place.''')


    # Fill arguments for the regen command.
    new_hotkey_parser = cmd_parsers.add_parser('new_hotkey')
    new_hotkey_parser.add_argument('--n_words', type=int, choices=[12,15,18,21,24], default=12, 
        help='''The number of words representing the mnemonic. i.e. horse cart dog ... x 24''')
    new_hotkey_parser.add_argument('--wallets_path',  dest="wallet_path", default='~/.bittensor/wallets/', type=str, 
        help='''Path to your wallets directory (default: ~/.bittensor/wallets/)''')
    new_hotkey_parser.add_argument('--wallet', dest="wallet_name", default=None, type=str, 
        help='''The wallet where the new hotkey will be stored (by default: ~/.bittensor/wallets/default)''')
    new_hotkey_parser.add_argument('--name', dest="hotkey_name", default=None, type=str, 
        help='''The name of the hotkey which allows you to distinguish between the node uid you are running with bittensor.py --wallet $wallet_name$ --hotkey $hotkey_name$''')

    # Fill arguments for new wallet command
    wallet_parser.add_argument('--n_words', type=int, choices=[12,15,18,21,24], default=12, 
        help='''The number of words representing the mnemonic. i.e. horse cart dog ... x 24''')
    wallet_parser.add_argument('--use_password', dest='use_password', action='store_true', help='''Set protect the generated bittensor key with a password.''')
    wallet_parser.add_argument('--no_password', dest='use_password', action='store_false', help='''Set off protects the generated bittensor key with a password.''')
    wallet_parser.set_defaults(use_password=True)
    wallet_parser.add_argument('--wallets_path',  dest="wallet_path", default='~/.bittensor/wallets/', type=str, 
        help='''Path to your wallets directory (default: ~/.bittensor/wallets/)''')
    wallet_parser.add_argument('--name',  dest="wallet_name", default=None, type=str, 
        help='''The name of the wallet, (default: 'default' and stores in ~/.bittensor/wallets/default)''')

    # Fill arguments for unstake command. 
    unstake_parser.add_argument('--all', dest="unstake_all", action='store_true')
    unstake_parser.add_argument('--uid', dest="uid", type=int, required=False)
    unstake_parser.add_argument('--amount', dest="amount", type=float, required=False)
    unstake_parser.add_argument("--chain-endpoint", dest="chainendpoint", default="feynman.akira.bittensor.com:9944", required=False, 
        help="The endpoint to the subtensor chain <hostname/ip>:<port>")
    unstake_parser.add_argument('--wallets_path',  dest="wallet_path", default='~/.bittensor/wallets/', type=str, 
        help='''Path to your wallets directory (default: ~/.bittensor/wallets/)''')
    unstake_parser.add_argument("--wallet", dest="wallet_name", default='default', required=False,
        help="Path to the wallet which holds the hotkey-account you are unstaking to.")

    # Fill arguments for stake command.
    stake_parser.add_argument('--uid', dest="uid", type=int, required=False)
    stake_parser.add_argument('--amount', dest="amount", type=float, required=False)
    stake_parser.add_argument("--chain-endpoint", dest="chainendpoint", default="feynman.akira.bittensor.com:9944", required=False, 
        help="The endpoint to the subtensor chain <hostname/ip>:<port>")
    stake_parser.add_argument('--wallets_path',  dest="wallet_path", default='~/.bittensor/wallets/', type=str, 
        help='''Path to your wallets directory (default: ~/.bittensor/wallets/)''')
    stake_parser.add_argument("--wallet", dest="wallet_name", default='default', required=False,
        help="Path to the wallet which holds the hotkey-account you are staking to.")

    # Fill arguments for overview command.
    overview_parser.add_argument("--chain-endpoint", dest="chainendpoint", default="feynman.akira.bittensor.com:9944", required=False, 
        help="The endpoint to the subtensor chain <hostname/ip>:<port>")
    overview_parser.add_argument('--wallets_path', dest="wallet_path", default='~/.bittensor/wallets/', type=str, 
        help='''Path to your wallets directory (default: ~/.bittensor/wallets/)''')
    overview_parser.add_argument("--wallet", dest="wallet_name", default='default', required=False,
        help="Path to your wallet")

     # Fill arguments for transfer
    transfer_parser.add_argument('--dest', dest="dest", type=str, required=True)
    transfer_parser.add_argument('--amount', dest="amount", type=float, required=True)
    transfer_parser.add_argument("--chain-endpoint", dest="chainendpoint", default="feynman.akira.bittensor.com:9944", required=False, 
        help="The endpoint to the subtensor chain <hostname/ip>:<port>")
    transfer_parser.add_argument('--wallets_path', dest="wallet_path", default='~/.bittensor/wallets/', type=str, 
        help='''Path to your wallets directory (default: ~/.bittensor/wallets/)''')
    transfer_parser.add_argument("--wallet", dest="wallet_name", default='default', required=False,
        help="Path to the wallet which holds the tokens for making the transfer.")

    # Hack to print formatted help
    if len(sys.argv) == 1:
    	parser.print_help()
    	sys.exit(0)

    # Load args and run the script.
    args = parser.parse_args()
    run(args)
 
if __name__ == '__main__':
    main()
