#!python
#
# Copyright 2019 Ricardo Branco <rbranco@suse.de>
# MIT License
#
"""
Show all instances created on cloud providers
"""

import argparse
import datetime
import logging
import sys

from itertools import groupby
from os.path import basename
from operator import itemgetter
from threading import Thread
from dateutil import parser

import jmespath
from jmespath.exceptions import JMESPathError

import timeago

from cachetools import cached, TTLCache
from pytz import utc

from providers.amazon import AWS
from providers.az import Azure
from providers.gcp import GCP
from providers.exceptions import FatalError
from output import Output


USAGE = "Usage: " + basename(sys.argv[0]) + """ [OPTIONS]
Options:
    -h, --help                          show this help message and exit
    -l, --log debug|info|warning|error|critical
    -o, --output text|html|json|JSON    output type
    -p, --port PORT                     run a web server on port PORT
    -r, --reverse                       reverse sort
    -s, --sort name|time|status         sort type
    -S, --status stopped|running|all    filter by instance status
    -T, --time TIME_FORMAT              time format as used by strftime(3)
    -V, --version                       show version and exit
    -v, --verbose                       be verbose
Filter options:
    --filter-aws NAME VALUE             may be specified multiple times
    --filter-azure FILTER               Filter for Azure
    --filter-gcp FILTER                 Filter for GCP
"""

__version__ = "0.1.20"


def fix_date(date):
    """
    Converts datetime object or string to local time or the
    timezone specified by the TZ environment variable
    """
    if isinstance(date, str):
        # The parser returns datetime objects
        date = parser.parse(date)
    if isinstance(date, datetime.datetime):
        # GCP doesn't return UTC dates
        date = utc.normalize(date)
        if args.verbose:
            # From Python 3.3 we can omit tz in datetime.astimezone()
            return date.astimezone().strftime(args.time)
        return timeago.format(date, datetime.datetime.now(tz=utc))
    return ""


def perror(msg, err):
    """
    Print an error message and exit
    """
    if isinstance(err, Exception):
        err = "%s: %s" % (err.__class__.__name__, err)
    logging.error("%s: %s", msg, err)
    sys.exit(1)


def print_amazon_instances():
    """
    Print information about AWS EC2 instances
    """
    filters = []
    # https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/ec2-instance-lifecycle.html
    if args.status in ('running', 'stopped'):
        # Consider an instance "running" if not stopped or terminated
        # and "stopped" if not pending or running, hence the overlap
        if args.status == "running":
            statuses = 'pending running stopping shutting-down'.split()
        else:
            statuses = 'stopping stopped shutting-down terminated'.split()
        filters = [['instance-state-name', _] for _ in statuses]
    # If instance-state-name was specified in the filter, use it instead
    if args.filter_aws:
        if 'instance-state-name' in set(_[0] for _ in args.filter_aws):
            filters = args.filter_aws
        else:
            filters.extend(args.filter_aws)
    # Compile filter using 'Name' & 'Values'
    filters = [
        {'Name': name, 'Values': [v for _, v in values]}
        for name, values in groupby(
            sorted(filters, key=itemgetter(0)), itemgetter(0))
    ]
    aws = AWS()
    instances = aws.get_instances(filters=filters)
    if args.sort:
        instances = list(instances)
        keys = {
            'name': itemgetter('InstanceId'),
            'time': itemgetter('LaunchTime', 'InstanceId'),
            'status': lambda k: (aws.get_status(k), k['InstanceId'])
        }
        instances.sort(key=keys[args.sort], reverse=args.reverse)
    for instance in instances:
        if args.output == "JSON":
            output.info(item=instance)
        else:
            output.info(
                cloud="AWS",
                name=instance['InstanceId'],
                type=instance['InstanceType'],
                status=aws.get_status(instance),
                created=fix_date(instance['LaunchTime']),
                location=instance['Placement']['AvailabilityZone'])


def print_azure_instances():  # pylint: disable=too-many-branches
    """
    Print information about Azure Compute instances
    """
    filters = None
    # https://docs.microsoft.com/en-us/azure/virtual-machines/windows/states-lifecycle
    if args.status in ('running', 'stopped'):
        # Consider an instance "running" if not stopped / deallocated
        # and "stopped" if not starting or running, hence the overlap
        if args.status == "running":
            statuses = 'starting running stopping'
        else:
            statuses = 'stopping stopped deallocating deallocated'
        filters = " || ".join(
            "instance_view.statuses[1].code == 'PowerState/%s'" % status
            for status in statuses.split())
    # If status was specified in the filter, use it instead
    if args.filter_azure:
        if "status" in args.filter_azure:
            filters = args.filter_azure
        else:
            filters += " && (%s)" % args.filter_azure
    if filters:
        try:
            filters = jmespath.compile(filters)
        except JMESPathError as exc:
            FatalError("Azure", exc)
    azure = Azure()
    instances = list(azure.get_instances(filters=filters))
    if args.sort:
        keys = {
            'name': itemgetter('name'),
            'time': itemgetter('date', 'name'),
            'status': itemgetter('status', 'name')
        }
        try:
            instances.sort(key=keys[args.sort], reverse=args.reverse)
        except TypeError:
            # args['date'] may be None
            pass
    for instance in instances:
        if args.output == "JSON":
            output.info(item=instance)
        else:
            output.info(
                cloud="Azure",
                name=instance['name'],
                type=instance['hardware_profile']['vm_size'],
                status=instance['status'],
                created=fix_date(instance['date']),
                location=instance['location'])


def print_google_instances():
    """
    Print information about Google Compute instances
    """
    filters = None
    # https://cloud.google.com/compute/docs/instances/instance-life-cycle
    # NOTE: The above list is incomplete. The API returns more statuses
    if args.status in ('running', 'stopped'):
        # Consider an instance "running if not stopped / terminated
        # and "stopped" if not starting, running, hence the overlap
        if args.status == "running":
            statuses = 'provisioning staging running stopping suspending suspended'
        else:
            statuses = 'stopping stopped terminated'
        filters = " OR ".join(
            "status: %s" % status
            for status in statuses.split())
    # If status was specified in the filter, use it instead
    if args.filter_gcp:
        if "status" in args.filter_gcp:
            filters = args.filter_gcp
        else:
            filters += " AND (%s)" % args.filter_gcp
    gcp = GCP()
    instances = gcp.get_instances(filters=filters)
    if args.sort:
        instances = list(instances)
        keys = {
            'name': itemgetter('name'),
            'time': itemgetter('creationTimestamp', 'name'),
            'status': itemgetter('status', 'name'),
        }
        instances.sort(key=keys[args.sort], reverse=args.reverse)
    for instance in instances:
        if args.output == "JSON":
            output.info(item=instance)
        else:
            output.info(
                cloud="GCP",
                name=instance['name'],
                type=instance['machineType'].rsplit('/', 1)[-1],
                status=gcp.get_status(instance),
                created=fix_date(instance['creationTimestamp']),
                location=instance['zone'].rsplit('/', 1)[-1])


@cached(cache=TTLCache(maxsize=2, ttl=120))
def main():
    """
    Main function
    """
    if args.port:
        sys.stdout = StringIO()
    output.header()
    threads = [
        Thread(target=print_amazon_instances),
        Thread(target=print_azure_instances),
        Thread(target=print_google_instances)
    ]
    for thread in threads:
        thread.start()
    for thread in threads:
        thread.join()
    output.footer()
    if args.port:
        response = sys.stdout.getvalue()
        sys.stdout.close()
        return response
    return None


def handle_requests(request):
    """
    Handle HTTP requests
    """
    logging.info(request)
    response = main()
    return Response(response)


def web_server():
    """
    Setup the WSGI server
    """
    with Configurator() as config:
        config.add_route('handle_requests', '/')
        config.add_view(handle_requests, route_name='handle_requests')
        app = config.make_wsgi_app()
        server = make_server('0.0.0.0', args.port, app)
        server.serve_forever()


def setup_logging():
    """
    Setup logging
    """
    logging.basicConfig(
        format="%(asctime)s %(levelname)-8s %(message)s" if args.port else None,
        stream=sys.stderr,
        level=args.log.upper())


if __name__ == "__main__":
    argparser = argparse.ArgumentParser(usage=USAGE, add_help=False)
    argparser.add_argument('-h', '--help', action='store_true')
    argparser.add_argument('-l', '--log', default='error',
                           choices='debug info warning error critical'.split())
    argparser.add_argument('-o', '--output', default='text',
                           choices=['text', 'html', 'json', 'JSON'])
    argparser.add_argument('-p', '--port', type=int)
    argparser.add_argument('-r', '--reverse', action='store_true')
    argparser.add_argument('-s', '--sort', choices=['name', 'status', 'time'])
    argparser.add_argument('-S', '--status', default='running',
                           choices=['all', 'running', 'stopped'])
    argparser.add_argument('-T', '--time', default="%a %b %d %H:%M:%S %Z %Y")
    argparser.add_argument('-v', '--verbose', action='count')
    argparser.add_argument('-V', '--version', action='store_true')
    argparser.add_argument('--filter-aws', nargs=2, action='append')
    argparser.add_argument('--filter-azure', type=str)
    argparser.add_argument('--filter-gcp', type=str)
    args = argparser.parse_args()

    if args.help or args.version:
        print(USAGE if args.help else __version__)
        sys.exit(0)

    setup_logging()

    _keys = "cloud name type status created location"
    _fmt = '{d[cloud]}\t{d[name]:32}\t{d[type]:>23}\t{d[status]:>16}\t{d[created]:30}\t{d[location]}'

    if args.port:
        args.output = "html"
    output = Output(type=args.output.lower(), keys=_keys, fmt=_fmt)

    if args.port:
        from io import StringIO
        from wsgiref.simple_server import make_server
        from pyramid.config import Configurator
        from pyramid.response import Response

        web_server()
        sys.exit(1)

    try:
        main()
    except KeyboardInterrupt:
        sys.exit(1)
