#!/usr/bin/env python
import errno
import httplib
import optparse
import os
import socket
import shlex
import sys
import time
import urllib
import urlparse
import uuid
import zipfile

from StringIO import StringIO

try:
    import json
except ImportError:
    try:
        import simplejson as json
    except ImportError:
        sys.stderr.write("Older versions of Python require simplejson.")
        sys.exit(1)

# API_HOST = "http://127.0.0.1:8000/"
API_HOST = "https://www.hydna.com/"

DEFAULT_UPLOAD_NAME = 'behavior.zip'
DEFAULT_CONFIG_NAME = '.hydna.conf'
DEFAULT_SETUP_NAME = 'setup.be'
MAX_UPLOAD_SIZE = 2

# exceptions

class APIRequestError(Exception):
    def __init__(self, url, msg):
        self.url = url
        self.msg = msg


class APIAuthenticationError(Exception):
    def __init__(self):
        self.msg = u"Unable to authenticate."


# configuration

class Config(dict):
    def __getattr__(self, name):
        return self.__getitem__(name)


def load_config(config_file):
    if config_file is None:
        config_file = os.path.join(os.getcwd(), DEFAULT_CONFIG_NAME)
        user_supplied_config = False
    else:
        user_supplied_config = True

    config_file = os.path.expanduser(config_file)
    path = os.path.dirname(config_file)

    config = Config(path=path, domain_name=None, api_key=None)

    if not os.path.isfile(config_file):
        if user_supplied_config:
            sys.stderr.write("error: Specified configuration file does not "
                             "exist.\n")
            sys.exit(1)

        # return the default config here as the user hasn't actively told us
        # to read any configuration.
        return config

    try:
        content = open(config_file, 'r').read()
    except IOError:
        sys.stderr.write("error: Unable to read configuration file.\n")
        sys.exit(1)

    lexer = shlex.shlex(content, posix=True)
    lexer.wordchars += './~\\'

    try:
        for token in lexer:
            if token not in config.keys():
                # TODO: use lexer.error_leader when we fix command output.
                sys.stderr.write("error: Invalid configuration parameter "
                                 "name '%s' on line %s\n"
                                 % (token, lexer.lineno))
                sys.exit(1)

            if lexer.get_token() != '=':
                sys.stderr.write("error: Expected '=' on line %s.\n"
                                 % lexer.lineno)
                sys.exit(1)

            value = lexer.get_token()
            if value is lexer.eof:
                sys.stderr.write("error: Expected value on line %s\n"
                                 % lexer.lineno)
                sys.exit(1)

            config[token] = value
    except ValueError, e:
        error_line = lexer.token.splitlines()[0]
        sys.stderr.write("error: Unable to parse configuration file.\n")
        sys.stderr.write("  %s following '%s'.\n" % (str(e), error_line))
        sys.exit(1)

    config.path = os.path.expanduser(config.path)

    return config

# command framework

class Option(object):
    def __init__(self, dest, opt_str=None, opt_str_long=None, required=False):
        self.dest = dest
        self.name = dest
        self.opt_str = opt_str
        self.opt_str_long = opt_str_long
        self.required = required

        if opt_str is None and opt_str_long is None:
            raise ValueError("You must specify at least one of opt_str and "
                             "opt_str_long")


class Command(object):
    name = None
    description = None
    options = []

    def __init__(self):
        usage = "%%prog %s [options]" % self.name

        self.parser = optparse.OptionParser(usage,
                                            description=self.description)
        self._required_options = []

        for option in self.options:
            self.add_option(option)
            if option.required:
                self._required_options.append(option.dest)

        # all commands can specify a path to the config file to use.
        self.add_option(Option('config_file', '-c', '--config-file'))

    def add_option(self, option):
        self.parser.add_option(option.opt_str, option.opt_str_long,
                               dest=option.dest)

    def validate(self, arguments):
        opts, bits = self.parser.parse_args(arguments)
        config = load_config(opts.config_file)

        for k, v in config.items():
            if not hasattr(opts, k) or getattr(opts, k) is None:
                setattr(opts, k, v)

        for option in self.options:
            if not option.required:
                continue

            if getattr(opts, option.name) is None:
                switch = option.opt_str_long and option.opt_str_long or \
                        option.opt_str
                sys.stderr.write("error: Missing one or more required "
                                 "options:\n\n")
                sys.stderr.write("  '%s' must be specified.\n" % switch)
                sys.exit(1)

        return opts, bits

    def error(self, message):
        self.parser.error(message)

    def run(self, opts, bits):
        raise NotImplementedError


def load_command(command_name):
    """Find and return command matching by `command_name`, or print helpful
    message to stderr and exit.

    """
    for command_cls in COMMANDS:
        if command_cls.name == command_name:
            return command_cls

    sys.stderr.write("'%s' is not a hydna command. "
                     "See 'hydna help'\n" % command_name)
    sys.exit(1)

def usage(code):
    parser = optparse.OptionParser("%prog <subcommand> [subcommand options]")
    stream = code and sys.stderr or sys.stdout
    stream.write("%s\n" % parser.get_usage())
    stream.write("Available subcommands:\n\n")
    for command_cls in COMMANDS:
        stream.write("  %s" % command_cls.name.ljust(10))
        if command_cls.description is not None:
            stream.write(command_cls.description)
        stream.write("\n")
    stream.write("\nSee '%s help SUBCOMMAND' for specific help.\n"
                 % parser.get_prog_name())
    sys.exit(code)

def dispatch():
    """Validate arguments and dispatch subcommand. Or print usage string."""
    if len(sys.argv) > 1:
        command_name = sys.argv[1]
        arguments = sys.argv[2:]

        command_cls = load_command(command_name)
        command = command_cls()
        opts, bits = command.validate(arguments)

        try:
            command.run(opts, bits)
        except (APIRequestError, APIAuthenticationError), e:
            sys.stderr.write("error: %s\n" % e.msg)
            sys.exit(1)
        except KeyboardInterrupt, e:
            print
    else:
        usage(1)

# commands

class CommandWithAPIKey(Command):
    """Baseclass that adds a switch to supply an API key, and a method
    to validate the key against the API. Keys are validated on every request,
    the method is merely supposed to be a convenience to facilitate
    displaying error messages.
    
    """
    def __init__(self, *args, **kwargs):
        option = Option('api_key', '-a', '--api-key', required=True)
        self.options.insert(0, option)
        super(CommandWithAPIKey, self).__init__(*args, **kwargs)

    def is_valid_api_key(self, api_key):
        path = '/api/1/ping/'
        try:
            response = api_request(path, method='GET', api_key=api_key)
        except (APIRequestError, APIAuthenticationError):
            return False
        return True


class PushBehaviorCommand(CommandWithAPIKey):
    """Upload behaviors in current directory to Hydna for validation and
    deployment.

    """
    name = 'push'
    description = "Validate and push behavior files to Hydna instance."

    options = [
        Option('domain_name', '-d', '--domain-name', required=True),
    ]

    def run(self, opts, bits):
        if not os.path.isfile(os.path.join(opts.path, DEFAULT_SETUP_NAME)):
            sys.stderr.write("No '%s' in '%s'.\n"
                             % (DEFAULT_SETUP_NAME, opts.path))
            sys.exit(1)

        if not self.is_valid_api_key(opts.api_key):
            sys.stderr.write("error: Invalid api key.\n")
            sys.exit(1)

        archive = self.archive(opts.path)

        if archive.tell() > MAX_UPLOAD_SIZE * 1024 * 1024:
            sys.stderr.write("Behaviors may not be larger than %s MB "
                             "in size\n" % MAX_UPLOAD_SIZE)
            sys.exit(1)

        response = self.push_behavior(opts.domain_name, opts.api_key, archive)

        if response['status']:
            if 'be_validation_error' in response:
                sys.stderr.write("error: Files did not pass validation:\n\n")
                sys.stderr.write("  %s\n" % response['be_validation_error'])
                sys.exit(1)

            sys.stderr.write("fatal: Unexpected error.\n")
            sys.exit(1)

        print "Behaviors deployed."
            
    def push_behavior(self, domain_name, api_key, archive):
        path = '/api/1/domains/%s/behaviors/' % domain_name
        files = { 'file': archive }
        def upload_callback():
            print "Deploying to domain ..."

        return api_request(path, 'POST', files=files,
                           accepted_callback=upload_callback, api_key=api_key)

    def archive(self, path):
        """Create a zip-archive of `path`."""
        buf = StringIO()
        buf.name = DEFAULT_UPLOAD_NAME

        zf = zipfile.ZipFile(buf, 'w')

        files = []
        for root, dirnames, filenames in os.walk(path):
            if os.path.basename(root) and os.path.basename(root)[0] in '._':
                continue

            for filename in filenames:
                if filename[0] in '._':
                    continue

                full_path = os.path.join(root, filename)
                files.append((full_path, full_path[len(path) + len(os.sep):]))

        compressed, total = 0, len(files)
        self.show_progress(compressed, total)
        for entry in files:
            zf.write(*entry)
            compressed = compressed + 1
            self.show_progress(compressed, total)

        print ""
        zf.close()
        return buf

    def show_progress(self, compressed, total):
        msg = "\rCompressing files: %s/%s" % (compressed, total)
        sys.stdout.write('%s%s\r' % (msg, ' ' * (80 - len(msg))))
        sys.stdout.flush()


class RestartInstanceCommand(CommandWithAPIKey):
    """Restart an instance, dropping all connections and clearing state."""
    name = 'restart'
    description = "Restart and reset Hydna instance."

    options = [
        Option('domain_name', '-d', '--domain-name', required=True),
    ]

    def run(self, opts, bits):
        if not self.is_valid_api_key(opts.api_key):
            sys.stderr.write("error: Invalid api key.\n")
            sys.exit(1)

        path = '/api/1/domains/%s/restart/' % opts.domain_name
        response = api_request(path, 'POST', api_key=opts.api_key)

        if response['status']:
            sys.stderr.write("fatal: Unexpected error.\n")
            sys.exit(1)

        print "Instance restarted."


class DummyCommand(CommandWithAPIKey):
    """Dummy command used while debugging. Please Ignore."""
    name = 'dummy'
    description = 'Will be removed after beta.'

    options = [
        Option('domain_name', '-d', '--domain-name', required=True),
    ]

    def run(self, opts, bits):
        print opts.domain_name
        print opts.path
        print opts.api_key


class HelpCommand(Command):
    name = 'help'

    description = "Display help for subcommand."

    def run(self, opts, bits):
        if not len(bits):
            usage(0)

        command_cls = load_command(bits[0])
        command = command_cls()
        command.parser.print_help()


COMMANDS = (PushBehaviorCommand, RestartInstanceCommand, DummyCommand,
            HelpCommand)

# http stuff

def http_connection_cls(for_https=False, for_file_transfer=False):
    """Generate and return a class that is tailored for the needs indicated by
    `for_https` and `for_file_transfer`.

    """
    base_cls = for_https and httplib.HTTPSConnection or httplib.HTTPConnection

    if not for_file_transfer:
        return base_cls

    class HTTPConnectionWithProgress(base_cls):
        CHUNK_SIZE = 2048

        def endheaders(self, message_body=None):
            # end headers, but do _not_ send the `message_body`
            base_cls.endheaders(self)

            # override the default send-fun with our local version that 
            # displays a progress bar.
            self.send = self.send_with_progress

            # now we can send the body
            if message_body is not None:
                self.send(message_body)

        def send_with_progress(self, data):
            total_size = len(data)
            sent = 0

            if self.sock is None:
                self.connect()

            while sent < total_size:
                offset = sent + min(total_size - sent, self.CHUNK_SIZE)
                try:
                    self.sock.sendall(data[sent:offset])
                except socket.error, e:
                    if e[0] == errno.EPIPE:
                        # the server closed the connection before the entire
                        # request was sent. this is probably because the
                        # server discarded the request on merits of lacking
                        # credentials.
                        # print "\r%s" % (27 * ' ')
                        return
                    raise
                sent = offset

                self.show_progress(offset, total_size)

            self.done(total_size)

        def show_progress(self, bit, total):   
            d = bit / float(total)
            percentage = int(100 * d)

            msg = "\rUploading files: %s%%" % str(percentage).rjust(3)
            sys.stdout.write('%s%s\r' % (msg, ' ' * (80 - len(msg))))
            sys.stdout.flush()

        def done(self, total_size):
            print ""


    return HTTPConnectionWithProgress

def encode_multipart_formdata(params):
    """Return a tuple holding a content-type, including the boundary, and the
    values encoded for use as multipart form data.

    """
    boundary = '---BOUNDARY:%s---' % uuid.uuid4().bytes.encode('hex')
    buf = StringIO()

    for key, value in params.items():
        buf.write('--%s\n\n' % boundary)
        if isinstance(value, file) or isinstance(value, StringIO):
            filename = os.path.basename(value.name)
            buf.write('Content-Disposition: form-data; name="%s"; '
                      'filename="%s"\r\n' % (key, filename))
            buf.write("Content-Type: application/octet-stream\r\n")
            if isinstance(value, file):
                buf.write("\r\n%s\r\n" % value.read())
            else:
                buf.write("\r\n%s\r\n" % value.getvalue())
        else:
            buf.write('Content-Disposition: form-data; name="%s"' % key)
            buf.write('\r\n\r\n%s\r\n' % value)
        buf.write('--%s--\r\n\r\n' % boundary)

    content_type = 'multipart/form-data; boundary=%s' % boundary

    return content_type, buf.getvalue()

def request(url, method='GET', params=None, files=None, headers=None):
    o = urlparse.urlparse(url)

    params = params or {}
    files = files or {}
    headers = headers or {}

    params.update(o.params)

    if files:
        assert(method in ('POST', 'PUT'))
        params.update(files)
        ctype, body = encode_multipart_formdata(params)
        headers['Content-type'] = ctype
        headers['Content-length'] = str(len(body))
    else:
        body = urllib.urlencode(params)

    c = http_connection_cls(o.scheme == 'https', bool(files))(o.netloc)
    c.request(method, o.path, body, headers=headers)

    response = c.getresponse()
    response.data = response.read()

    c.close()

    return response

def api_request(path, method, params=None, files=None, headers=None,
                task_id=None, accepted_callback=None, api_key=None):
    """Simple wrapper around `request()` that handles long-running jobs
    through task polling. `accepted_callback` is a function that is called
    with no arguments the first time a long-running process has been accepted.

    """
    url = urlparse.urljoin(API_HOST, path)
    headers = headers and headers.copy() or {}

    if task_id is not None:
        # purge everything associated with multipart submissions to avoid
        # multiple uploads when polling a task by id.
        files = None
        for key in ('Content-type', 'Content-length'):
            if key in headers:
                del headers[key]

        headers.update({
            'X-Task-Id': task_id,
        })

    if api_key is not None:
        headers.update({
            'X-Api-Key': api_key,
        })

    r = request(url, method, params, files, headers)

    if r.status == httplib.ACCEPTED:
        obj = json.loads(r.data)
        task_id = obj.get('task_id')

        if task_id is None:
            raise APIRequestError(url, "No task id present.")

        if accepted_callback is not None:
            accepted_callback()

            # remove the callback as it's only supposed to be called when the
            # task has been accepted.
            accepted_callback = None

        time.sleep(5.0)
        return api_request(path, method=method, params=params, files=files,
                           headers=headers, task_id=task_id)

    if r.status == httplib.OK:
        return json.loads(r.data)

    if r.status == httplib.FORBIDDEN:
        raise APIAuthenticationError()

    raise APIRequestError(url, "API request error.")

if __name__ == "__main__":
    dispatch()
