#!python
import os
import collections
import argparse
import sys
import logging
try:
    from ConfigParser import SafeConfigParser
except ImportError:
    import configparser
    SafeConfigParser = configparser.ConfigParser

from tmdeploy.utils import write_yaml_file, read_yaml_file, to_json
from tmdeploy.log import configure_logging, map_logging_verbosity
from tmdeploy.config import Setup
from tmdeploy.errors import (
    CloudError, NoInstanceFoundError, MultipleInstancesFoundError
)
from tmdeploy.inventory import (
    CONFIG_DIR, GROUP_VARS_DIR, HOST_VARS_DIR,
    build_inventory_information, load_inventory, save_inventory
)

logger = logging.getLogger(os.path.basename(__file__))


def _get_host_ip_os(region, network, host_name, host_vars):
    import novaclient.client
    auth_url = os.getenv('OS_AUTH_URL')
    username = os.getenv('OS_USERNAME')
    password = os.getenv('OS_PASSWORD')
    project = os.getenv('OS_PROJECT_NAME')
    cloud_client = novaclient.client.Client(
        '2', username=username, password=password, project_name=project,
        auth_url=auth_url, region_name=region
    )
    search_options = {'name': host_name}
    instances = cloud_client.servers.list(
        search_opts=search_options, detailed=True
    )
    if len(instances) > 1:
        raise MultipleInstancesFoundError(
            'More than one instance found with name "{0}"'.format(host_name)
        )
    elif len(instances) == 0:
        raise NoInstanceFoundError(
            'No instance found with name "{0}"'.format(host_name)
        )

    try:
        addresses = instances[0].addresses[network]
    except KeyError:
        raise CloudError(
            'No addresses found for network "{0}"'.format(network)
        )
    if len(addresses) > 1:
        # Which one to choose? OS-EXT-IPS:type?
        logger.warn(
            'more than one address found for host "%s" and network "%s"',
            host_name, network
        )
        logger.info('try to use floating address (if existing)')
        floating_addresses = [
            addr['addr'] for addr in addresses
            if addr['OS-EXT-IPS:type'] == 'floating'
        ]
        try:
            address = floating_addresses[0]
            logger.info('floating address found')
        except IndexError:
            logger.info('no floating address found')
            address = addresses[0]['addr']
    elif len(addresses) == 0:
        raise CloudError('No address found for network "{0}"'.format(network))
    else:
        address = addresses[0]['addr']
    return address


def _get_host_ip_ec2(region, host_name, host_vars):
    import boto3
    access_id = os.getenv('AWS_ACCESS_KEY_ID')
    access_secret = os.getenv('AWS_SECRET_ACCESS_KEY')
    try:
        cloud_client = boto3.resource(
            'ec2',
            region_name=region,
            aws_access_key_id=access_id,
            aws_secret_access_key=access_secret,
        )
    except:
        raise CloudError(
            'Could not connect to cloud provider in region "{0}".'.format(
                region
            )
        )
    instances = cloud_client.instances.filter(
        Filters=[{'Name':'tag:Name', 'Values':[host_name]}]
    )
    addresses = list()
    for inst in instances:
        addresses.append(inst.public_ip_address)
    if len(addresses) > 1:
        logger.warn('more than one instance found with name "%s"', host_name)
    elif len(addresses) == 0:
        raise NoInstanceFoundError(
            'No instance found with name "{0}"'.format(host_name)
        )
    return addresses[0]


def _get_host_ip_gce(region, network, host_name, host_vars):
    import googleapiclient.discovery
    from oauth2client.client import GoogleCredentials
    project = os.getenv('GCE_PROJECT')
    credentials_file = os.getenv('GCE_CREDENTIALS_FILE_PATH')
    credentials = GoogleCredentials.from_stream(credentials_file)
    cloud_client = googleapiclient.discovery.build(
        'compute', 'v1', credentials=credentials
    )
    try:
        instances = cloud_client.instances().list(
            project=project, zone=region, filter='name eq {0}'.format(host_name)
        ).execute()
    except:
        raise CloudError(
            'Could not connect to cloud provider for '
            'project "{0}" in region "{1}".'.format(project, region)
        )
    instances = instances['items']
    if len(instances) > 1:
        raise MultipleInstancesFoundError(
            'More than one instance found with name "{0}"'.format(host_name)
        )
    elif len(instances) == 0:
        raise NoInstanceFoundError(
            'No instance found with name "{0}"'.format(host_name)
        )
    networks = [
        net for net in instances[0]['networkInterfaces']
        if net['name'] == network
    ]
    if len(networks) == 0:
        raise CloudError('No network with name "{0}".'.format(network))
    elif len(networks) > 1:
        logger.warn(
            'more than one address found for host "%s" and network "%s"',
            host_name, network
        )
    addresses = [addr['natIP'] for addr in networks[0]['accessConfigs']]
    if len(addresses) > 1:
        logger.warn(
            'more than one address found for host "%s" and network "%s"',
            host_name, network
        )
    elif len(addresses) == 0:
        raise CloudError('No address found for network "{0}"'.format(network))
    return addresses[0]


def get_host_ip(provider, region, network, host_name, host_vars):
    if provider == 'os':
        ip = _get_host_ip_os(region, network, host_name, host_vars)
    elif provider == 'ec2':
        ip = _get_host_ip_ec2(region, network, host_name, host_vars)
    elif provider == 'gce':
        ip = _get_host_ip_gce(region, network, host_name, host_vars)
    else:
        logger.error('unsupported cloud provider: %s', provider)
        sys.exit(1)
    return str(ip)


def update_static_inventory(inventory_info, setup):
    # Create the main ansible inventory file
    persistent_inventory = load_inventory()
    for group, group_content in inventory_info.items():
        if group == '_meta':
            continue

        if group == 'all':
            continue

        logger.info('update static inventory files for group "%s"', group)
        # Create a separate variable file in YAML format for each group
        group_vars_file = os.path.join(GROUP_VARS_DIR, group)
        if group_content.get('vars', None) is not None:
            logger.debug('update variables for group "%s"', group)
            write_yaml_file(group_vars_file, group_content['vars'])

        if not persistent_inventory.has_section(group):
            persistent_inventory.add_section(group)
        for host in inventory_info[group]['hosts']:
            persistent_inventory.set(group, host)
            # Create a separte variable file in YAML format for each host
            host_vars_file = os.path.join(HOST_VARS_DIR, host)
            private_key_file = os.path.expandvars(
                os.path.expanduser(setup.cloud.key_file_private)
            )
            host_vars = dict()
            host_vars.update(inventory_info['_meta']['hostvars'][host])
            host_vars.update({
                'ansible_user': 'ubuntu',
                'ansible_ssh_private_key_file': private_key_file,
                'timeout': 60
            })

            # NOTE: --refresh is supposed to be triggered by ansible via
            # (meta: refresh_inventory), but it's not!
            if args.refresh or args.refresh_cache:
                try:
                    host_vars['ansible_host'] = get_host_ip(
                        setup.cloud.provider, setup.cloud.region,
                        setup.cloud.network, host,
                        inventory_info['_meta']['hostvars'][host]
                    )
                except NoInstanceFoundError:
                    logger.warn('no instance found for host "%s"', host)
                    # host_vars['ansible_host'] = None
                    logger.debug('remove "ansible_host" variable')
                except CloudError as err:
                    logger.error(
                        'host address could not be obtained from cloud: %s',
                        str(err)
                    )
                    sys.exit(1)
            else:
                try:
                    pre_host_vars = read_yaml_file(host_vars_file)
                    host_vars['ansible_host'] = pre_host_vars['ansible_host']
                except OSError:
                    logger.warn('no variable file found for host "%s"', host)
                    # host_vars['ansible_host'] = None
                except KeyError:
                    logger.warn(
                        'variable "ansible_host" is not defined for host "%s"',
                        host
                    )
                    # host_vars['ansible_host'] = None
            if 'web' in host_vars['tags']:
                # NOTE: We use an architecture where hosts are part of a private
                # network and have security group settings that prevent access
                # from outside. Only some hosts have a public IP (those tagged
                # with "web") and all other hosts have private IPs.
                # So when the deploying machine is not part of the private
                # network, Ansible can only connect to the "web" host directly.
                # We therefore use the "web" host as a bastion host through
                # which we can connect to other hosts.
                if 'ansible_host' in host_vars:
                    logger.info(
                        'use "%s" as bastion host', host_vars['ansible_host']
                    )
                    inventory_info['all']['vars']['ansible_ssh_pipelining'] = \
                        True
                    inventory_info['all']['vars']['ansible_ssh_common_args'] = (
                        '-o UserKnownHostsFile=/dev/null '
                        '-o ProxyCommand="ssh -o StrictHostKeyChecking=no -i '
                        '{key_file} -W %h:%p -q {user}@{host}"'.format(
                            key_file=host_vars['ansible_ssh_private_key_file'],
                            user=host_vars['ansible_user'],
                            host=host_vars['ansible_host']
                        )
                    )
                    all_group_vars_file = os.path.join(GROUP_VARS_DIR, 'all')
                    logger.debug('update variables for group "all"')
                    write_yaml_file(
                        all_group_vars_file, inventory_info['all']['vars']
                    )
            # Update inventory with host and group variables
            inventory_info['_meta']['hostvars'][host].update(host_vars)
            logger.debug('update variables for host "%s"', host)
            write_yaml_file(host_vars_file, host_vars)

        # Remove group_vars and host_vars files when the respective group or
        # host is no longer part of the inventory.
        for group in os.listdir(GROUP_VARS_DIR):
            if group == 'all':
                continue
            if group not in inventory_info:
                group_vars_file = os.path.join(GROUP_VARS_DIR, group)
                logger.debug('remove variables of group "%s"', group)
                os.remove(group_vars_file)
        for host in os.listdir(HOST_VARS_DIR):
            if host not in inventory_info['_meta']['hostvars']:
                host_vars_file = os.path.join(HOST_VARS_DIR, host)
                logger.debug('remove variables of host "%s"', host)
                os.remove(host_vars_file)
        for group in persistent_inventory.sections():
            if group not in inventory_info:
                logger.debug('remove group "%s" from inventory', group)
                persistent_inventory.remove_section(group)

    # Create (or update) the main ansible inventory file in INI format
    save_inventory(persistent_inventory)


def main(args):

    if not os.path.exists(CONFIG_DIR):
        logger.error('configuration directory does not exist: %s', CONFIG_DIR)
        sys.exit(1)

    setup = Setup()

    if not os.path.exists(GROUP_VARS_DIR):
        os.mkdir(GROUP_VARS_DIR)
    if not os.path.exists(HOST_VARS_DIR):
        os.mkdir(HOST_VARS_DIR)

    logger.info('build inventory')
    inventory_info = build_inventory_information(setup)

    logger.info('update static inventory files')
    update_static_inventory(inventory_info, setup)

    if args.list:
        print('---')
        print(to_json(inventory_info))
    elif args.host:
        print('---')
        print(to_json(inventory_info['_meta']['hostvars'][args.host]))


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description='Dynamic Ansible inventory for TissueMAPS cloud instances.'
    )
    # parser.add_argument(
    #     '--private', action='store_true',
    #     help='use private address for ansible host'
    # )
    parser.add_argument(
        '--refresh', action='store_true',
        help='refresh cached information'
    )
    parser.add_argument(
        '--refresh-cache', dest='refresh_cache', action='store_true',
        help='refresh cached information'
    )
    parser.add_argument(
        '--debug', action='store_true',
        help='enable debug output'
    )
    group = parser.add_mutually_exclusive_group()
    group.add_argument(
        '--list', action='store_true',
        help='list active servers'
    )
    group.add_argument(
        '--host',
        help='list details about the specific host'
    )
    args = parser.parse_args()

    configure_logging()
    if args.debug:
        verbosity = 3
    else:
        # NOTE: Logging to STDOUT breaks inventory functionality.
        # Ansible will raise an ERROR: Syntax Error while loading YAML.
        # WARNING and ERROR messages are fine, because they get logged to
        # STDERR.
        verbosity = 0
    log_level = map_logging_verbosity(verbosity)
    logger.setLevel(log_level)
    tmdeploy_logger = logging.getLogger('tmdeploy')
    tmdeploy_logger.setLevel(log_level)

    main(args)
