#!/usr/bin/env python
'''
    Copyright (c) 2015 Tim Savannah under GPLv3.

    See LICENSE for more details.

    This is a program that copies a script to any number of remote hosts and executes it, using various parameters
'''

# vim: set ts=4 sw=4 expandtab

import datetime
import os
import random
import subprocess2 as subprocess
import threading
import time
import sys

if bytes == str:
    # Python 2
    def tostr(x):
        if isinstance(x, unicode):
            return x.encode('utf-8')
        else:
            return str(x)
else:
    # Python 3
    
    def tostr(x):
        if isinstance(x, bytes) is False:
            return str(x)
        return x.decode('utf-8')

def printThreadResults(thread, omitEmpty, isQuiet):
    if omitEmpty is True:
        if not thread.results.strip():
            return
            
    if isQuiet is False:
        sys.stdout.write(thread.hostname + '\n' + ('-' * len(thread.hostname)) + '\n')
    sys.stdout.write(tostr(thread.results))
    if isQuiet is False:
        sys.stdout.write('\n' + '-' * 60 + '\n\n')

    sys.stdout.flush()


def printUsage():
    sys.stderr.write("""Usage: remote_copy_and_execute [program] [args] (--) [hostname1] [hostnameN]
Copies a script and executes on multiple hosts simultaneously. Use "--" after the args and before the list of host names.
Script must be executable by the running user

remote_copy_and_execute arguments:

  --rcae-timeout=#seconds         to use a timeout.
  --rcae-omit-empty               Omit printing empty results
  --rcae-at-a-time=#              Split application into given # chunks
  --rcae-hide-date                Do not show runtime date
  --rcae-skip-bad-hosts           Skip bad hosts. Default is to terminate.
  --rcae-quiet                    Omit all output except that from the script. Implies hide-date
  --rcae-as-user=username         Perform copy and execute as given user. Default is root.
  --rcae-print-on-host-complete   Print right after each host completes execution. Default is to print at end of each set. For now, does not print date.

  --rcae-batch                    Sets defaults [listed below]. Sane defaults for batch
                                   executions. This directive is evaluated first, so you can override
                                   the ones that take a paramater.
                                      timeout=2
                                      omit-empty
                                      at-a-time=15
                                      hide-date
                                      skip-bad-hosts

  --rcae-help                     Show this message

Example: remote_copy_and_execute get_traffic 2 -- host1 host2 host3
""")


def getRandPortion():
    return str(random.randint(1, 999999)).zfill(6)

def getRandPath(command):
    return "/tmp/%s.%s.%s.%s" %(os.path.basename(command), getRandPortion(), getRandPortion(), getRandPortion())

SSH_ARGS = ["-q", "-o", "StrictHostKeyChecking=no"]


class RunProgramThread(threading.Thread):
    def __init__(self, hostname, asUser, printOnHostComplete, omitEmpty, isQuiet, scriptLocation, args):
        threading.Thread.__init__(self)

        self.hostname = hostname
        self.scriptLocation = scriptLocation
        self.args = args
        self.asUser = asUser
        self.printOnHostComplete = printOnHostComplete
        self.omitEmpty = omitEmpty
        self.isQuiet = isQuiet

        self.connectStr = asUser + '@' + hostname

        self.cmd = ['/usr/bin/ssh'] + SSH_ARGS + [self.connectStr, scriptLocation] + args

        self.results = ''

    def run(self):

        pipe = subprocess.Popen(self.cmd, shell=False, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
        self.results = pipe.stdout.read()
        pipe.wait()
        if self.printOnHostComplete is True:
            printThreadResults(self, omitEmpty, isQuiet)

        with open('/dev/null', 'w') as devnull:
            subprocess.Popen(['/usr/bin/ssh'] + SSH_ARGS + [self.connectStr, '/bin/rm', '-f', self.scriptLocation], shell=False, stdout=devnull, stderr=devnull).wait()

if __name__ == '__main__':

    timeout = None
    rcaeArgs = [arg for arg in sys.argv if arg.startswith('--rcae-')]
    omitEmpty = False
    hideDate = False
    skipBadHosts = False
    isQuiet = False
    asUser = 'root'

    rcaeAtATime = None

    printOnHostComplete = False


    if '--rcae-batch' in rcaeArgs:
        timeout = 2
        omitEmpty = True
        hideDate  = True
        skipBadHosts = True
        rcaeAtATime = 15
        rcaeArgs.remove('--rcae-batch')
        sys.argv.remove('--rcae-batch')


    for arg in rcaeArgs:
        if arg.startswith('--rcae-timeout='):
            try:
                timeout = int(arg[len('--rcae-timeout='):])
            except ValueError:
                sys.stderr.write('--rcae-timeout expects an integer value\n')
                sys.exit(1)
        elif arg == '--rcae-omit-empty':
            omitEmpty = True
        elif arg.startswith('--rcae-at-a-time='):
            try:
                rcaeAtATime = int(arg[len('--rcae-at-a-time='):])
            except ValueError:
                sys.stderr.write('--rcae-at-a-time expects an integer value\n')
                sys.exit(1)
        elif arg == '--rcae-hide-date':
            hideDate = True
        elif arg == '--rcae-skip-bad-hosts':
            skipBadHosts = True
        elif arg.startswith('--rcae-as-user='):
            asUser = arg[len('--rcae-as-user='):]
        elif arg.startswith('--rcae-help'):
            printUsage()
            sys.exit(1)
        elif arg == '--rcae-quiet':
            isQuiet = True
            hideDate = True
        elif arg == '--rcae-print-on-host-complete':
            printOnHostComplete = True
        else:
            sys.stderr.write('Unknown arg: %s\n\n' %(arg,))
            printUsage()
            sys.exit(1)

        sys.argv.remove(arg)


    if len(sys.argv) < 4:
        sys.stderr.write('Not enough arguments.\n\n')
        printUsage()
        sys.exit(1)

    argv = sys.argv[1:]


    command = argv.pop(0)
    args = []
    hostnames = []

    doingHostnames = False

    for arg in argv:
        if arg == '--':
            doingHostnames = True
        elif doingHostnames is False:
            args.append(arg)
        else:
            hostnames.append(arg)


    if doingHostnames is False or len(hostnames) == 0:
        sys.stderr.write('No hostnames provided. Please use "--" followed by a list of hostnames (one per argument).\n Example: remote_copy_and_execute /bin/ls -l /home -- host1 host2 host3\n')
        sys.exit(1)
    
    if command.startswith('/') or command.startswith('./') or command.startswith('../'):
        if not os.path.exists(command):
            sys.stderr.write('Cannot find executable at %s\n' %(command,))
            sys.exit(1)
    else:
        allPath = os.environ.get('PATH', '/bin:/sbin:/usr/bin:/usr/sbin').split(':')
        foundIt = False
        for path in allPath:
            if not path:
                continue
            while path.endswith('/'):
                path = path[:-1]
        
            cmdPath = path + '/' + command
            if os.path.isfile(cmdPath):
                foundIt = True
                command = cmdPath
                break
        if not foundIt:
            sys.stderr.write('Cannot find %s in PATH.\n' %(command,))
            sys.exit(1)

    
    # Optionally split and chunk here
    if rcaeAtATime is not None:
        chunks = []
        curList = []
        i = 0
        while hostnames:
            if i < rcaeAtATime:
                curList.append(hostnames.pop(0))
                i += 1
            else:
                chunks.append(curList)
                curList = []
                i = 0
        chunks.append(curList)
    else:
        chunks = [hostnames]

    endpointLocation = getRandPath(command)
    devnull = open('/dev/null', 'w')

    for hostnames in chunks:
        threads = []

        pipes = []
        for hostname in hostnames:
            # here we give up to 2 seconds (default) to transfer the program, otherwise we throw it on the "bad hosts" list
            pipes.append(subprocess.Popen(['/usr/bin/scp'] + SSH_ARGS + ['-o', 'ConnectTimeout=%d' %(timeout or 2,), command, '%s@%s:%s' %(asUser, hostname, endpointLocation)], shell=False, stdout=devnull))

        numHostnames = len(hostnames)
        for i in range(numHostnames):
            hostname = hostnames[i]
            pipe = pipes[i]

            if pipe.waitOrTerminate(2)['returnCode'] != 0:
                if skipBadHosts is True:
                    sys.stderr.write('Failed to transfer application to %s. Skipping host.\n' %(hostname,))
                    continue
                else:
                    sys.stderr.write('Failed to transfer application to %s. Cannot continue.\nUse --rcae-skip-bad-hosts to continue on failed host.\n' %(hostname,))
                    sys.exit(4)
            threads.append(RunProgramThread(hostname, asUser, printOnHostComplete, omitEmpty, isQuiet, endpointLocation, args))

        if hideDate is False:
            startCtime = datetime.datetime.now().ctime()

        for thread in threads:
            thread.start()

        if timeout is None:
            for thread in threads:
                thread.join()

        else:
            anyAlive = False
            runningThreads = threads[:]
            for i in range(timeout*2):
                stillRunningThreads = []
                for thread in runningThreads:
                    if thread.is_alive():
                        thread.join(.01)
                        if thread.is_alive():
                            stillRunningThreads.append(thread)
                runningThreads = stillRunningThreads
                if len(runningThreads) == 0 or i+1 == timeout*2:
                    break
                time.sleep(.5)
            if len(runningThreads) > 0:
                for thread in runningThreads:
                    sys.stderr.write('Operation timed-out on %s\n' %(thread.hostname,))
                    thread._Thread__stop()
                time.sleep(.1)
                for thread in runningThreads:
                    thread.join(.01)
                
                

        if printOnHostComplete is False:
            if hideDate is False:
                endCtime = datetime.datetime.now().ctime()
                sys.stdout.write('\nResults from %s to %s:\n\n' %(startCtime, endCtime))
            for thread in threads:
                printThreadResults(thread, omitEmpty, isQuiet)

