#!/usr/bin/env python3

import sys
import os


def shellquote(args):
    """Escape a list of strings or bytestrings, returning a single string appropraiate to
    pass to a shell. Works with bash and zsh."""

    import re

    #Make a dictionary containing escape codes for all bytes than need
    # escaping in a dollar-quoted string on the shell:
    ESCAPE_CODES = {'\\': R'\\',
                    "'" : R"\'",
                    '\0': R'\0',
                    '\a': R'\a',
                    '\b': R'\b',
                    '\f': R'\f',
                    '\n': R'\n',
                    '\r': R'\r',
                    '\t': R'\t',
                    '\v': R'\v'}
    # Add codes for all bytes that don't already have their own code.
    # Control chars below byte 32:
    for i in range(0, 32):
        if chr(i) not in ESCAPE_CODES:
            ESCAPE_CODES[chr(i)] = r'\x{0:02x}'.format(i)
    # The delete char, byte 127:
    ESCAPE_CODES['\x7f'] = r'\x7f'
    # bytes > 127:
    for i in range(128, 256):
        # Both surrogate escapes and normal bytes embedded in unicode strings:
        ESCAPE_CODES[chr(0xdc00 + i)] = ESCAPE_CODES[chr(i)] = r'\x{0:02x}'.format(i)

    # Control characters, or bytes that are invalid for the filesystem encoding:
    dollar_quote_required = re.compile(r'[\u007f]|[\udc80-\udcff]|[\u0000-\u001f]')

    # Chars we will escape if dollar quotes are used
    dollar_quote_to_escape = re.compile(r'[\\\']|' + dollar_quote_required.pattern)

    # Characters that need to be quoted normally:
    normal_quote_required = re.compile(r'[^\w@%+=:,./-]')

    # Loop over the args, quoting and escaping them as necessary:
    quoted_args = []
    for s in args:
        if isinstance(s, bytes):
            s = os.fsdecode(s)
        if dollar_quote_required.search(s):
            s = "$'" + dollar_quote_to_escape.sub(lambda m: ESCAPE_CODES[m.group()], s) + "'"
        elif normal_quote_required.search(s):
            s = "'" + s.replace("'", "'\"'\"'") + "'"
        quoted_args.append(s)
    return ' '.join(quoted_args)


def get_command():
    """Parse sys.argv to get the user's command"""
    if len(sys.argv) > 2 and sys.argv[1] == '--script':
        # The command is provided as a single, already quoted  argument to be
        # entered into the shell as-is:
        command = sys.argv[2]
    else:
        # We have the unquoted arguments - quote them as neccesary before
        # entering them into a shell:
        command = shellquote(sys.argv[1:])

    command = os.fsencode(command)
    return command


def run_command(command):
    """launch a shell and inject the user's command into it. This is done with
    a pseudo-tty using pexpect, so that the shell does everything it would if
    the user typed the command themselves"""

    import pexpect, struct, fcntl, termios, signal, select

    # pexpect does not forward sigwinch, so we have to do that ourselves:
    def sigwinch_passthrough (sig, data):
        s = struct.pack("HHHH", 0, 0, 0, 0)
        a = struct.unpack('hhhh', fcntl.ioctl(sys.stdout.fileno(), termios.TIOCGWINSZ , s))
        child.setwinsize(a[0],a[1])


    # Start the shell process:
    child = pexpect.spawn(os.getenv('SHELL'))
    child.delaybeforesend = 0

    # Connect SIGWINCH:
    signal.signal(signal.SIGWINCH, sigwinch_passthrough)

    # Wait until the shell has output something, hopefully its prompt, so that
    # we don't write to it prior to the prompt being displayed:
    poll = select.poll()
    poll.register(child.child_fd, select.POLLIN)
    poll.poll()

    # Input the command, if any:
    if command:
        child.sendline(command)

    # Pass control to the user until they quit:
    child.interact(escape_character=None)
    child.close()
    if child.exitstatus is None:
        return 128 + child.signalstatus
    else:
        return child.exitstatus


def main():
    try:
        command = get_command()
        sys.exit(run_command(command))
    except Exception:
        import traceback
        traceback.print_exc()
        try:
            # Hold open so the user can see the traceback:
            input()
        except:
            pass


if __name__ == '__main__':
    main()
