#!/usr/bin/env python

# Copyright (c) 2011-2016 Timothy Savannah under GPLv3, All Rights Reserved. See LICENSE for more information
"""
Disttask is a utility which provides the ability to distribute a task across a fixed number of processes, for better utilization of multiprocessing.

Use it with existing single-threaded/process tools and scripts to take full advantage of your computer's resources.

"""

import os
import sys
import select
import signal
import subprocess
import threading
import time

from collections import deque

__version__ = '2.1.1'

__version_tuple__ = (2, 1, 1)

try:
    bytes
except:
    bytes = str # Python < 2.6
    
if bytes == str:
    # Python 2, no additional decoding necessary.
    tostr = str
else:
    # Python 3, additional decoding necessary
    try:
        defaultEncoding = sys.getdefaultencoding()
    except:
        defaultEncoding = 'utf-8'
    
    def tostr(x):
        if isinstance(x, str) is True:
            return x
        if isinstance(x, bytes) is False:
            return str(x)
        return x.decode(defaultEncoding)


class StdoutWriter(threading.Thread):

    # FLUSH_EVERY - Explicitly flush after this many items.
    FLUSH_EVERY = 1

    def __init__(self, *args, **kwargs):
        threading.Thread.__init__(self, *args, **kwargs)

        self.stdoutData = deque()

        self.keepGoing = True

    def addData(self, data):
        self.stdoutData.append(data)

    def setFlushEvery(self, nWrites):
        self.FLUSH_EVERY = nWrites

    def run(self):
        time.sleep(.001) # Block immediatly whilst setup happens
        stdoutData = self.stdoutData

        flushEvery = self.FLUSH_EVERY

        try:
            writeOutput = sys.stdout.buffer.write
        except:
            writeOutput = sys.stdout.write

        while self.keepGoing is True or len(stdoutData) > 0:
            i = 0
            while len(stdoutData) > 0:
                nextItem = stdoutData.popleft()
                writeOutput(nextItem)
                i += 1
                if i >= flushEvery:
                    i = 0
                    sys.stdout.flush()

            sys.stdout.flush()
            time.sleep(.0005)

class Runner(threading.Thread):

    def __init__(self, cmd, stdoutWriter, thisItem, collateOutput=True):
        threading.Thread.__init__(self)
        self.cmd = cmd
        self.stdoutWriter = stdoutWriter

        self.thisItem = thisItem
        self.collateOutput = collateOutput

        self.keepGoing = True

    def run(self):
        pipe = subprocess.Popen(self.cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
        if self.collateOutput is True:
            output = []
            def handleLine(line):
                output.append(line)
        else:
            thisItem = self.thisItem

            if sys.version_info.major >= 3:
                def handleLine(line):
                    prefix = ('[%s] ' %(thisItem,)).encode(defaultEncoding)
                    self.stdoutWriter.addData(prefix + line)
            else:
                def handleLine(line):
                    self.stdoutWriter.addData('[%s] %s' %(thisItem, line))

        pipeStdout = pipe.stdout
        i = 0
        while self.keepGoing is True and (not pipeStdout.closed or pipe.poll() is not None):
            try:
                (rlist, wlist, errors) = select.select([pipeStdout], [], [pipeStdout], .004)
                if errors:
                    try:
                        pipeStdout.close()
                    except:
                        pass
                    break

                if not rlist:
                    time.sleep(.002)
                    continue

                line = pipeStdout.readline()
                    
                if line == b'':
                    break

                handleLine(line)
            except Exception as e:
                keepGoing = False
                pipe.terminate()
                sys.stderr.write('Got exception: %s\n' %(str(e),))
                break
        pipe.wait()
        if self.collateOutput is True:
            try:
                self.stdoutWriter.addData(''.join(output))
            except:
                for item in output:
                    self.stdoutWriter.addData(item)

class DistTask(object):
    def __init__(self, cmd, concurrent_tasks, argset, stdoutWriter, endWhenDone=True, collateOutput=True):
        self.cmd = cmd
        self.concurrent_tasks = concurrent_tasks
        self.argset = deque(argset)
        self.stdoutWriter = stdoutWriter
        self.endWhenDone = endWhenDone

        if self.endWhenDone is False:
            self.keepGoing = True

        self.collateOutput = collateOutput

        # keepGoing is an attribute when end

    def addToArgset(self, items):
        self.argset += items

    def addItemToArgset(self, item):
        self.argset.append(item)

    def run(self):
        argset = self.argset
        for i in range(self.concurrent_tasks):
            pipes.append(None)

        pipesRunning = -1

        stdoutWriter = self.stdoutWriter

        if self.endWhenDone is True:
            shouldKeepGoing = lambda : bool(pipesRunning != 0)
        else:
            shouldKeepGoing = lambda : bool(self.keepGoing is True or (len(self.argset) > 0 or pipesRunning > 0))

        collateOutput = self.collateOutput

        while shouldKeepGoing():
            pipesRunning = 0
            for i in range(self.concurrent_tasks):
                if pipes[i] is None:
                    if len(argset) > 0:
                        nextItem = argset.popleft()
                        cmd = self.cmd.replace('%s', nextItem).replace('%d', str(i))
                        pipes[i] = Runner(cmd, stdoutWriter, nextItem, collateOutput)
                        pipes[i].start()
                        pipesRunning += 1
                else:
                    if pipes[i].isAlive() is False:
                        if len(argset) > 0:
                            nextItem = argset.popleft()
                            cmd = self.cmd.replace('%s', nextItem).replace('%d', str(i))
                            pipes[i].join() # cleanup
                            pipes[i] = Runner(cmd, stdoutWriter, nextItem, collateOutput)
                            pipes[i].start()
                            pipesRunning += 1
                    else:
                        pipesRunning += 1

            time.sleep(.0002)

        stdoutWriter.keepGoing = False

if (__name__ == "__main__"):
    args = sys.argv[1:]

    collateOutput = True
    if '-nc' in args:
        args.remove('-nc')
        collateOutput = False
    if '--no-collate' in args:
        args.remove('--no-collate')
        collateOutput = False

    if '--version' in args:
        sys.stderr.write('disttask version %s by Tim Savannah\n' %(__version__,))
        sys.exit(0)

    if len(args) < 3 or '--help' in args:
        sys.stderr.write("Usage: " + os.path.basename(sys.argv[0]) + " [cmd] [concurrent tasks] [argset]\n\n")
        sys.stderr.write("Use a %s in [cmd] where you want the args to go. use %d for the pipe number.\nTo run a list of commands, make '%s' be your full command.\n\n")
        sys.stderr.write("If argset is '--', the items will be read from stdin instead of providing the arguments to disttask.\n  Execution will start immediately, so you can have disttask manage processing items that another program is feeding in.\n")
        sys.stderr.write('''
    Options:

       -nc or --no-collate          By default, the output will be held until the task is completed, so output is not intermixed.
                                       By providing "-nc" or "--no-collate", instead each line that comes in from any running task
                                       is printed, prefixed with the argset in square-brackets (e.x.  "[arg1] Some message"

 Example:  disttask "ssh root@%s hostname" 3 host1 host2 host3 host4 host5 host6 # Connect and get hostname on 6 hosts, 3 at a time.
''')

        sys.stderr.write("\ndisttask version " + __version__ + "\n")
        sys.exit(1)


    pipes = []


    cmd = args.pop(0)
    if cmd.find('%s') == -1:
        sys.stderr.write("No %s in command!\n")
        sys.exit(1)

    concurrent_tasks = args.pop(0)
    if concurrent_tasks.isdigit() is False:
        sys.stderr.write('Number of concurrent tasks must be an integer, not "%s"\n' %(concurrent_tasks, ))
        sys.exit(1)

    concurrent_tasks = int(concurrent_tasks)
    argset = args
 

    stdoutWriter = StdoutWriter()
    if collateOutput is False:
        stdoutWriter.setFlushEvery(10)
    stdoutWriter.start()

    if len(argset) == 1 and argset[0] == '--':
        runner = DistTask(cmd, concurrent_tasks, [], stdoutWriter, endWhenDone=False, collateOutput=collateOutput)
        runnerThread = threading.Thread(target=runner.run)
        runnerThread.start()

        nextItem = None
        while not sys.stdin.closed:
            try:
                nextItem = sys.stdin.readline()
                if nextItem == '':
                    break
            except:
                break
            runner.addItemToArgset(nextItem[:-1])

        runner.keepGoing = False
        runnerThread.join()
    else:
        runner = DistTask(cmd, concurrent_tasks, argset, stdoutWriter, endWhenDone=True, collateOutput=collateOutput)
        runner.run()

# vim: set ts=4 sw=4 expandtab
