#!/usr/bin/env python2

import os
import sys
import signal
import time

import ArgumentParser

from socket_gatekeeper.Listener import Listener
from socket_gatekeeper.Handler import DEFAULT_CLIENT_BUFFER_LEN, DEFAULT_ENDPOINT_BUFFER_LEN, handlerFilterQuit

from socket_gatekeeper.MappingsParser import MappingsFileParser, ParseMappingException


def printUsage():
    sys.stderr.write('''Usage: %s [arguments]
Starts a socket gatekeeper instance. Description of application at bottom.

 Required:

     --mappings=/path/to/file
        or                           
     -m /path/to/file                 Path to the mappings file. See MAPPING FORMAT below for details.

     --bind=addr:port
        or
     -b addr:port                     Listen on this interface and port.
                                       Example, listen on all interfaces port 51000:
                                       --bind=0.0.0.0:51000


 Options:

      --help                          Show this message and quit
      --client-buffer-len=X           Use X bytes max buffer in data to/from the client. Defaults to %d
      --endpoint-buffer-len=X         Use X bytes max buffer in data to/from the endpoint. Defaults to %d
      --enable-quit                   Enable intercepting the messages "quit" and "exit" to terminate connection.
      


DESCRIPTION
-----------

Socket gatekeeper instance listens on a local interface/port. Any connections are silently
prompted for a password. That password is read, hashed, and checked against a provided database of
one-way encrypted passwords. Each password specifies an endpoint address and port.

You can use this to open a single port on a machine which routes privileged users to their
appropriate services. You can also use this as a "front door" for services that don't
support authentication out-of-the-box to add a layer of security.


MAPPING FORMAT
--------------

The mapping file is defined like this:

sha256sum_password=ADDR:port

You can use a procedure like the following to generate a sha256sum of your password:


Example, for password "abc" to resolve to localhost port 80:

echo -n "abc" | sha256sum | awk {'print $1'}

Returns edeaaff3f1774ad2888673770c6d64097e391bc362d7d6fb34982ddf0efd18cb

So the mapping file would contain

edeaaff3f1774ad2888673770c6d64097e391bc362d7d6fb34982ddf0efd18cb = 0.0.0.0:80

You can have several mappings in the same file.
You can have duplicates of the endpoints, but you can not have duplicate passwords.

''' %(sys.argv[0], DEFAULT_CLIENT_BUFFER_LEN, DEFAULT_ENDPOINT_BUFFER_LEN)
    )

def errorUsageAndExit(msg):
    sys.stderr.write(msg)
    sys.stderr.write('\n\n')
    time.sleep(1.5)
    printUsage()
    sys.exit(1)

if __name__ == '__main__':

    parser = ArgumentParser.ArgumentParser(
        ('mappingsFilename', 'clientBufferLen', 'endpointBufferLen', 'bind' ),
        ('m', None, None, 'b' ),
        ('mappings', 'client-buffer', 'endpoint-buffer', 'bind' ),
        ['--help', '--enable-quit'],
        {},
        False
    )
    parseResults = parser.parse(sys.argv[1:])
    args = parseResults['result']

    if args['--help'] is True:
        printUsage()
        sys.exit(0)

    if parseResults['errors']:
        errorUsageAndExit('Error in arguments:\n%s' %('\n'.join(parseResults['errors']), ))
 
    if 'clientBufferLen' in args:
        try:
            overrideClientBufferLen = int(args['clientBufferLen'])
            if overrideClientBufferLen <= 0:
                raise ValueError
        except ValueError:
            errorUsageAndExit('Client buffer length must be an integer > 0.')
    else:
        overrideClientBufferLen = None

    if 'endpointBufferLen' in args:
        try:
            overrideEndpointBufferLen = int(args['endpointBufferLen'])
            if overrideEndpointBufferLen <= 0:
                raise ValueError
        except ValueError:
            errorUsageAndExit('Endpoint buffer length must be an integer > 0.')
    else:
        overrideEndpointBufferLen = None
 

    if 'bind' not in args:
        errorUsageAndExit('A bind param is required. Specify one with --bind=addr:port or -m addr:port.\nEx: 0.0.0.0:51000')

    bindValue = args['bind']
    bindValueSplit = bindValue.split(':')
    if len(bindValueSplit) != 2 or bindValueSplit[1].isdigit() is False:
        sys.stderr.write('Invalid bind param: "%s".\nMust be in format Addr:port like 0.0.0.0:51000\n\n')
        sys.exit(1)

    bindAddr = bindValueSplit[0]
    bindPort = int(bindValueSplit[1])

    if 'mappingsFilename' not in args:
        errorUsageAndExit('A mappings file is required. Specify one with --mappings=/path/to/file or -m /path/to/file')

    mappingsFilename = args['mappingsFilename']

    if not mappingsFilename or not os.path.isfile(mappingsFilename):
        sys.stderr.write('Cannot find specified mappings file: "%s" or is not a file.\n Check the path and try again.\n\n')
        sys.exit(1)
   

    mappingsParser = MappingsFileParser(mappingsFilename)

    try:
        mappings = mappingsParser.getMappings()
    except ParseMappingException as e:
        sys.stderr.write('Error parsing mappings file. See --help for format information: \n%s\n\n' %(str(e),))
        sys.exit(1)


        

    listener = Listener(bindAddr, bindPort, mappings, overrideClientBufferLen, overrideEndpointBufferLen)

    if args['--enable-quit'] is True:
        listener.setApplyFiltersToHandlerFunction( lambda handler, sha256, mapping : handler.addIncomingFilter(handlerFilterQuit) )

    listener.start()

    globalIsTerminating = False

    def handleSigTerm(*args):
        global listener
        global globalIsTerminating
        sys.stderr.write('CALLED\n')
        if globalIsTerminating is True:
            return # Already terminating
        globalIsTerminating = True
        sys.stdout.write('Caught signal, shutting down listeners...\n')
        try:
            os.kill(listener.pid, signal.SIGTERM)
        except:
            pass
        sys.stderr.write('Sent signal to children, waiting up to 4 seconds then trying to clean up\n')
        time.sleep(1)
        listener.join(4)

        if listener.is_alive():
            try:
                os.kill(listener.pid, signal.SIGKILL)
            except:
                pass
            time.sleep(.1)
            sys.stderr.write('Starting final join\n')
            sys.stderr.flush()
            listener.join()
            sys.stderr.write('Done\n')
            sys.stderr.flush()
        time.sleep(5)
        signal.signal(signal.SIGTERM, signal.SIG_DFL)
        signal.signal(signal.SIGINT, signal.SIG_DFL)
        os.kill(os.getpid(), signal.SIGTERM)
        return 0
        

    signal.signal(signal.SIGTERM, handleSigTerm)
    signal.signal(signal.SIGINT, handleSigTerm)

    while True:
        try:
            time.sleep(2)
        except:
            os.kill(os.getpid(), signal.SIGTERM)

# vim: ts=4 sw=4 expandtab
