#!/usr/bin/env python
"""Customizable TCP fuzzer."""

from __future__ import print_function

import re
import sys
import argparse
import socket
from time import sleep


# -------------------------------------------------------------------------------------------------
# GLOBALS
# -------------------------------------------------------------------------------------------------

CHAR = "A"
PREFIX = ""
SUFFIX = ""
INIT_MULTIPLIER = 100
ROUND_MULTIPLIER = 100
TIMEOUT = 30.0
DELAY = 1.0


# -------------------------------------------------------------------------------------------------
# FUNCTIONS
# -------------------------------------------------------------------------------------------------


def b2str(data):
    """Convert bytes into string type."""
    try:
        return data.decode("utf-8")
    except UnicodeDecodeError:
        pass
    try:
        return data.decode("utf-8-sig")
    except UnicodeDecodeError:
        pass
    try:
        return data.decode("ascii")
    except UnicodeDecodeError:
        return data.decode("latin-1")


def print_crashlog(prefix, suffix, char, buff):
    payload = prefix + buff + suffix
    print('\nRemote service (most likely) crashed at %s bytes of "%s"' % (str(len(buff)), char))
    print("Payload sent:\n%s" % (payload))


def receive(s, timeout, bufsize=1024):
    """Read one newline terminated line from a connected socket."""
    data = ""
    size = len(data)
    s.settimeout(timeout)

    while True:
        try:
            data += b2str(s.recv(bufsize))
        except socket.error as err:
            return (False, err)
        if not data:
            return (False, "upstream connection is gone while receiving")
        # Newline terminates the read request
        if data.endswith("\n"):
            break
        if data.endswith("\r"):
            break
        # Sometimes a newline is missing at the end
        # If this round has the same data length as previous, we're done
        if size == len(data):
            break
        size = len(data)
    # Remove trailing newlines
    data = data.rstrip("\r\n")
    data = data.rstrip("\n")
    data = data.rstrip("\r")
    return (True, data)


# -------------------------------------------------------------------------------------------------
# ARGS
# -------------------------------------------------------------------------------------------------


def _args_check_init(value):
    """Check argument for valid init value."""
    for comm in value.split(","):
        if comm.find(":") == -1:
            raise argparse.ArgumentTypeError('"%s" is an invalid init value.' % value)
    return value


def _args_check_port(value):
    """Check argument for valid port number."""
    min_port = 1
    max_port = 65535

    try:
        intvalue = int(value)
    except ValueError:
        raise argparse.ArgumentTypeError('"%s" is an invalid port number.' % value)

    if intvalue < min_port or intvalue > max_port:
        raise argparse.ArgumentTypeError('"%s" is an invalid port number.' % value)
    return intvalue


def get_args():
    """Retrieve command line arguments."""
    parser = argparse.ArgumentParser(
        formatter_class=argparse.RawTextHelpFormatter,
        description="Customizable TCP fuzzing tool to test for remote buffer overflows.",
        epilog="""\
example:\n

  The following example illustrates how to use the initial communication by:
      1. Expecting the POP3 server banner
      2. Sending 'USER bob'
      3. Expecting a welcome message
  Additionally before sending the fuzzing characters, it is prepended with 'PASS ',
  so that the actuall fuzzing can be done on the password:
      $ fuzza -i ':.*POP3.*,USER bob:.*welcome.*' -p 'PASS '
""",
    )
    parser.add_argument(
        "-v",
        "--version",
        action="version",
        version="%(prog)s 0.2.0 by cytopia",
        help="Show version information,",
    )
    parser.add_argument(
        "-c",
        "--char",
        metavar="char",
        required=False,
        default=CHAR,
        type=str,
        help='Buffer character to send as payload. Default: "' + CHAR + '"',
    )
    parser.add_argument(
        "-p",
        "--prefix",
        metavar="str",
        required=False,
        default=PREFIX,
        type=str,
        help="Prefix string to prepend to buffer. Empty by default.",
    )
    parser.add_argument(
        "-s",
        "--suffix",
        metavar="str",
        required=False,
        default=SUFFIX,
        type=str,
        help="Suffix string to append to buffer. Empty by default.",
    )
    parser.add_argument(
        "-l",
        "--length",
        metavar="int",
        required=False,
        default=INIT_MULTIPLIER,
        type=int,
        help="Initial length to concat buffer string with x*char. Default: " + str(INIT_MULTIPLIER),
    )
    parser.add_argument(
        "-m",
        "--multiply",
        metavar="int",
        required=False,
        default=ROUND_MULTIPLIER,
        type=int,
        help="Round multiplier to concat buffer string with x*char every round. Default: "
        + str(ROUND_MULTIPLIER),
    )
    parser.add_argument(
        "-i",
        "--init",
        metavar="str",
        required=False,
        default=None,
        type=_args_check_init,
        help="""If specified, initializes communication in the form
'<send>:<expect>,<send>:<expect>,...'. Where <send> is the data to be sent
to the server and <expect> is the answer to be received from the server.
Regex supported for <expect> part.""",
    )
    parser.add_argument(
        "-a",
        "--answer",
        metavar="str",
        required=False,
        default=None,
        type=str,
        help="If specified, will stop if answer is not received from endpoint. Regex supported.",
    )
    parser.add_argument(
        "-t",
        "--timeout",
        metavar="float",
        required=False,
        default=TIMEOUT,
        type=float,
        help="Timeout for receiving data before declaring the endpoint as crashed. Default: "
        + str(TIMEOUT),
    )
    parser.add_argument(
        "-d",
        "--delay",
        metavar="float",
        required=False,
        default=DELAY,
        type=float,
        help="Delay in seconds between each round. Default: " + str(DELAY),
    )
    parser.add_argument("host", type=str, help="address to connect to.")
    parser.add_argument("port", type=_args_check_port, help="port to connect to.")
    return parser.parse_args()


# -------------------------------------------------------------------------------------------------
# MAIN ENTRYPOINT
# -------------------------------------------------------------------------------------------------


def main():
    """Start the program."""
    args = get_args()

    char = args.char
    imulti = args.length
    rmulti = args.multiply
    prefix = args.prefix
    suffix = args.suffix
    timeout = args.timeout
    delay = args.delay

    multiplier = imulti
    buff = char * multiplier

    while True:
        print("------------------------------------------------------------")
        print("%s * %s" % (char, str(multiplier)))
        print("------------------------------------------------------------")
        # Create socket
        try:
            s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        except socket.error as msg:
            if multiplier == imulti:
                print(msg, file=sys.stderr)
                sys.exit(1)
            else:
                print_crashlog(prefix, suffix, char, buff)
                sys.exit(0)

        # Get remote IP
        try:
            addr = socket.gethostbyname(args.host)
        except socket.gaierror as msg:
            s.close()
            print(msg, file=sys.stderr)
            sys.exit(1)

        # Connect
        try:
            s.connect((addr, args.port))
        except socket.error as msg:
            s.close()
            if multiplier == imulti:
                print(msg, file=sys.stderr)
                sys.exit(1)
            else:
                print_crashlog(prefix, suffix, char, buff)
                sys.exit(0)

        if args.init is not None:
            for comm in args.init.split(","):
                send, recv = comm.split(":")

                if len(send) > 0:
                    print("Init Sending:  %s" % (send))
                    try:
                        s.send(send.encode() + "\r\n")
                    except socket.error as msg:
                        s.close()
                        if multiplier == imulti:
                            print(msg, file=sys.stderr)
                            sys.exit(1)
                        else:
                            print_crashlog(prefix, suffix, char, buff)
                            sys.exit(0)
                if len(recv) > 0:
                    print("Init Awaiting: %s" % (recv))
                    succ, data = receive(s, timeout, 1024)
                    if not succ:
                        if multiplier == imulti:
                            print(data, file=sys.stderr)
                            sys.exit(1)
                        else:
                            print_crashlog(prefix, suffix, char, buff)
                            sys.exit(0)
                    print("Init Received: %s" % (data))
                    if data != recv and not bool(re.search(recv, data)):
                        print_crashlog(prefix, suffix, char, buff)
                        sys.exit(0)

        print('Sending "%s" + "%s"*%s + "%s"' % (prefix, char, multiplier, suffix))
        try:
            buff = char * multiplier
            string = prefix + buff + suffix

            s.send(string.encode() + "\r\n")
        except socket.error as msg:
            s.close()
            if multiplier == imulti:
                print(msg, file=sys.stderr)
                sys.exit(1)
            else:
                print_crashlog(prefix, suffix, char, buff)
                sys.exit(0)

        if args.answer is not None:
            print('String expected after sending: "%s"' % (args.answer))
            succ, data = receive(s, timeout, 1024)
            if not succ:
                print(data, file=sys.stderr)
                print_crashlog(prefix, suffix, char, buff)
                sys.exit(0)
            print('String received after sending: "%s"' % (data))
            if data != args.answer and not bool(re.search(args.answer, data)):
                print_crashlog(prefix, suffix, char, buff)
                sys.exit(0)

        s.close()
        sleep(delay)
        multiplier = multiplier + rmulti


if __name__ == "__main__":
    # Catch Ctrl+c and exit without error message
    try:
        main()
    except KeyboardInterrupt:
        print()
        sys.exit(1)
