#!/usr/bin/python

# Copyright (c) 2011. All Right Reserved, http://chart.io/
#
# THIS CODE AND INFORMATION ARE PROVIDED "AS IS" WITHOUT WARRANTY OF ANY 
# KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND/OR FITNESS FOR A
# PARTICULAR PURPOSE.

""" SSH tunnel, to be run by client to tunnel mysql (initially) over ssh,
    and keep connection alive.

    Basically, runs:
    ssh -v -R 3603:localhost:12345 

    For use on unix-y systems only (signals, forking, etc)
"""

import ConfigParser, logging, logging.handlers, os, socket, signal, \
        subprocess, sys, time

# Setup logging
logger = logging.getLogger('ssh_tunnel')
fmt = logging.Formatter("%(levelname)s @ %(asctime)s %(filename)s:%(lineno)d %(message)s")
log_handler = logging.handlers.RotatingFileHandler('/var/log/chartio_connect.log',
                maxBytes=2**20) # Roll at 1 MB
log_handler.setFormatter(fmt)
logger.addHandler(log_handler)
logger.setLevel(logging.DEBUG)


class SSHTunnel(object):

    CONF_FILE = '/etc/chartio/chartio.cfg'
    PIDFILE = '/var/run/chartio_connect.pid'
    SSH_KEY = '/etc/chartio/sshkey/id_rsa'
    RETRIES = 15 # How many times to retry before giving up
    RETRY_DELAY = 10 # Seconds between retries
    PING_DELAY = 10 # Time between pings to keep tunnel alive

    def __init__(self, remoteuser, remotehost, remote_port, local_port):
        """ Typical values:
                host: amz123.chartio.com
                remote_port: 12345 (some big number)
                local_port: 3306 (mysql/pg port)
        """
        self.host = socket.gethostbyname(remotehost)
        self.user = remoteuser
        self.remote_port = int(remote_port)
        self.local_port = int(local_port)

        self._pid = None
        self._ssh_process = None
        self._retries = self.RETRIES
        self._shutdown = False
        self._retry_delay = self.RETRY_DELAY

    # Store pid in file, get and set as property. 
    # This let's a parent process set the pid, and the child load it in.
    def _get_pid(self):
        if self._pid is None:
            f = open(self.PIDFILE)
            self._pid = int(f.read())
            f.close()
        return self._pid

    def _set_pid(self, pid):
        f = open(self.PIDFILE, 'w')
        f.write(str(pid))
        f.close()
        self._pid = pid

    def _del_pid(self):
        logger.info("Clearing pid file")
        try:
            f = open(self.PIDFILE, 'w')
            f.close()
        except (IOError, OSError), e:
            pass

    pid = property(_get_pid, _set_pid, _del_pid)

    def _sig_chld(self, signum, frame):
        """ Handle subprocess dying """
        if self._shutdown:
            return
        logger.warning("subprocess died, about to restart")
        if self._retries == 0:
            logger.error("Number of retries has been exceeded, exiting")
            sys.exit(1)
        logger.info("waiting %s seconds to retry connection" % 
                self._retry_delay)
        signal.signal(signal.SIGALRM, self._sig_alarm)
        signal.alarm(self._retry_delay + 120)
        time.sleep(self._retry_delay)
        self._retries -= 1
        self._retry_delay *= 2
        self._make_connection()

    def _sig_alarm(self, signum, frame):
        """ Since we reached this, connection likely succeeded """
        logger.info("Connection appears to be working, resetting retries")
        self._retries = self.RETRIES
        self._retry_delay = self.RETRY_DELAY

    def _make_connection(self, debug=False):
        """ Create subprocess SSH connection """
        logger.info("Making ssh tunnel connection")
        args = ['ssh', 
            '-N', # No command, just forwarding
            '-R', # Remote forward
            '%s:localhost:%s' % (self.remote_port, self.local_port),
            '%s@%s' % (self.user, self.host),
            '-g', # Allow remote connect
            '-i', self.SSH_KEY,
            # Keep alive. If fails, disconnect and let this script reconnect
            '-o', 'ServerAliveInterval=%s' % self.PING_DELAY,
            '-o', 'ServerAliveCountMax=1',
            # Ignore known_hosts_file feedback
            '-o', 'UserKnownHostsFile=/dev/null',
            '-o', 'StrictHostKeyChecking=no',
        ]
        if debug:
            return " ".join(args)

        signal.signal(signal.SIGCHLD, self._sig_chld)
        # XXX write stdout?
        self._ssh_process = subprocess.Popen(args, 
                stdout=None, stderr=subprocess.STDOUT)

    def daemonize(self):
        """ Return True from daemon process, otherwise exit """
        pid = os.fork()
        if pid > 0:
            print "SSHTunnel watcher daemonized at process %d" % pid
            logger.info("Daemon process running: %d" % pid)
            self.pid = pid
            sys.exit(0)
        elif pid == 0:
            # Set session id to disconnect from tty
            os.setsid()
            # Close fds 0,1,2
            for a in range(3):
                try:
                    os.close(a)
                except OSError:
                    pass
            return True
        else:
            logger.error("Error forking, cannot daemonize.")
            sys.exit(1)

    def cleanup(self, signum, frame):
        """ Called on daemon exit, hopefully """
        self._shutdown = True
        logger.info("Running cleanup")
        # Turn off special sigchld handling
        logger.debug('killing process')
        self._ssh_process.kill()
        del self.pid
        logging.shutdown()
        sys.exit(0)

    @classmethod
    def kill(cls, silent=False):
        """ Try to kill daemon process """
        if os.path.exists(cls.PIDFILE):
            if not silent:
                print "Trying to stop previous running instance"
            try:
                f = open(cls.PIDFILE)
                pid = f.read()
                f.close()
                if not pid.strip():
                    return
                pid = int(pid)
                if not silent:
                    print "Stopping daemon at %s" % pid
                os.kill(pid, signal.SIGINT)
            except OSError:
                if not silent:
                    print "Previous instance probably isn't running anymore"
        else:
            if not silent:
                print "No pid file found, cannot kill daemon"

    def main(self, daemonize=True):
        logger.info("Starting up ssh tunnel")
        if daemonize:
            logger.info("daemonized mode")
            self.daemonize()
        else:
            logger.info("non-daemonized mode")
            print "Starting in non-daemon mode. Press ctrl-c to stop."

        self._make_connection()

        signal.signal(signal.SIGTERM, self.cleanup)
        signal.signal(signal.SIGINT, self.cleanup)

        # Wait until next sigalarm, or other terminating signal
        while True:
            signal.pause()

        # Should not be reached
        sys.exit(2)


def main():
    # Read options
    from optparse import OptionParser
    parse = OptionParser()
    parse.add_option('-d', '--daemonize', dest='daemonize',
        action="store_true", default=False)

    options, args = parse.parse_args()

    action = 'start'
    if args:
        action = args[0]

    # Check conf file is there
    if not os.path.exists(SSHTunnel.CONF_FILE):
        print "Config file not found. Please re-run installation of chartio"
        sys.exit(1)
    
    # Read in values from cfg file
    conf = ConfigParser.ConfigParser()
    conf.read(SSHTunnel.CONF_FILE)
    try:
        conf = dict(conf.items('SSHTunnel'))
    except ConfigParser.NoSectionError:
        print "Config file appears empty. Please run chartio_setup"
        sys.exit(1)
    tunnel = SSHTunnel(conf['remoteuser'], conf['remotehost'], 
            conf['remoteport'], conf['localport'])

    if action == 'start' or action == 'restart':
        tunnel.kill(silent=True)
        tunnel.main(options.daemonize)
    elif action == 'debug':
        print "This program would have run the following command:"
        print tunnel._make_connection(debug=True)
    elif action == 'stop':
        tunnel.kill()
    else:
        print "Incorrect usage"


if __name__ == '__main__':
    main()
