#!/usr/bin/env python

# Copyright (c) 2011. All Right Reserved, http://chart.io/
#
# THIS CODE AND INFORMATION ARE PROVIDED "AS IS" WITHOUT WARRANTY OF ANY 
# KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND/OR FITNESS FOR A
# PARTICULAR PURPOSE.

'''SSH process connection providing a remote tunnel to a database.

    With limited support for restarting the SSH process.

    Basically, runs
        ssh -v -R 3603:localhost:12345

    Designed for use on POSIX systems only (signals, forking, etc.).

'''

import ConfigParser
import logging
import logging.handlers
import optparse
import os
import signal
import socket
import subprocess
import sys
import time

logger = None
CONFIG_PATHS = None
PREFIX_DEFAULT = '~/.chartio.d'

def config_paths_set(prefix):
    global CONFIG_PATHS
    CONFIG_PATHS = ConfigPaths(prefix)

class ConfigPaths(object):
    '''Yuck. Global config class

    Double Yuck. Needs to be kept in sync with chartio_setup to limit
    PYTHONPATH dependencies.

    '''
    def __init__(self, prefix=None):
        # Directories
        self.PREFIX = prefix or os.path.expanduser(PREFIX_DEFAULT)
        self.LOG_DIRECTORY = os.path.join(self.PREFIX, 'logs')
        self.RUN_DIRECTORY = os.path.join(self.PREFIX, 'run')
        self.SSH_DIRECTORY = os.path.join(self.PREFIX, 'sshkey')
        # Files
        self.CONFIG_FILE = os.path.join(self.PREFIX, 'chartio.cfg')
        self.LOG_FILE = os.path.join(self.LOG_DIRECTORY, 'chartio_connect.log')
        self.PID_FILE = os.path.join(self.RUN_DIRECTORY, 'chartio_connect.pid')
        self.SSH_KEY = os.path.join(self.SSH_DIRECTORY, 'id_rsa')
        # Config file sections
        self.SSHTUNNEL_SECTION = 'SSHTunnel'
        # Ensure required directories exist
        map(self.dir_exists,
            (self.PREFIX, self.LOG_DIRECTORY, self.RUN_DIRECTORY, self.SSH_DIRECTORY))

    def dir_exists(self, dir):
        '''Determine whether a directory exists. Exit gracefuly if it does not.

        Arguments
            dir -- the directory path to check

        '''
        if not os.path.isdir(dir):
            sys.stderr.write('The directory %r does not exist.\n'
                             'Please run chartio_setup or rerun chartio_connect and specify --prefix\n'
                             'Exiting\n' % (dir))
            sys.exit(1)


class SSHTunnel(object):
    # Number of consecutive times to retry before exiting
    RETRIES = 15
    # Initial number of seconds between retries. Doubles for backoff.
    RETRY_DELAY = 10
    # Seconds between pings to keep tunnel alive
    PING_DELAY = 10

    def __init__(self, remoteuser, remotehost, remote_port, local_port):
        """ Typical values:
                host: amz123.chartio.com
                remote_port: 12345 (some big number)
                local_port: 3306 (mysql/pg port)
        """
        self.host = socket.gethostbyname(remotehost)
        self.user = remoteuser
        self.remote_port = int(remote_port)
        self.local_port = int(local_port)

        self._pid = None
        self._ssh_process = None
        self._retries = self.RETRIES
        self._shutdown = False
        self._retry_delay = self.RETRY_DELAY

    # Store PID in file, get and set as property. 
    # This lets a parent process set the pid and the child load it in.
    def _get_pid(self):
        if self._pid is None:
            f = open(CONFIG_PATHS.PID_FILE)
            self._pid = int(f.read())
            f.close()
        return self._pid

    def _set_pid(self, pid):
        f = open(CONFIG_PATHS.PID_FILE, 'w')
        f.write(str(pid))
        f.close()
        self._pid = pid

    def _del_pid(self):
        logger.info("Clearing pid file")
        try:
            f = open(CONFIG_PATHS.PID_FILE, 'w')
            f.close()
        except (IOError, OSError), e:
            pass

    pid = property(_get_pid, _set_pid, _del_pid)

    def _sig_chld(self, signum, frame):
        """ Handle subprocess dying """
        if self._shutdown:
            return
        logger.warning("subprocess died, about to restart")
        if self._retries == 0:
            logger.error("Number of retries has been exceeded, exiting")
            sys.exit(1)
        logger.info("waiting %s seconds to retry connection" % 
                self._retry_delay)
        signal.signal(signal.SIGALRM, self._sig_alarm)
        signal.alarm(self._retry_delay + 120)
        time.sleep(self._retry_delay)
        self._retries -= 1
        self._retry_delay *= 2
        self._make_connection()

    def _sig_alarm(self, signum, frame):
        """ Since we reached this, connection likely succeeded """
        logger.info("Connection appears to be working, resetting retries")
        self._retries = self.RETRIES
        self._retry_delay = self.RETRY_DELAY

    def _make_connection(self, debug=False):
        """ Create subprocess SSH connection """
        logger.info("Making ssh tunnel connection")
        args = ['ssh', 
            '-N', # No command, just forwarding
            '-R', # Remote forward
            '%s:localhost:%s' % (self.remote_port, self.local_port),
            '%s@%s' % (self.user, self.host),
            '-g', # Allow remote connect
            '-i', CONFIG_PATHS.SSH_KEY,
            # Keep alive. If fails, disconnect and let this script reconnect
            '-o', 'ServerAliveInterval=%s' % self.PING_DELAY,
            '-o', 'ServerAliveCountMax=1',
            # Ignore known_hosts_file feedback
            '-o', 'UserKnownHostsFile=/dev/null',
            '-o', 'StrictHostKeyChecking=no',
        ]
        if debug:
            return " ".join(args)

        signal.signal(signal.SIGCHLD, self._sig_chld)
        # XXX write stdout?
        self._ssh_process = subprocess.Popen(args, 
                stdout=None, stderr=subprocess.STDOUT)

    def daemonize(self):
        """ Return True from daemon process, otherwise exit """
        pid = os.fork()
        if pid > 0:
            print "Chartio connect daemonized at process %d" % pid
            logger.info("Daemon process running: %d" % pid)
            self.pid = pid
            sys.exit(0)
        elif pid == 0:
            # Set session id to disconnect from tty
            os.setsid()
            # Close fds 0,1,2
            for a in range(3):
                try:
                    os.close(a)
                except OSError:
                    pass
            return True
        else:
            logger.error("Error forking, cannot daemonize.")
            sys.exit(1)

    def cleanup(self, signum, frame):
        """ Called on daemon exit, hopefully """
        self._shutdown = True
        logger.info("Running cleanup")
        # Turn off special sigchld handling
        logger.debug('killing process')
        os.kill(self._ssh_process.pid, signal.SIGKILL)
        time.sleep(1) # Let sigkill propagate
        del self.pid
        logging.shutdown()
        sys.exit(0)

    @classmethod
    def kill(cls, silent=False):
        """ Try to kill daemon process """
        if os.path.exists(CONFIG_PATHS.PID_FILE):
            if not silent:
                print "Trying to stop previous running instance"
            try:
                f = open(CONFIG_PATHS.PID_FILE)
                pid = f.read()
                f.close()
                if not pid.strip():
                    return
                pid = int(pid)
                if not silent:
                    print "Stopping daemon at %s" % pid
                os.kill(pid, signal.SIGINT)
                time.sleep(1)
            except OSError:
                if not silent:
                    print "Previous instance probably isn't running anymore"
        else:
            if not silent:
                print "No pid file found, cannot kill daemon"

    def main(self, daemonize=True):
        logger.info("Starting up ssh tunnel")
        if daemonize:
            logger.info("daemonized mode")
            self.daemonize()
        else:
            logger.info("non-daemonized mode")
            print "Starting in non-daemon mode. Press ctrl-c to stop."

        self._make_connection()

        signal.signal(signal.SIGTERM, self.cleanup)
        signal.signal(signal.SIGINT, self.cleanup)

        # Wait until next sigalarm, or other terminating signal
        while True:
            signal.pause()

        # Should not be reached
        sys.exit(2)


def opt_args_gather():
    '''Parse options and arguments.

    Return
     (optparse options object, argument list)

    '''
    parser = optparse.OptionParser()
    parser.add_option('-d', '--daemonize', dest='daemonize', action="store_true", default=False,
                      help='Disassociate from the terminal and run in the background')
    parser.add_option('--prefix',
                      help=('Prefix argument (if any) used during chartio_setup configuration.'
                            ' Specifies where to find/place run-time data.'
                            ' Defaults to %r' % (PREFIX_DEFAULT)))
    opt_args = parser.parse_args()
    return opt_args


def logging_init():
    '''Configure and start logging'''
    global logger
    logger = logging.getLogger('ssh_tunnel')
    fmt = logging.Formatter('%(levelname)s @ %(asctime)s %(filename)s:%(lineno)d %(message)s')
    # Rotate log at 1 MB
    log_handler = logging.handlers.RotatingFileHandler(CONFIG_PATHS.LOG_FILE, maxBytes=2**20)
    log_handler.setFormatter(fmt)
    logger.addHandler(log_handler)
    logger.setLevel(logging.DEBUG)


def main():
    options, args = opt_args_gather()
    config_paths_set(options.prefix)

    # Get configuration
    config_file = CONFIG_PATHS.CONFIG_FILE
    if not os.path.exists(config_file) or not os.path.isfile(config_file):
        sys.stderr.write('Config file not found: %s.\n' % (config_file))
        sys.stderr.write('Please re-run installation of chartio or specify a different prefix.\n')
        sys.stderr.write('Exiting.\n')
        sys.exit(1)
    conf = ConfigParser.ConfigParser()
    conf.read(config_file)
    try:
        conf = dict(conf.items(CONFIG_PATHS.SSHTUNNEL_SECTION))
    except ConfigParser.NoSectionError:
        print 'Config file appears empty. Please run chartio_setup'
        sys.exit(1)

    logging_init()

    tunnel = SSHTunnel(conf['remoteuser'], conf['remotehost'], 
            conf['remoteport'], conf['localport'])

    action = 'start'
    if args:
        action = args[0]
    if action == 'start' or action == 'restart':
        tunnel.kill(silent=True)
        tunnel.main(options.daemonize)
    elif action == 'debug':
        print "This program would have run the following command:"
        print tunnel._make_connection(debug=True)
    elif action == 'stop':
        tunnel.kill()
    else:
        print "Incorrect usage"


if __name__ == '__main__':
    main()
