#!/usr/bin/env python3

import sys
import argparse
import asyncio
import postfix_mta_sts_resolver.utils as utils
import postfix_mta_sts_resolver.defaults as defaults
import pynetstring
import yaml
import signal
from postfix_mta_sts_resolver.resolver import *
import collections
import time
import logging


def parse_args():
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument("-v", "--verbosity",
                        help="logging verbosity",
                        type=utils.LogLevel.__getitem__,
                        choices=list(utils.LogLevel),
                        default=utils.LogLevel.info)
    parser.add_argument("-c", "--config",
                        help="config file location",
                        metavar="FILE",
                        default=defaults.CONFIG_LOCATION)
    parser.add_argument("--disable-uvloop",
                        help="do not use uvloop even if it is available",
                        action="store_true")

    return parser.parse_args()


def populate_cfg_defaults(cfg):
    if not cfg:
        cfg = {}

    cfg['host'] = cfg.get('host', defaults.HOST)
    cfg['port'] = cfg.get('port', defaults.PORT)

    if 'cache' not in cfg:
        cfg['cache'] = {}

    cfg['cache']['type'] = cfg['cache'].get('type', defaults.CACHE_BACKEND)

    if cfg['cache']['type'] == 'internal':
        if 'options' not in cfg['cache']:
            cfg['cache']['options'] = {}

        cfg['cache']['options']['cache_size'] = cfg['cache']['options'].get('cache_size', defaults.INTERNAL_CACHE_SIZE)

    def populate_zone(zone):
        zone['timeout'] = zone.get('timeout', defaults.TIMEOUT)
        zone['strict_testing'] = zone.get('strict_testing', defaults.STRICT_TESTING)
        return zone

    if 'default_zone' not in cfg:
        cfg['default_zone'] = {}

    populate_zone(cfg['default_zone'])

    if 'zones' not in cfg:
        cfg['zones'] = {}

    for zone in cfg['zones'].values():
        populate_zone(zone)

    return cfg


ZoneEntry = collections.namedtuple('ZoneEntry', ('strict', 'resolver'))


CacheEntry = collections.namedtuple('CacheEntry', ('ts', 'pol_id', 'pol_body'))


class STSSocketmapResponder(object):
    def __init__(self, cfg, loop):
        self._loop = loop

        # Construct configurations and resolvers for every socketmap name
        self._default_zone = ZoneEntry(cfg["default_zone"]["strict_testing"],
                                       STSResolver(loop=loop,
                                                   timeout=cfg["default_zone"]["timeout"]))

        self._zones = dict((k, ZoneEntry(zone["strict_testing"],
                                         STSResolver(loop=loop,
                                                     timeout=zone["timeout"])))
                           for k, zone in cfg["zones"].items())

        # Construct cache
        if cfg["cache"]["type"] == "internal":
            import postfix_mta_sts_resolver.internal_cache
            capacity = cfg["cache"]["options"]["cache_size"]
            self._cache = postfix_mta_sts_resolver.internal_cache.InternalLRUCache(capacity)
        else:
            raise NotImplementedError("Unsupported cache type!")

    async def sender(self, queue, writer):
        logger = logging.getLogger("STS")
        try:
            while True:
                fut = await queue.get()

                # Check for shutdown
                if fut is None:
                    writer.close()
                    return

                logger.debug("Got new future from queue")
                try:
                    data = await fut
                except Exception as e:
                    logging.exception("Got exception from future: %s", e)
                    writer.close()
                    return
                logger.debug("Future await complete: data=%s", repr(data))
                writer.write(data)
                logger.debug("Wrote: %s", repr(data))
                await writer.drain()
        except asyncio.CancelledError:
            try:
                fut.cancel()
            except:
                pass
            while not queue.empty():
                task = queue.get_nowait()
                task.cancel()

    async def process_request(self, raw_req):
        have_policy = True

        # Parse request and canonicalize domain
        req_zone, _, req_domain = raw_req.decode('latin-1').partition(' ')
        domain = req_domain.lower().strip().rstrip('.')

        # Skip lookups for parent domain policies
        if domain.startswith('.'):
            return pynetstring.encode('NOTFOUND ')

        # Find appropriate zone config
        if req_zone in self._zones:
            zone_cfg = self._zones[req_zone]
        else:
            zone_cfg = self._default_zone

        # Lookup for cached policy
        cached = self._cache.get(domain)

        # Check if newer policy exists or 
        # retrieve policy from scratch if there is no cached one
        if cached is None:
            latest_pol_id  = None
        else:
            latest_pol_id = cached.pol_id
        status, policy = await zone_cfg.resolver.resolve(domain, latest_pol_id)

        # Update local cache
        ts = time.time()
        if status is STSFetchResult.NOT_CHANGED:
            cached = CacheEntry(ts, cached.pol_id, cached.pol_body)
            self._cache.set(domain, cached)
        elif status is STSFetchResult.VALID:
            pol_id, pol_body = policy
            cached = CacheEntry(ts, pol_id, pol_body)
            self._cache.set(domain, cached)
        else:
            if cached is None:
                have_policy = False
            else:
                # Check if cached policy is expired
                if cached.pol_body['max_age'] + cached.ts < ts:
                    have_policy = False


        if have_policy:
            mode = cached.pol_body['mode']
            if mode == 'none' or (mode == 'testing' and not zone_cfg.strict):
                return pynetstring.encode('NOTFOUND ')
            else:
                assert cached.pol_body['mx'], "Empty MX list for restrictive policy!"
                mxlist = [mx.lstrip('*') for mx in set(cached.pol_body['mx'])]
                resp = "OK secure match=" + ":".join(mxlist)
                return pynetstring.encode(resp)
        else:
            return pynetstring.encode('NOTFOUND ')


    def enqueue_request(self, queue, raw_req):
        fut = asyncio.ensure_future(self.process_request(raw_req), loop=self._loop)
        queue.put_nowait(fut)

    async def handle_msg(self, reader, writer):
        logger = logging.getLogger("STS")

        # Construct netstring parser
        self._decoder = pynetstring.Decoder()

        # Construct queue for responses ordering
        queue = asyncio.Queue(0, loop=self._loop)

        # Create coroutine which awaits for steady responses and sends them
        sender = asyncio.ensure_future(self.sender(queue, writer), loop=self._loop)

        def cleanup():
            sender.cancel()
            writer.close()

        while True:
            try:
                part = await reader.read(4096)
                logger.debug("Read: %s", repr(part))
            except asyncio.CancelledError as e:
                cleanup()
                return
            except ConnectionError as e:
                cleanup()
                return
            if not part:
                cleanup()
                return

            try:
                requests = self._decoder.feed(part)
            except:
                # Bad protocol. Do shutdown
                queue.put_nowait(None)
                await sender
            else:
                for req in requests:
                    logger.debug("Enq request: %s", repr(req))
                    self.enqueue_request(queue, req)


def main():
    # Parse command line arguments and setup basic logging
    args = parse_args()
    mainLogger = utils.setup_logger('MAIN', args.verbosity)
    utils.setup_logger('STS', args.verbosity)
    mainLogger.info("MTA-STS daemon starting...")

    # Read config and populate with defaults
    with open(args.config, 'rb') as cfg_file:
        cfg = yaml.safe_load(cfg_file)
    cfg = populate_cfg_defaults(cfg)

    # Construct event loop
    mainLogger.info("Starting eventloop...")
    if not args.disable_uvloop:
        if utils.enable_uvloop():
            mainLogger.info("uvloop enabled.")
        else:
            mainLogger.info("uvloop is not available. "
                            "Falling back to built-in event loop.")
    evloop = asyncio.get_event_loop()
    mainLogger.info("Eventloop started.")

    # Construct request handler instance
    responder = STSSocketmapResponder(cfg, evloop)

    # Start server
    start_server = asyncio.start_server(responder.handle_msg,
                                        cfg['host'],
                                        cfg['port'],
                                        loop=evloop)
    server = evloop.run_until_complete(start_server)
    mainLogger.info("Server started.")

    # Enter main loop
    stop = asyncio.Future()
    evloop.add_signal_handler(signal.SIGINT, stop.set_result, None)
    mainLogger.debug("Signal handler defined.")
    evloop.run_until_complete(stop)

    # Handle interruption: shutdown properly
    mainLogger.info("Got exit signal.")
    server.close()
    evloop.run_until_complete(server.wait_closed())
    mainLogger.info("Server finished its work.")
    tasks = list(asyncio.Task.all_tasks(loop=evloop))
    if tasks:
        for task in tasks:
            task.cancel()
        evloop.run_until_complete(asyncio.wait(tasks))


if __name__ == '__main__':
    main()
