#!/usr/bin/env python3

import logging
import os
import socket
import psutil
import sys
import time
import utils.stringcalc
from utils.pre_process_configs import pre_process_configs
import signal
import pwd
import getpass
import select, pty, tty, termios

import utils.parse_configs

from zeroconf import ServiceInfo, Zeroconf
from utils.zc_config import load_config
from netifaces import interfaces, ifaddresses, AF_INET


child_pid=-1

def report_ids(msg):
    print ('uid, gid = %d, %d; %s' % (os.getuid(), os.getgid(), msg),flush=True)

def attach_tty(command, env=None, quit=True):
    global child_pid
    global child_process
    # save original tty setting then set it to raw mode
    old_tty = termios.tcgetattr(sys.stdin)
    tty.setraw(sys.stdin.fileno())
    # open pseudo-terminal to interact with subprocess
    master_fd, slave_fd = pty.openpty()
    # use os.setsid() make it run in a new process group, or bash job control will not be enabled
    child_process = psutil.Popen(command.split(" "),
              preexec_fn=os.setsid,
              stdin=slave_fd,
              stdout=slave_fd,
              stderr=slave_fd,
              universal_newlines=True, env=env)
    child_pid = child_process.pid
    while child_process.poll() is None:
        r, w, e = select.select([sys.stdin, master_fd], [], [])
        if sys.stdin in r:
            d = os.read(sys.stdin.fileno(), 32)
            os.write(master_fd, d)
        elif master_fd in r:
            o = os.read(master_fd, 32)
            if o:
                os.write(sys.stdout.fileno(), o)
    # restore tty settings back
 #   termios.tcsetattr(sys.stdin, termios.TCSADRAIN, old_tty)

    #child_process.wait()
    if child_process.returncode:
        print("retcode=", child_process.returncode)
    if child_process.returncode and quit:
        sys.exit(child_process.returncode)
    return child_process.returncode

def execute_process(process_command, env=None, quit=True):
    global child_pid
    global child_process
    if env:
        c = os.environ.copy()
        c.update(env)
        env = c
        print(process_command)

    print("Starting child process command {}".format(process_command),flush=True)
    report_ids("")
    child_process = psutil.Popen(process_command.split(" "), env=env)
    child_pid=child_process.pid
    child_process.wait()
    if child_process.returncode:
        print("retcode=", child_process.returncode)
    if child_process.returncode and quit:
        sys.exit(child_process.returncode)
    return child_process.returncode

prohibited_ifaces=["docker","resin-dns","balena","supervisor","lo"]

def ip4_addresses():
    ip_list = []
    for interface in interfaces():
        skip=False
        for prohibited in prohibited_ifaces:
            if interface.startswith(prohibited):
                skip=True
        if skip:
            continue
        if_links = ifaddresses(interface)
        if AF_INET in if_links:
            for link in if_links[AF_INET]:
                ip_list.append(link['addr'])
    return ip_list

def get_service_type_and_name(service_name,service,domain_name):

    transport = "_" + service['transport'] if 'transport' in service else "_tcp"
    full_type = "_{service}.{transport}.local.".format(service=service_name, transport=transport)
    full_name = "_{domain}.".format(domain=domain_name) + full_type
    return full_type,full_name,transport


def unregister_services(exit_code=0):
    print("Unregistering all services ...",flush=True)
    r.unregister_all_services()
    print("Zeroconf process finished " + 'clean' if not exit_code else 'with exit code {}'.format(exit_code),flush=True)
    if child_pid!=-1:
        try:
            for child in child_process.children(recursive=True):
                child.send_signal(signal.SIGTERM)
            child_process.send_signal(signal.SIGTERM)
        except Exception as error:
            print("Failed to send TERM signal t child process, response: {}".format(error),flush=True)
    try:
        for child in child_process.children(recursive=True):
            try:
                child.wait()
            except:
                pass
        child_process.wait()
    except:
        pass
    sys.exit(exit_code)

class GracefulKiller:
  def __init__(self):
    signal.signal(signal.SIGINT, self.exit_gracefully)
    signal.signal(signal.SIGTERM, self.exit_gracefully)

  def exit_gracefully(self,signum, frame):
      unregister_services()


if __name__ == '__main__':
    if os.environ.get("SKIP_CONF"):
        execute_process(sys.argv[2], env=os.environ)
        sys.exit(0)
    if not os.environ.get("SKIP_AVAHI",None):
    
        added_keys=['USERNAME', 'SUDO_COMMAND', 'TERM', 'SHELL', 'SUDO_UID', 'SUDO_GID', 'LOGNAME', 'USER', 'SUDO_USER','MAIL']
        for key in added_keys:
            os.environ.pop(key,None)
        print("Running as {}, launch environment {}\n".format(getpass.getuser(),os.environ),flush=True)
        print("Starting MDNS resolution services...",flush=True)
        try:
            execute_process("sudo service dbus start")
            execute_process("sudo service avahi-daemon start")
        except Exception as error:
            print("Failed to start system MDNS resolving services Avahi/DBus:\n{}".format(error),flush=True)

    killer = GracefulKiller()
    print("Pre-processing configuration files and and copying them to target location...",flush=True)
    pre_process_configs()
    params = {}
    logging.basicConfig(level=logging.DEBUG)
    if len(sys.argv) > 3:
        assert sys.argv[3:] == ['--debug']
        logging.getLogger('zeroconf').setLevel(logging.DEBUG)

    print("Checking MDNS environment variables...",flush=True)
    container_name=os.environ['MDNS_HOSTNAME']
    domain_name=os.environ['MDNS_DOMAIN']
    print("Getting system IP addresses...",flush=True)
    ip_addresses=ip4_addresses()

    print("Zeroconf - Multicast DNS Service Discovery Utility started",flush=True)
    print("Container MDNS host is {}.{}".format(container_name,domain_name),flush=True)
    print("System IP addresses are: {}".format(ip_addresses),flush=True)
    r = Zeroconf()

    zc_conf = load_config(open(sys.argv[1] if len(sys.argv)>1 else "zeroconf.yaml"))
    resolve_timeout=zc_conf.get("common", {}).get('resolve-timeout',None)
    resolve_timeout=900 if not resolve_timeout else resolve_timeout
    print("Registering container name resolution service",flush=True)
    host_service_type = "_host._tcp.local."
    host_service_address = "_{container}._{domain}._host._tcp.local.".format(container=container_name, domain=domain_name)
    info = ServiceInfo(
        host_service_type,
        host_service_address,
        addresses=[socket.inet_aton(addr) for addr in ip_addresses],
        port=0
    )
    params['host']=host_service_address.strip(".")
    r.register_service(info)
    if 'services' in zc_conf:
        for service_name, service in zc_conf['services'].items():
            if service_name=='host':
                continue
            print("Service {} with {} allocation started".format(service_name,service),flush=True)
            max_port_tries = zc_conf['services'][service_name]['max-tries']
            port_step = zc_conf['services'][service_name]['port-step']
            full_type,full_name,transport=get_service_type_and_name(service_name,service,domain_name)
            allocated=False
            tries=0
            port=service['port']
            while not allocated:
                try:
                    socket_type = socket.SOCK_STREAM if transport == "_tcp" else socket.SOCK_DGRAM
                    sock=socket.socket(socket.AF_INET, socket_type)
                    server_address = ('0.0.0.0', port)
                    sock.bind(server_address)
                    if not 'socket-keepalive' in service:
                        sock.close()
                    else:
                        service['allocated-socket']=sock
                    print("Port {} allocation successfull, release strategy {}".format(port,service.get('socket-keepalive',"None")),flush=True)
                    allocated=True
                except Exception as error:
                    print ("Port {} allocation failed, trying next...".format(port),flush=True)
                    tries=tries+1
                    port+=port_step
                    if tries>max_port_tries:
                        print("FATAL: Maximal port allocation tries failed, exiting",flush=True)
                        sys.exit(1)


            print("Registration of a service {} started".format(service_name),flush=True)
            desc = service.get('description',{})


            info = ServiceInfo(
                full_type,
                full_name,
                addresses=[socket.inet_aton(addr) for addr in ip_addresses],
                port=port,
                properties=desc,
            )

            print("   Registering service...",flush=True)
            r.register_service(info)
            print("   Registration done.",flush=True)

    print("All services registered and advertized",flush=True)
    if 'require' in zc_conf:
        print("Discovering necessary services, timeout {} seconds".format(resolve_timeout),flush=True)
        not_resolved=list(zc_conf['require'].keys())
        tries=0
        while not_resolved and tries<resolve_timeout:
            rest_of_services=not_resolved.copy()
            for service_name in rest_of_services:
                service=zc_conf['require'][service_name]
                full_type, full_name, transport = get_service_type_and_name(service_name, service, domain_name)
                print("\tQueriering service {}: looking for MDNS name {}".format(service_name,full_type),flush=True)
                queried_info = r.get_service_info(full_type, full_name)
                if queried_info:
                    del not_resolved[not_resolved.index(service_name)]
                    params["{}_host".format(service_name)]=full_name.strip(".") if not os.environ.get("RESOLVE_NAMES") else '{}.{}.{}.{}'.format(*bytearray((queried_info.addresses[0])))
                    params["{}_port".format(service_name)]=queried_info.port
            print("Still not resolved: {}".format(not_resolved),flush=True)
            time.sleep(1)
            tries=tries+1
        if tries==resolve_timeout:
            print("FATAL: Required services {} was not discovered, exiting".format(not_resolved),flush=True)
            sys.exit(1)
        print("Parsing configuration files....",flush=True)
        configs_processed=parse_configs.parse_configs(params, zc_conf)
        if configs_processed:
            os.environ['ZEROCONF_CFS'] = ",".join(configs_processed)
    if 'set-environment' in zc_conf:
        print("Preparing environment variables...",flush=True)
        os.environ.update({var_name:str(parse_configs.process_mask(var_mask,params)) for var_name,var_mask in zc_conf['set-environment'].items()})
    print("Zeroconf setup done, launching main process...",flush=True)

    print("Child process environment: {}\n".format(os.environ),flush=True)
    for service_name,service in  zc_conf['services'].items():
        if service_name=='host':
            continue
        if service.get('socket-keepalive',None)=='start' and 'allocated-socket' in service:
            print("Releasing socket for service {}".format(service_name),flush=True)
            service['allocated-socket'].close()
    tty_needed=zc_conf.get('common',{}).get('tty',None)
    print("Starting process{}".format(" with tty" if tty_needed else ""),flush=True)
    exit_code= attach_tty(sys.argv[2], os.environ) if tty_needed else execute_process(sys.argv[2], env=os.environ)
    for service_name,service in  zc_conf['services'].items():
        if service_name=='host':
            continue
        if service.get('socket-keepalive',None)=='finish' and 'allocated-socket' in service:
            print("Releasing socket for service {}".format(service_name),flush=True)
            service['allocated-socket'].close()
    child_pid=-1
    unregister_services(exit_code)


