#!python

import argparse
import badbyte
from badbyte.utils.colors import RED, GREEN, RST
from badbyte.utils.functions import unhexify
from badbyte.utils.functions import analyze, generate_characters, get_cyclic_alphabet
from argparse import RawTextHelpFormatter
from pwnlib.util import cyclic
import sys
import string


print(f"""

{RED}  _               _{GREEN} _           _       {RST}
{RED} | |__   __ _  __| {GREEN}| |__  _   _| |_ ___ {RST}
{RED} | '_ \ / _` |/ _` {GREEN}| '_ \| | | | __/ _ \{RST}
{RED} | |_) | (_| | (_| {GREEN}| |_) | |_| | ||  __/{RST}
{RED} |_.__/ \__,_|\__,_{GREEN}|_.__/ \__, |\__\___|{RST}
                        {GREEN}   |___/             {RST}
                        by C3l1n v{badbyte.__version__}
""")

parser = argparse.ArgumentParser(formatter_class=RawTextHelpFormatter, description=f"""
Command:
  {GREEN}generate{RST} - generate patttern with all characters to search for bad/encoded chars. Optional args: pre, post, bad.
  {GREEN}parse{RST} - parse memory dumped after vuln trigger and analyze it. Optional args: pre, post, custom. Use custom when
          you provided bad characters in generation.
  {GREEN}cycle-gen{RST} - generate pattern to search for offset (i.e. useful when searching offset from where EIP is popped. 
          Optional args: bad, length, uniqlength.
  {GREEN}offset{RST} - Search for pattern offset in generated pattern. Please supply same optional args: bad, length, uniqlength
          as you supplied in cycle-gen. Additional required argument: pattern.
""")
parser.add_argument("mode", help="g[enerate] | p[arse] | c[cycle-gen] | o[ffset]")
parser.add_argument("-c", "--custom", action='store_true',
                    help="use custom payload in parse (usefull when you try to search for subsequent bad cahrs).")
parser.add_argument("--pre", type=str, default="BAD_START", help="payload prefix - default BAD_START")
parser.add_argument("--post", type=str, default="BAD_STOP", help="payload postfix - default BAD_STOP")
parser.add_argument("--bad", type=str, default="", help="Banned characters as hexstring i.e. '3D 0D 0A'")
parser.add_argument("--length", "-l", type=int, help="Cyclic pattern length.")
parser.add_argument("--pattern", "-p", type=str, help="Part of pattern to search offset for.")
parser.add_argument("--uniqlength", "-u", default=4, type=int, help="Minimal length of character from cyclic pattern minimal needed to determine offset.")

args = parser.parse_args()

def validate_args_cyclic(search_offset=False):
    if args.length is None or args.uniqlength is None or (search_offset and args.pattern is None and search_offset):
        parser.print_help(sys.stdout)
        print("\n\nError - you must provide -u and -l to generate pattern or search offset" +
              (" and -p to search of offset" if search_offset else ""))
        sys.exit(0)


if __name__ == "__main__":
    bad = []
    if args.bad != "":
        bad = [b for b in unhexify(args.bad)]
    payload_raw = bytes([x for x in range(255, 0, -1) if x not in bad])
    PREFIX = args.pre.encode("latin1")
    POSTFIX = args.post.encode("latin1")
    payload = generate_characters(PREFIX, POSTFIX, bad)
    if args.mode[0] == "g":
        print(f"-> Known bad characters: {RED}{bytes(bad)} {RST}")
        print(payload)
    elif args.mode[0] == "c":
        validate_args_cyclic()
        alphabet = get_cyclic_alphabet(args.bad)
        g = cyclic.cyclic_gen(alphabet=alphabet, n=args.uniqlength)
        print(f"{GREEN}Pattern:{RST} {g.get(args.length)}")
    elif args.mode[0] == "o":
        validate_args_cyclic(search_offset=True)
        alphabet = get_cyclic_alphabet(args.bad)
        g = cyclic.cyclic_gen(alphabet=alphabet, n=args.uniqlength)
        print(f"{GREEN}Pattern:{RST} {g.get(args.length)}")
        if {g.find(args.pattern) == -1:
            print(f"{GREEN}Offset:{RST} {RED}NOT FOUND{END}")
        else:
            print(f"{GREEN}Offset:{RST} {g.find(args.pattern)[0]}")
    elif args.mode[0] == "p":
        print(f"{GREEN}Please paste hexdump (every non whitespace is treated as hexstring). "
              f"Press CTRL+D to end.{RST}")
        try:
            hexdump = ""
            while True:
                hexdump += input()
        except EOFError:
            pass
        if args.custom:
            payload = input(f"{GREEN}Enter payload as python bytestring {RST}: ")
            if payload[0:2] == "b'" or payload[0:2] == 'b"':
                payload = payload[2:-1]
            payload = payload.encode('utf-8').decode("unicode_escape").encode("latin1")
        analyze(hexdump, PREFIX, POSTFIX, payload)
    else:
        parser.print_help()
