#!/usr/bin/env python
"""
An example of a BufferControl in a full screen layout that offers auto
completion.
Important is to make sure that there is a `CompletionsMenu` in the layout,
otherwise the completions won't be visible.
"""
from prompt_toolkit.application import Application
from prompt_toolkit.buffer import Buffer
from prompt_toolkit.completion import WordCompleter
from prompt_toolkit.key_binding import KeyBindings
from prompt_toolkit.document import Document
from prompt_toolkit.widgets import SearchToolbar, TextArea
from prompt_toolkit.history import FileHistory
from prompt_toolkit.layout.containers import (
    Float,
    FloatContainer,
    HSplit,
    Window,
)
from prompt_toolkit.layout.controls import BufferControl, FormattedTextControl
from prompt_toolkit.layout.layout import Layout
from prompt_toolkit.styles import Style
from prompt_toolkit.layout.menus import CompletionsMenu

from polypacket.protocol import *
from polypacket.polyservice import *
import binascii
import traceback

import sys
import os
import argparse
import time
import re
import builtins


args = None
parser = None
PROTOCOL = ''
SERVICE = ''

compFields = [] 
compVals = []


connect_completer = WordCompleter(['serial:','udp:','tcp:'],ignore_case=True)

agent_completer = ''

command_completer = ''


outputField = TextArea(style='class:output-win',multiline=True)
outputField.text = "Ctrl-C to quit"

DataStore ={}

agentCommands = {}
agentCommandHelp = {}

namespace = {}




# Style.
style = Style([
    ('output-win', '#00ff00'),
    ('input-win', '#ffffff'),
    ('line',        '#004400'),
    ('status', '#ffffff')
])

def onInputChange(buff):
    global outputField
    global PROTOCOL
    global compFields
    global compVals
    global command_completer

    input = buff.text.split(';')[-1] #only worry about lastest command
    if len(input.strip()) > 0:
        if input[-1:] == " ":
            word = input.split()[0].strip()
            lastWord = input.split()[-1].strip()

            if word.lower() == 'connect':
                buff.completer = connect_completer
            
            elif word.lower() == 'loadagent':
                buff.completer = agent_completer

            elif word.lower() in agentCommands :
                compVals = ['help']
                for arg in agentCommands[word.lower()].args:
                    compVals.append(arg.name + ':')
                buff.completer = WordCompleter(compVals,ignore_case=True )

            elif lastWord.replace(':','') in PROTOCOL.fieldIdx:
                lastWord = lastWord.replace(':','')
                id = PROTOCOL.fieldIdx[lastWord]
                field = PROTOCOL.fields[id]
                compVals = []
                for val in field.vals:
                    compVals.append(val.name)
                buff.completer = WordCompleter(compVals,ignore_case=True )
            
            elif '|' in lastWord:
                pass



            else :
                packetKey = PROTOCOL.getExtName(word)
                if packetKey != None:
                    id = PROTOCOL.packetIdx[packetKey]
                    packet = PROTOCOL.descFromId(id)
                    compFields = []
                    for field in packet.fields:
                        compFields.append(field.name +":")

                    buff.completer = WordCompleter(compFields,ignore_case=True )

        elif input[-1] == ';':
                buff.completer = command_completer

        elif input[-1] == ",":
            buff.completer = WordCompleter(compFields,ignore_case=True )
    else: 
        buff.completer = command_completer
# The layout
home_path = os.path.expanduser('~')
our_history = FileHistory(home_path+'/.poly-history')

inputBuffer = Buffer( complete_while_typing=False, multiline=False, on_text_changed=onInputChange, history=our_history)
#inputBuffer.text = 'SendCmd src: 45, dst: 32, cmd: 4'

#outputWin = Window(BufferControl(buffer=outputField,focusable=False), height=20, style='class:output-win')
inputWin = Window(BufferControl(buffer=inputBuffer), height=4, style='class:input-win')
body = FloatContainer(
    content=HSplit([
        #outputWin,
        outputField,
        Window(height=1, char='-', style='class:line'),
        inputWin,
    ]),
    floats=[
        Float(xcursor=True,
              ycursor=True,
              content=CompletionsMenu(max_height=16, scroll_offset=1))
    ]
)


# Key bindings
kb = KeyBindings()


@kb.add('c-c')
def _(event):
    printToConsole(" Quit application. ")
    SERVICE.close()
    event.app.exit()

# @kb.add('enter')
# def _(event):
#     global command_completer
#     new_text =outputField.text + "\n"+ inputBuffer.text
#     outputField.text = new_text
#     inputBuffer.text = '>>> '
#     inputBuffer.completer = command_completer
#     testService(SERVICE)

def accept(buff):
    global command_completer
    #new_text =outputField.text + "\n"+ inputBuffer.text
    #outputField.text = new_text
    commands = inputBuffer.text.split(';')
    for cmd in commands:
        parseCommand(cmd)
        time.sleep(0.05)
    #inputBuffer.text = '>>> '
    inputBuffer.completer = command_completer


#buff.on_text_changed = onInputChange;
inputBuffer.accept_handler = accept


# The `Application`
application = Application(
    layout=Layout(body, focused_element=inputBuffer),
    key_bindings=kb,
    style=style,
    mouse_support=True,
    full_screen=True)

# Initialize the argument parser
def init_args():
    global parser
    parser = argparse.ArgumentParser("Tool to generate code and documentation for PolyPacket protocol")
    parser.add_argument('-i', '--input', type=str, help='input file to parse', default="")
    parser.add_argument('-x', '--execute', nargs='+', help='Commands to execute on start, seperate with ;', default="")
    parser.add_argument('-c', '--connection', type=str, help='Connection string ex. tcp:localhost:8020', default="")
    parser.add_argument('-a', '--agent', type=str, help='Specify an agent profile to use its behavior, or "none" to not use any', default="default")
    parser.add_argument('-b', '--bytes', action='store_true', help='shows packet bytes', default=False)
    parser.add_argument('-m', '--meta', action='store_true', help='shows packet meta data', default=False)
    


def saveBufferToFile(fileName):
    file = open(fileName, "w")
    file.write(outputField.text)
    file.close()
    printToConsole(" Log saved as: " + fileName)

def printToConsole( text, newLine=True, replace=False):

    global outputField
    new_text =outputField.text

    
    if replace:
        new_text = new_text[:new_text.rfind('\n')]

    if newLine:
       new_text += "\n"

    new_text+=  text

    # Add text to output buffer.
    outputField.buffer.document = Document(
        text=new_text, cursor_position=len(new_text))

    #outputField.text = new_text

def cmdConnect(SERVICE, tokens):
    SERVICE.connect(tokens[1])

def cmdSaveLog(SERVICE, tokens):
    saveBufferToFile(tokens[1])

def cmdAck(SERVICE, tokens):
    SERVICE.toggleAck()

def cmdSilence(SERVICE, tokens):
    SERVICE.toggleSilence(tokens[1])

def cmdLoadagent(SERVICE, tokens):
    if tokens[1] in PROTOCOL.agents:
        load_agent(SERVICE, PROTOCOL.agents[tokens[1]])
    else:
        printToConsole(" agent Profile '"+ tokens[1]+ "' not found..")

commandMap = {
    'connect': cmdConnect,
    'savelog': cmdSaveLog,
    'ack': cmdAck,
    'silence': cmdSilence,
    'loadagent': cmdLoadagent
}




def parseCommand(text):

    global PROTOCOL
    text = text.strip()
    words = text.split()

    if len(words) == 0:
        printToConsole("")
        return
    

    
    if words[0].lower() in commandMap:
        commandMap[words[0].lower()](SERVICE, words)
        return 0
    
    if words[0].lower() in agentCommands:

        if len(words) == 2 and words[1].lower() == "help":
            printToConsole(agentCommands[words[0].lower()].getHelpString())
        else:
            Pattern = re.compile(r',(?!(?:[^[]*\[[^]]*\])*[^[\]]*\])\s*') #splits by commas, ignoring commas in brackets
            strArgs = Pattern.split(text[len(words[0]):])
            
            args = {}

            #set arg defaults  
            for arg in agentCommands[words[0].lower()].args:
                if arg.default != None:
                    args[arg.name] = arg.default
            
            for strArg in strArgs:
                strArg = strArg.strip()
                try:
                    if strArg != "":
                        subFields = strArg.split(':')
                        fname = subFields[0].strip()
                        val = subFields[1].strip()
                        args[fname] = val
                except:
                    printToConsole("Error Parsing argument: " + strArg)
            
            line = "\n>>> " + words[0] 

            if len(args) > 0:
                line += " ("
                for arg in args:
                    line += arg + ": " + str(args[arg]) + ", "
                
                line = line[:-2] + ")"

            printToConsole(line)
            

            agentCommands[words[0].lower()].f_handler(SERVICE, args)

        return 0

    packetType = words[0].strip()

    extPacketType = PROTOCOL.getExtName(packetType)

    newPacket = SERVICE.newPacket(extPacketType)
    
    if len(words) > 1:
        Pattern = re.compile(r',(?!(?:[^[]*\[[^]]*\])*[^[\]]*\])\s*') #splits by commas, ignoring commas in brackets
        fields = Pattern.split(text[len(packetType):])
        #printToConsole(str(fields))

        for field in fields:
            if field != None:
                subFields = field.split(':')
                fname = subFields[0].strip()
                val = subFields[1].strip()
                newPacket.setField(fname, val)

    SERVICE.interfaces[0].sendPacket(newPacket)

def load_agent(service, agent):
    global command_completer
    global agentCommands
    global agentCommandHelp
    
    service.handlers.clear() #clear current handlers


    for key, value in agent.handlers.items():

        if key == "default":
            name = key
        else:
            name = agent.namespace + ":" + key

        try:
            fx = "def "+ key +"_handler(service,req, resp):\n\tglobal DataStore\n\t" + value.replace('\n','\n\t')+"\n"
            exec(fx)
            exec("service.handlers[\""+name+"\"] = "+key+ "_handler")
        except:
            printToConsole("Error loading handler: "+ key)

 
    if agent.init != "":
        try:
            fx = "service = SERVICE\n"
            fx += agent.init

            exec(fx, globals())


        except:
            printToConsole("Error loading init function")
        
    commandNames = []
    agentNames = []
    for packet in PROTOCOL.packets:
        commandNames.append(packet.name)

    for command in agent.commands:

        try:
            commandNames.append(command.name)
            fx = "def "+ command.name +"_cmd_handler(service,args):\n\tglobal DataStore\n" 
            for arg in command.args:
                fx+= "\t"+arg.name + "= args.get('" + arg.name +"', None)\n"
            fx += "\n\t"+ command.handler.replace('\n','\n\t')+"\n"
            exec(fx)
            agentCommands[command.name] = command
            exec("agentCommands[\""+command.name+"\"].f_handler = "+command.name+ "_cmd_handler")
            agentCommandHelp[command.name] = command.getHelpString()
        except: 
            printToConsole("Error loading command: "+ command.name)



    commandNames.append('connect')
    commandNames.append('silence')
    commandNames.append('saveLog')
    commandNames.append('loadagent')
    command_completer = WordCompleter(commandNames,ignore_case=True)
    inputBuffer.completer = command_completer
    printToConsole(" agent Profile '" + agent.name+ "' Loaded ")


def run():
    global parser
    global args
    global inputBuffer
    global PROTOCOL
    global SERVICE
    global command_completer
    global agent_completer

    init_args()
    args= parser.parse_args()

    # The layout.
    #search_field = SearchToolbar()  # For reverse search.
    inputFile = args.input

    agentProfile = args.agent
    connected = False

    builtins.print = printToConsole

    if inputFile == "":
        
        if args.connection.startswith("ws"):
            
            SERVICE = PolyService(None)
            SERVICE.print = print

            print(f"Requesting protocol from {args.connection}")
            SERVICE.connect(args.connection)

            timeout = 10
            while SERVICE.protocol == None:
                time.sleep(1)
                timeout -= 1
                if timeout == 0:
                    print("Unable to connect to server")
                    sys.exit()
            
            PROTOCOL = SERVICE.protocol
            connected = True


        else:

            print("No input file specified, use poly-make -t to create a template/example file")
            sys.exit()

    elif inputFile.startswith("http"):
        PROTOCOL = PolyProtocol(inputFile)
        SERVICE = PolyService(PROTOCOL)

    elif os.path.isfile(inputFile):
        fileCrc, fileHash = crc(inputFile)

        PROTOCOL = buildProtocol(inputFile)

        SERVICE = PolyService(PROTOCOL)

        PROTOCOL.hash = fileHash
        PROTOCOL.crc = fileCrc


    else :
        print("Unable to read input file: " + inputFile)
        return
    
    SERVICE.print = printToConsole
    
    if args.meta:
        SERVICE.showMeta = True

    if args.bytes:
        SERVICE.showBytes = True


    commandNames = []
    agentNames = []
    for packet in PROTOCOL.packets:
        commandNames.append(packet.name)

    commandNames.append('connect')
    commandNames.append('silence')
    commandNames.append('saveLog')
    commandNames.append('loadagent')

    for key, value in PROTOCOL.agents.items():
        agentNames.append(key)

    command_completer = WordCompleter(commandNames,ignore_case=True)
    agent_completer = WordCompleter(agentNames, ignore_case=True)
    inputBuffer.completer = command_completer

    outputField.text = "" #"Ctrl-C to quit"
    outputField.text += "______     _      ______          _        _   \n"
    outputField.text += "| ___ \\   | |     | ___ \\        | |      | |  \n"
    outputField.text += "| |_/ /__ | |_   _| |_/ /_ _  ___| | _____| |_ \n"
    outputField.text += "|  __/ _ \\| | | | |  __/ _` |/ __| |/ / _ \\ __|\n"
    outputField.text += "| | | (_) | | |_| | | | (_| | (__|   <  __/ |_ \n"
    outputField.text += "\\_|  \\___/|_|\\__, \\_|  \\__,_|\\___|_|\\_\\___|\\__|    ["+ PROTOCOL.name+ " protocol]\n"
    outputField.text += "              __/ |                            \n"
    outputField.text += "             |___/                             \n"

    if not args.connection == '' and not connected:
        SERVICE.connect(args.connection)

    if agentProfile in PROTOCOL.agents:
        load_agent(SERVICE, PROTOCOL.agents[agentProfile])

    if not args.execute == '':
        commands = ' '.join(args.execute)
        commands = commands.split(';')
        for com in commands:
            parseCommand(com)



    application.run()

if __name__ == '__main__':
    run()
