#!/usr/bin/env python -u
# coding=utf-8
""" A Pythonic implementation of pdsh powered by sshreader
"""
from __future__ import print_function
from collections import defaultdict
from hashlib import md5
from hostlist import expand_hostlist, collect_hostlist
import click
import logging
import sshreader
import sys
# GLOBALS
__author__ = 'Jesse Almanrode'
__version__ = '2.0.1'
__examples__ = """\b
Examples:
    pydsh -w host1,host2,host3 "uname -r"
    pydsh -u root -k /root/.ssh/id_rsa -w host[1,3] "uname -r"
    pydsh -u root -P Password123 -w host[1-3] "uname -r"
"""


def output(thisjob):
    """ Print output from jobs as they complete in a format that could be piped to dshbak

    :param thisjob: <ServerJob> object
    :return: None
    """
    if thisjob.status == 255:
        result = thisjob.results[0]
    else:
        result = thisjob.results[0].stdout
    if len(result) != 0:
        for line in result.split('\n'):
            sshreader.echo(str(thisjob.name) + ': ' + str(line))
    return None


def dshbak(jobresults):
    """ Output the results of the jobs grouped by hosts

    Similar to piping output to `dshbak`

    :param jobresults: List of <ServerJob> objects
    :return: None
    """
    for thisjob in jobresults:
        if thisjob.status == 255:
            result = thisjob.results[0]
        else:
            result = thisjob.results[0].stdout
        if len(result) != 0:
            click.echo(str('-' * 16) + '\n' + str(thisjob.name) + '\n' + str('-' * 16))
            click.echo(result)
    return None


def coalesce(jobresults):
    """ Output the results of jobs coalescing identical output from hosts

    Similar to piping output to `dshbak -c`

    :param jobresults: List of <ServerJob> objects
    :return: None
    """
    job_hashes = defaultdict(list)
    output_hashes = dict()
    for job in jobresults:
        if job.status == 255:
            result = job.results[0]
        else:
            result = job.results[0].stdout
        md5sum = md5(result.encode()).hexdigest()
        job_hashes[md5sum].append(job.name)
        if md5sum not in output_hashes.keys():
            output_hashes[md5sum] = result

    for md5sum, stdout in output_hashes.items():
        if len(stdout) != 0:
            click.echo(str('-' * 16) + '\n' + collect_hostlist(job_hashes[md5sum]) + '\n' + str('-' * 16))
            click.echo(stdout)
    return None


def validate_hostlist(ctx, param, value):
    """ Callback for click to expand hostlist expressions or error
    
    :param ctx: Click context
    :param param: Parameter Name
    :param value: Hostlist expression to expand
    :return: List of expanded hosts
    """
    try:
        return expand_hostlist(value)
    except Exception:
        raise click.BadOptionUsage('Invalid hostlist expression')


@click.command(epilog=__examples__)
@click.version_option(version=__version__)
@click.option('--hostlist', '-w', metavar='EXPR', required=True, callback=validate_hostlist,
              help='Hostlist expression')
@click.option('--username', '-u', help='Override ssh username')
@click.option('--keyfile', '-k', type=click.Path(exists=True, dir_okay=False, readable=True), help='Override ssh key')
@click.option('--prompt', '-p', is_flag=True, help='Prompt for ssh password')
@click.option('--password', '-P', help='Supply ssh password')
@click.option('--timeout', '-T', default=600, help='Timeout for ssh commands')
@click.option('--dshbak', '-D', is_flag=True, help='Group output by host')
@click.option('--coalesce', '-C', is_flag=True, help='Coalesce similar output from hosts')
@click.option('--debug', '-d', is_flag=True, help='Enable debug output')
@click.option('--redline', is_flag=True, help='Run as many sub-processes as possible')
@click.argument('cmd', nargs=1, required=True)
def cli(**kwargs):
    """  Run ssh commands in parallel across hosts
    """
    sshenv = sshreader.envvars()

    if kwargs['username'] is None:
        if sshenv.username is None:
            raise click.ClickException('Unable to determine ssh username. Please provide one using --username')
        else:
            kwargs['username'] = sshenv.username

    # By default, we prefer ssh keys
    if kwargs['keyfile'] is None:
        if sshenv.rsa_key is None and sshenv.dsa_key is None:
            if kwargs['password'] is None and kwargs['prompt'] is False:
                raise click.ClickException('Unable to find ssh key to use and password not supplied.')
        else:
            if sshenv.rsa_key is not None:
                kwargs['keyfile'] = sshenv.rsa_key
            else:
                kwargs['keyfile'] = sshenv.dsa_key
    else:
        # If you specify an SSH key then we ignore any password or prompt flags you might have entered
        kwargs['password'] = None
        kwargs['prompt'] = False

    # If you specify a password or prompt for one it overrides the ssh key
    if kwargs['password'] is None:
        if kwargs['prompt'] is False:
            if kwargs['keyfile'] is None:
                raise click.ClickException('Unable to find ssh key to use and password not supplied or prompt enabled.')
        else:
            kwargs['keyfile'] = None
            while kwargs['password'] is None:
                kwargs['password'] = click.prompt(kwargs['username'] + "'s Password", hide_input=True)
    else:
        # You provided a password, ignore the SSH key
        kwargs['keyfile'] = None

    if kwargs['debug']:
        logging.getLogger('sshreader').setLevel(logging.INFO)

    post = sshreader.Hook(target=output)
    jobs = list()
    for host in kwargs['hostlist']:
        if kwargs['keyfile'] is not None:
            job = sshreader.ServerJob(host, kwargs['cmd'], username=kwargs['username'], keyfile=kwargs['keyfile'],
                                      timeout=kwargs['timeout'], combine_output=True)
        else:
            job = sshreader.ServerJob(host, kwargs['cmd'], username=kwargs['username'], password=kwargs['password'],
                                      timeout=kwargs['timeout'], combine_output=True)
        if kwargs['dshbak'] is False and kwargs['coalesce'] is False:
            job.posthook = post
        jobs.append(job)

    if kwargs['dshbak'] is False and kwargs['coalesce'] is False:
        if kwargs['redline']:
            sshreader.sshread(jobs, pcount=-1, tcount=0)
        else:
            sshreader.sshread(jobs, pcount=0, tcount=0)
    else:
        if kwargs['redline']:
            jobs_finished = sshreader.sshread(jobs, pcount=-1, tcount=0, progress_bar=True)
        else:
            jobs_finished = sshreader.sshread(jobs, pcount=0, tcount=0, progress_bar=True)
        if kwargs['coalesce']:
            coalesce(jobs_finished)
        else:
            dshbak(jobs_finished)
    sys.exit(0)

if __name__ == "__main__":
    cli()
