#!/usr/bin/env python
"""Customizable TCP fuzzer."""

from __future__ import print_function

import re
import sys
import argparse
import socket
from time import sleep


# -------------------------------------------------------------------------------------------------
# GLOBALS
# -------------------------------------------------------------------------------------------------

CHAR = "A"
PREFIX = ""
SUFFIX = ""
INIT_MULTIPLIER = 100
ROUND_MULTIPLIER = 100
TIMEOUT = 30.0
DELAY = 1.0


# -------------------------------------------------------------------------------------------------
# FUNCTIONS
# -------------------------------------------------------------------------------------------------


def b2str(data):
    """Convert bytes into string type."""
    try:
        return data.decode("utf-8")
    except UnicodeDecodeError:
        pass
    try:
        return data.decode("utf-8-sig")
    except UnicodeDecodeError:
        pass
    try:
        return data.decode("ascii")
    except UnicodeDecodeError:
        return data.decode("latin-1")


def connect(host, port):
    """Connect to remote host."""
    # Create socket
    try:
        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    except socket.error as msg:
        return (None, msg)
    # Get remote IP
    try:
        addr = socket.gethostbyname(host)
    except socket.gaierror as msg:
        s.close()
        return (None, msg)
    # Connect
    try:
        s.connect((addr, port))
    except socket.error as msg:
        s.close()
        return (None, msg)

    return (s, None)


def print_crashlog(prefix, suffix, char, buff):
    payload = prefix + buff + suffix
    print('\nRemote service (most likely) crashed at %s bytes of "%s"' % (str(len(buff)), char))
    print("Payload sent:\n%s" % (payload))


def send(s, data):
    """Send data to socket."""
    try:
        s.send(data.encode() + "\r\n")
    except socket.error as msg:
        s.close()
        return (False, msg)

    return (True, None)


def receive(s, timeout, bufsize=1024):
    """Read one newline terminated line from a connected socket."""
    data = ""
    size = len(data)
    s.settimeout(timeout)

    while True:
        try:
            data += b2str(s.recv(bufsize))
        except socket.error as err:
            return (False, err)
        if not data:
            return (False, "upstream connection is gone while receiving")
        # Newline terminates the read request
        if data.endswith("\n"):
            break
        if data.endswith("\r"):
            break
        # Sometimes a newline is missing at the end
        # If this round has the same data length as previous, we're done
        if size == len(data):
            break
        size = len(data)
    # Remove trailing newlines
    data = data.rstrip("\r\n")
    data = data.rstrip("\n")
    data = data.rstrip("\r")
    return (True, data)


# -------------------------------------------------------------------------------------------------
# ARGS
# -------------------------------------------------------------------------------------------------


def _args_check_init(value):
    """Check argument for valid init value."""
    for comm in value.split(","):
        if comm.find(":") == -1:
            raise argparse.ArgumentTypeError('"%s" is an invalid init value.' % value)
    return value


def _args_check_port(value):
    """Check argument for valid port number."""
    min_port = 1
    max_port = 65535

    try:
        intvalue = int(value)
    except ValueError:
        raise argparse.ArgumentTypeError('"%s" is an invalid port number.' % value)

    if intvalue < min_port or intvalue > max_port:
        raise argparse.ArgumentTypeError('"%s" is an invalid port number.' % value)
    return intvalue


def get_args():
    """Retrieve command line arguments."""
    parser = argparse.ArgumentParser(
        formatter_class=argparse.RawTextHelpFormatter,
        description="Customizable TCP fuzzing tool to test for remote buffer overflows.",
        epilog="""\
example:\n

  The following example illustrates how to use the initial communication by:
      1. Expecting the POP3 server banner
      2. Sending 'USER bob'
      3. Expecting a welcome message
  Additionally before sending the fuzzing characters, it is prepended with 'PASS ',
  so that the actuall fuzzing can be done on the password:
     1. Prefix payload with 'PASS '
     2. Send payload
  Lastly in order to also close the connection the '-e' opton is used
  (which works exactly as '-i') in order to send data after the payload.
     1. Expect any response from password payload
     2. Terminate the connection via QUIT
     3. Do not expect a follow up response

     $ fuzza -i ':.*POP3.*,USER bob:.*welcome.*' -e ':.*,QUIT:' -p 'PASS '
""",
    )
    parser.add_argument(
        "-v",
        "--version",
        action="version",
        version="%(prog)s 0.3.0 by cytopia",
        help="Show version information,",
    )
    parser.add_argument(
        "-c",
        "--char",
        metavar="char",
        required=False,
        default=CHAR,
        type=str,
        help='Buffer character to send as payload. Default: "' + CHAR + '"',
    )
    parser.add_argument(
        "-p",
        "--prefix",
        metavar="str",
        required=False,
        default=PREFIX,
        type=str,
        help="Prefix string to prepend to buffer. Empty by default.",
    )
    parser.add_argument(
        "-s",
        "--suffix",
        metavar="str",
        required=False,
        default=SUFFIX,
        type=str,
        help="Suffix string to append to buffer. Empty by default.",
    )
    parser.add_argument(
        "-l",
        "--length",
        metavar="int",
        required=False,
        default=INIT_MULTIPLIER,
        type=int,
        help="Initial length to concat buffer string with x*char. Default: " + str(INIT_MULTIPLIER),
    )
    parser.add_argument(
        "-m",
        "--multiply",
        metavar="int",
        required=False,
        default=ROUND_MULTIPLIER,
        type=int,
        help="Round multiplier to concat buffer string with x*char every round. Default: "
        + str(ROUND_MULTIPLIER),
    )
    parser.add_argument(
        "-i",
        "--init",
        metavar="str",
        required=False,
        default=None,
        type=_args_check_init,
        help="""If specified, initializes communication before sending the payload in the form
'<send>:<expect>,<send>:<expect>,...'. Where <send> is the data to be sent
to the server and <expect> is the answer to be received from the server.
Either one of <send> or <expect> can be omitted if you expect something without
having sent data yet or need to send something for which there will not be an
answer. Multiple <send>:<expect> are supported and must be separated by a comma.
Regex supported for <expect> part.""",
    )
    parser.add_argument(
        "-e",
        "--exit",
        metavar="str",
        required=False,
        default=None,
        type=_args_check_init,
        help="""If specified, finalizes communication after sending the payload in the form
'<send>:<expect>,<send>:<expect>,...'. Where <send> is the data to be sent
to the server and <expect> is the answer to be received from the server.
Either one of <send> or <expect> can be omitted if you expect something without
having sent data yet or need to send something for which there will not be an
answer. Multiple <send>:<expect> are supported and must be separated by a comma.
Regex supported for <expect> part.""",
    )
    parser.add_argument(
        "-t",
        "--timeout",
        metavar="float",
        required=False,
        default=TIMEOUT,
        type=float,
        help="Timeout for receiving data before declaring the endpoint as crashed. Default: "
        + str(TIMEOUT),
    )
    parser.add_argument(
        "-d",
        "--delay",
        metavar="float",
        required=False,
        default=DELAY,
        type=float,
        help="Delay in seconds between each round. Default: " + str(DELAY),
    )
    parser.add_argument("host", type=str, help="address to connect to.")
    parser.add_argument("port", type=_args_check_port, help="port to connect to.")
    return parser.parse_args()


# -------------------------------------------------------------------------------------------------
# MAIN ENTRYPOINT
# -------------------------------------------------------------------------------------------------


def main():
    """Start the program."""
    args = get_args()

    char = args.char
    imulti = args.length
    rmulti = args.multiply
    prefix = args.prefix
    suffix = args.suffix
    timeout = args.timeout
    delay = args.delay

    multiplier = imulti
    buff = char * multiplier

    while True:
        print("------------------------------------------------------------")
        print("%s * %s" % (char, str(multiplier)))
        print("------------------------------------------------------------")

        # Connect
        s, err = connect(args.host, args.port)
        if s is None:
            if multiplier == imulti:
                print(err, file=sys.stderr)
                sys.exit(1)
            else:
                print_crashlog(prefix, suffix, char, buff)
                sys.exit(0)

        # Initial communication
        if args.init is not None:
            for comm in args.init.split(","):
                d_send, d_expect = comm.split(":")
                # Send data?
                if len(d_send) > 0:
                    print("Init Sending:  %s" % (d_send))
                    succ, err = send(s, d_send)
                    if not succ:
                        if multiplier == imulti:
                            print(err, file=sys.stderr)
                            sys.exit(1)
                        else:
                            print_crashlog(prefix, suffix, char, buff)
                            sys.exit(0)
                # Expect data?
                if len(d_expect) > 0:
                    print("Init Awaiting: %s" % (d_expect))
                    succ, d_recv = receive(s, timeout, 1024)
                    if not succ:
                        if multiplier == imulti:
                            print(d_recv, file=sys.stderr)
                            sys.exit(1)
                        else:
                            print_crashlog(prefix, suffix, char, buff)
                            sys.exit(0)
                    print("Init Received: %s" % (d_recv))
                    if d_expect != d_recv and not bool(re.search(d_expect, d_recv)):
                        print_crashlog(prefix, suffix, char, buff)
                        sys.exit(0)

        # Send payload
        print('Sending "%s" + "%s"*%s + "%s"' % (prefix, char, multiplier, suffix))
        buff = char * multiplier
        payload = prefix + buff + suffix
        succ, err = send(s, payload)
        if not succ:
            if multiplier == imulti:
                print(err, file=sys.stderr)
                sys.exit(1)
            else:
                print_crashlog(prefix, suffix, char, buff)
                sys.exit(0)

        # Exit communication
        if args.exit is not None:
            for comm in args.exit.split(","):
                d_send, d_expect = comm.split(":")
                # Send data?
                if len(d_send) > 0:
                    print("Exit Sending:  %s" % (d_send))
                    succ, err = send(s, d_send)
                    if not succ:
                        if multiplier == imulti:
                            print(err, file=sys.stderr)
                            sys.exit(1)
                        else:
                            print_crashlog(prefix, suffix, char, buff)
                            sys.exit(0)
                # Expect data?
                if len(d_expect) > 0:
                    print("Exit Awaiting: %s" % (d_expect))
                    succ, d_recv = receive(s, timeout, 1024)
                    if not succ:
                        if multiplier == imulti:
                            print(d_recv, file=sys.stderr)
                            sys.exit(1)
                        else:
                            print_crashlog(prefix, suffix, char, buff)
                            sys.exit(0)
                    print("Exit Received: %s" % (d_recv))
                    if d_expect != d_recv and not bool(re.search(d_expect, d_recv)):
                        print_crashlog(prefix, suffix, char, buff)
                        sys.exit(0)

        s.close()
        sleep(delay)
        multiplier = multiplier + rmulti


if __name__ == "__main__":
    # Catch Ctrl+c and exit without error message
    try:
        main()
    except KeyboardInterrupt:
        print()
        sys.exit(1)
