#!python
import os
import collections
import argparse
import sys
import logging
try:
    from ConfigParser import SafeConfigParser
except ImportError:
    import configparser
    SafeConfigParser = configparser.ConfigParser

from tmdeploy.utils import write_yaml_file, read_yaml_file, to_json
from tmdeploy.log import configure_logging, map_logging_verbosity
from tmdeploy.config import Setup
from tmdeploy.errors import (
    CloudError, NoInstanceFoundError, MultipleInstancesFoundError,
    SetupDescriptionError, SetupEnvironmentError
)
from tmdeploy.inventory import (
    CONFIG_DIR, GROUP_VARS_DIR, HOST_VARS_DIR, build_inventory
)

logger = logging.getLogger(os.path.basename(__file__))


def _get_host_ip_os(region, network, host_name, host_vars):
    import novaclient.client
    auth_url = os.getenv('OS_AUTH_URL')
    username = os.getenv('OS_USERNAME')
    password = os.getenv('OS_PASSWORD')
    project = os.getenv('OS_PROJECT_NAME')
    cloud_client = novaclient.client.Client(
        '2', username=username, password=password, project_name=project,
        auth_url=auth_url, region_name=region
    )
    search_options = {'name': host_name}
    instances = cloud_client.servers.list(
        search_opts=search_options, detailed=True
    )
    if len(instances) > 1:
        raise MultipleInstancesFoundError(
            'More than one instance found with name "{0}"'.format(host_name)
        )
    elif len(instances) == 0:
        raise NoInstanceFoundError(
            'No instance found with name "{0}"'.format(host_name)
        )

    try:
        addresses = instances[0].addresses[network]
    except KeyError:
        raise CloudError(
            'No addresses found for network "{0}"'.format(network)
        )
    if len(addresses) > 1:
        # Which one to choose? OS-EXT-IPS:type?
        logger.debug(
            'more than one address found for host "%s" and network "%s"',
            host_name, network
        )
        logger.debug('try to use floating address (if existing)')
        floating_addresses = [
            addr['addr'] for addr in addresses
            if addr['OS-EXT-IPS:type'] == 'floating'
        ]
        try:
            address = floating_addresses[0]
            logger.debug('floating address found')
        except IndexError:
            logger.debug('no floating address found')
            address = addresses[0]['addr']
    elif len(addresses) == 0:
        raise CloudError('No address found for network "{0}"'.format(network))
    else:
        address = addresses[0]['addr']
    return address


def _get_host_ip_ec2(region, network, host_name, host_vars):
    import boto3
    access_id = os.getenv('AWS_ACCESS_KEY_ID')
    access_secret = os.getenv('AWS_SECRET_ACCESS_KEY')
    try:
        cloud_client = boto3.resource(
            'ec2',
            region_name=region,
            aws_access_key_id=access_id,
            aws_secret_access_key=access_secret,
        )
    except:
        raise CloudError(
            'Could not connect to cloud provider in region "{0}".'.format(
                region
            )
        )
    instances = cloud_client.instances.filter(
        Filters=[
            {'Name': 'tag:Name', 'Values': [host_name]},
            {'Name': 'instance-state-name', 'Values': ['running']}
        ]
    )
    addresses = list()
    for inst in instances:
        addresses.append(inst.public_ip_address)
    if len(addresses) > 1:
        logger.debug('more than one instance found with name "%s"', host_name)
    elif len(addresses) == 0:
        raise NoInstanceFoundError(
            'No instance found with name "{0}"'.format(host_name)
        )
    return addresses[0]


def _get_host_ip_gce(region, network, host_name, host_vars):
    import googleapiclient.discovery
    from oauth2client.client import GoogleCredentials
    project = os.getenv('GCE_PROJECT')
    credentials_file = os.getenv('GCE_CREDENTIALS_FILE_PATH')
    credentials = GoogleCredentials.from_stream(credentials_file)
    cloud_client = googleapiclient.discovery.build(
        'compute', 'v1', credentials=credentials
    )
    try:
        instances = cloud_client.instances().list(
            project=project, zone=region, filter='name eq {0}'.format(host_name)
        ).execute()
    except:
        raise CloudError(
            'Could not connect to cloud provider for '
            'project "{0}" in region "{1}".'.format(project, region)
        )
    instances = instances['items']
    if len(instances) > 1:
        raise MultipleInstancesFoundError(
            'More than one instance found with name "{0}"'.format(host_name)
        )
    elif len(instances) == 0:
        raise NoInstanceFoundError(
            'No instance found with name "{0}"'.format(host_name)
        )
    networks = [
        net for net in instances[0]['networkInterfaces']
        if net['name'] == network
    ]
    if len(networks) == 0:
        raise CloudError('No network with name "{0}".'.format(network))
    elif len(networks) > 1:
        logger.debug(
            'more than one address found for host "%s" and network "%s"',
            host_name, network
        )
    addresses = [addr['natIP'] for addr in networks[0]['accessConfigs']]
    if len(addresses) > 1:
        logger.debug(
            'more than one address found for host "%s" and network "%s"',
            host_name, network
        )
    elif len(addresses) == 0:
        raise CloudError('No address found for network "{0}"'.format(network))
    return addresses[0]


def get_host_ip(provider, region, network, host_name, host_vars):
    if provider == 'os':
        ip = _get_host_ip_os(region, network, host_name, host_vars)
    elif provider == 'ec2':
        ip = _get_host_ip_ec2(region, network, host_name, host_vars)
    elif provider == 'gce':
        ip = _get_host_ip_gce(region, network, host_name, host_vars)
    else:
        logger.error('unsupported cloud provider: %s', provider)
        sys.exit(1)
    return str(ip)


def update_inventory(inventory, setup):
    # Create the main ansible inventory file
    for group, group_content in inventory.items():
        if group in {'_meta', 'all'}:
            continue

        if group.startswith('ganglia'):
            inventory[group]['vars'].update({
                'ganglia_cluster_name': setup.architecture.name,
                'ganglia_grid_name': 'tissuemaps'
            })

        if group.startswith('slurm'):
            inventory[group]['vars'].update({
                'multiuser_cluster': 'no',
                'slurm_cluster_users': ['tissuemaps']
            })
            inventory['tissuemaps_server']['vars'].update({
                'cluster_mode': 'yes'
            })

        logger.info('update inventory for group "%s"', group)
        # Create a separate variable file in YAML format for each group
        for host in group_content['hosts']:
            # Create a separte variable file in YAML format for each host
            host_vars_file = os.path.join(HOST_VARS_DIR, host)
            private_key_file = os.path.expandvars(
                os.path.expanduser(setup.cloud.key_file_private)
            )
            host_vars = dict()
            host_vars.update(inventory['_meta']['hostvars'][host])
            host_vars.update({
                'ansible_user': host_vars['ssh_user'],
                'ansible_ssh_private_key_file': private_key_file
            })

            # NOTE: --refresh is supposed to be triggered by ansible via
            # (meta: refresh_inventory), but it's not!
            if args.refresh or args.refresh_cache:
                try:
                    host_vars['ansible_host'] = get_host_ip(
                        setup.cloud.provider, setup.cloud.region,
                        setup.cloud.network, host,
                        inventory['_meta']['hostvars'][host]
                    )
                except NoInstanceFoundError:
                    logger.warn('no instance found for host "%s"', host)
                    # host_vars['ansible_host'] = None
                    logger.debug('remove "ansible_host" variable')
                except CloudError as err:
                    logger.error(
                        'host address could not be obtained from cloud: %s',
                        str(err)
                    )
                    sys.exit(1)
                if 'ansible_host' in host_vars:
                    logger.debug('update variables for host "%s"', host)
                    # The IP address of hosts is gets cached on disk, such that
                    # it doens't need to be retrieved from the cloud for each call.
                    write_yaml_file(
                        host_vars_file,
                        {'ansible_host': host_vars['ansible_host']}
                    )
            else:
                try:
                    pre_host_vars = read_yaml_file(host_vars_file)
                    host_vars['ansible_host'] = pre_host_vars['ansible_host']
                except OSError:
                    logger.warn('no variable file found for host "%s"', host)
                    # host_vars['ansible_host'] = None
                except KeyError:
                    logger.warn(
                        'variable "ansible_host" is not defined for host "%s"',
                        host
                    )
                    # host_vars['ansible_host'] = None
            if 'web' in host_vars['tags']:
                # NOTE: We use an architecture where hosts are part of a private
                # network and have security group settings that prevent access
                # from outside. Only some hosts have a public IP (those tagged
                # with "web") and all other hosts have private IPs.
                # So when the deploying machine is not part of the private
                # network, Ansible can only connect to the "web" host directly.
                # We therefore use the "web" host as a bastion host through
                # which we can connect to other hosts.
                if 'ansible_host' in host_vars:
                    logger.info(
                        'use "%s" as bastion host', host_vars['ansible_host']
                    )
                    inventory['all']['vars']['ansible_ssh_common_args'] = \
                        '-o ProxyCommand="ssh {o} -i {f} -W %h:%p -q {u}@{h}"'.format(
                            o='-o UserKnownHostsFile=/dev/null -o StrictHostKeyChecking=no',
                            f=private_key_file,
                            u=host_vars['ansible_user'],
                            h=host_vars['ansible_host']
                        )
            # Update inventory with host and group variables
            inventory['_meta']['hostvars'][host].update(host_vars)

        # Remove host_vars files when the host is no longer in the inventory.
        for host in os.listdir(HOST_VARS_DIR):
            if host not in inventory['_meta']['hostvars']:
                host_vars_file = os.path.join(HOST_VARS_DIR, host)
                logger.debug('remove variables of host "%s"', host)
                os.remove(host_vars_file)

    return inventory


def main(args):

    if 'TM_SETUP' not in os.environ:
        raise Exception('"TM_SETUP" environment variable is not set.')
    setup_file = os.environ['TM_SETUP']
    setup = Setup(setup_file)

    if not os.path.exists(HOST_VARS_DIR):
        os.mkdir(HOST_VARS_DIR)

    logger.info('build inventory based on setup file: %s', setup_file)
    inventory = build_inventory(setup)

    logger.info('update inventory')
    inventory = update_inventory(inventory, setup)

    if args.list:
        print('---')
        print(to_json(inventory))
    elif args.host:
        print('---')
        print(to_json(inventory['_meta']['hostvars'][args.host]))


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description='Dynamic Ansible inventory for TissueMAPS cloud instances.'
    )
    # parser.add_argument(
    #     '--private', action='store_true',
    #     help='use private address for ansible host'
    # )
    parser.add_argument(
        '--refresh', action='store_true',
        help='refresh cached information'
    )
    parser.add_argument(
        '--refresh-cache', dest='refresh_cache', action='store_true',
        help='refresh cached information'
    )
    parser.add_argument(
        '--debug', action='store_true',
        help='enable debug output'
    )
    group = parser.add_mutually_exclusive_group()
    group.add_argument(
        '--list', action='store_true',
        help='list active servers'
    )
    group.add_argument(
        '--host',
        help='list details about the specific host'
    )
    args = parser.parse_args()

    configure_logging()
    if args.debug:
        verbosity = 3
    else:
        # NOTE: Logging to STDOUT breaks inventory functionality!!!
        # Ansible will raise an ERROR: Syntax Error while loading YAML.
        # WARNING and ERROR messages are ok, because they get logged to STDERR.
        verbosity = 0
    log_level = map_logging_verbosity(verbosity)
    logger.setLevel(log_level)
    tmdeploy_logger = logging.getLogger('tmdeploy')
    tmdeploy_logger.setLevel(log_level)

    try:
        main(args)
    except SetupDescriptionError as err:
        logger.error(str(err))
        sys.exit(1)
    except SetupEnvironmentError as err:
        logger.error(str(err))
        sys.exit(1)
