#!/usr/bin/env python2

###
#    Copyright (c) 2015 Timothy Savannah <kata198@gmail.com> LGPL v2.1
#    See LICENSE for more information, or https://gnu.org/licenses/old-licenses/lgpl-2.1.txt
###

import math
import multiprocessing
import os
import sys
import signal
import threading
import time

import ArgumentParser

from socket_gatekeeper.Listener import Listener
from socket_gatekeeper.Handler import DEFAULT_CLIENT_BUFFER_LEN, DEFAULT_ENDPOINT_BUFFER_LEN, handlerFilterQuit

from socket_gatekeeper.MappingsParser import MappingsFileParser, ParseMappingException


def printUsage():
    sys.stderr.write('''Usage: %s [arguments]
Starts a socket gatekeeper instance. Description of application at bottom.

 Required:

     --mappings=/path/to/file
        or                           
     -m /path/to/file                 Path to the mappings file. See MAPPING FORMAT below for details.

     --bind=addr:port
        or
     -b addr:port                     Listen on this interface and port.
                                       Example, listen on all interfaces port 51000:
                                       --bind=0.0.0.0:51000


 Options:

      --help                          Show this message and quit
      --client-buffer-len=X           Use X bytes max buffer in data to/from the client. Defaults to %d
      --endpoint-buffer-len=X         Use X bytes max buffer in data to/from the endpoint. Defaults to %d
      --enable-quit                   Enable intercepting the messages "quit" and "exit" to terminate connection.
      


DESCRIPTION
-----------

Socket gatekeeper instance listens on a local interface/port. Any connections are silently
prompted for a password. That password is read, hashed, and checked against a provided database of
one-way encrypted passwords. Each password specifies an endpoint address and port.

You can use this to open a single port on a machine which routes privileged users to their
appropriate services. You can also use this as a "front door" for services that don't
support authentication out-of-the-box to add a layer of security.


MAPPING FORMAT
--------------

The mapping file is defined like this:

sha256sum_password=ADDR:port

You can use a procedure like the following to generate a sha256sum of your password:


Example, for password "abc" to resolve to localhost port 80:

echo -n "abc" | sha256sum | awk {'print $1'}

Returns edeaaff3f1774ad2888673770c6d64097e391bc362d7d6fb34982ddf0efd18cb

So the mapping file would contain

edeaaff3f1774ad2888673770c6d64097e391bc362d7d6fb34982ddf0efd18cb = 0.0.0.0:80

You can have several mappings in the same file.
You can have duplicates of the endpoints, but you can not have duplicate passwords.

''' %(sys.argv[0], DEFAULT_CLIENT_BUFFER_LEN, DEFAULT_ENDPOINT_BUFFER_LEN)
    )

def errorUsageAndExit(msg):
    sys.stderr.write(msg)
    sys.stderr.write('\n\n')
    time.sleep(1.5)
    printUsage()
    sys.exit(1)

if __name__ == '__main__':

    parser = ArgumentParser.ArgumentParser(
        ('mappingsFilename', 'clientBufferLen', 'endpointBufferLen', 'bind' ),
        ('m', None, None, 'b' ),
        ('mappings', 'client-buffer', 'endpoint-buffer', 'bind' ),
        ['--help', '--enable-quit'],
        {},
        False
    )
    parseResults = parser.parse(sys.argv[1:])
    args = parseResults['result']

    if args['--help'] is True:
        printUsage()
        sys.exit(0)

    if parseResults['errors']:
        errorUsageAndExit('Error in arguments:\n%s' %('\n'.join(parseResults['errors']), ))
 
    if 'clientBufferLen' in args:
        try:
            overrideClientBufferLen = int(args['clientBufferLen'])
            if overrideClientBufferLen <= 0:
                raise ValueError
        except ValueError:
            errorUsageAndExit('Client buffer length must be an integer > 0.')
    else:
        overrideClientBufferLen = None

    if 'endpointBufferLen' in args:
        try:
            overrideEndpointBufferLen = int(args['endpointBufferLen'])
            if overrideEndpointBufferLen <= 0:
                raise ValueError
        except ValueError:
            errorUsageAndExit('Endpoint buffer length must be an integer > 0.')
    else:
        overrideEndpointBufferLen = None
 

    if 'bind' not in args:
        errorUsageAndExit('A bind param is required. Specify one with --bind=addr:port or -b addr:port.\nEx: 0.0.0.0:51000')

    bindValue = args['bind']
    bindValueSplit = bindValue.split(':')
    if len(bindValueSplit) != 2 or bindValueSplit[1].isdigit() is False:
        sys.stderr.write('Invalid bind param: "%s".\nMust be in format Addr:port like 0.0.0.0:51000\n\n')
        sys.exit(1)

    bindAddr = bindValueSplit[0]
    bindPort = int(bindValueSplit[1])

    if 'mappingsFilename' not in args:
        errorUsageAndExit('A mappings file is required. Specify one with --mappings=/path/to/file or -m /path/to/file')

    mappingsFilename = args['mappingsFilename']

    if not mappingsFilename or not os.path.isfile(mappingsFilename):
        sys.stderr.write('Cannot find specified mappings file: "%s" or is not a file.\n Check the path and try again.\n\n')
        sys.exit(1)
   

    mappingsParser = MappingsFileParser(mappingsFilename)

    try:
        mappings = mappingsParser.getMappings()
    except ParseMappingException as e:
        sys.stderr.write('Error parsing mappings file. See --help for format information: \n%s\n\n' %(str(e),))
        sys.exit(1)


        

    listener = Listener(bindAddr, bindPort, mappings, overrideClientBufferLen, overrideEndpointBufferLen)

    if args['--enable-quit'] is True:
        listener.setApplyFiltersToHandlerFunction( lambda handler, sha256, mapping : handler.addIncomingFilter(handlerFilterQuit) )

    listener.start()

    globalIsTerminating = False

    def handleSigTerm(*args):
        global listener
        global globalIsTerminating
        if globalIsTerminating is True:
            return # Already terminating
        globalIsTerminating = True
        sys.stderr.write('Caught signal, shutting down listeners...\n')
        try:
            os.kill(listener.pid, signal.SIGTERM)
        except:
            pass
        sys.stderr.write('Sent signal to children, waiting up to 6 seconds...\n')
        sys.stderr.flush()

        startTime = time.time()

        time.sleep(1)
        listener.join(6)

        if listener.is_alive():
            try:
                os.kill(listener.pid, signal.SIGKILL)
            except:
                pass
            time.sleep(.1)
            listener.join()

        afterJoinTime = time.time()

        delta = afterJoinTime - startTime
        remainingSleep = int(6 - math.floor(afterJoinTime - startTime))
        if remainingSleep > 0:
            anyAlive = False
            # If we still have time left, see if we are just done or if there are children to clean up using remaining time allotment
            if threading.activeCount() > 1 or len(multiprocessing.active_children()) > 0:
                sys.stderr.write('Listener closed in %1.2f seconds. Waiting up to %d seconds before terminating.\n' %(delta, remainingSleep))
                sys.stderr.flush()
                thisThread = threading.current_thread()
                for i in range(remainingSleep):
                    allThreads = threading.enumerate()
                    anyAlive = False
                    for thread in allThreads:
                        if thread is thisThread or thread.name == 'MainThread':
                            continue
                        thread.join(.05)
                        if thread.is_alive() == True:
                            anyAlive = True

                    allChildren = multiprocessing.active_children()
                    for child in allChildren:
                        child.join(.05)
                        if child.is_alive() == True:
                            anyAlive = True
                    if anyAlive is False:
                        break
                    time.sleep(1)

            if anyAlive is True:
                sys.stderr.write('Could not kill in time.\n')
            else:
                sys.stderr.write('Shutdown successful after %1.2f seconds.\n' %( time.time() - startTime))
            sys.stderr.flush()
                
        else:
            sys.stderr.write('Listener timed out in closing, exiting uncleanly.\n')
            sys.stderr.flush()
            time.sleep(.05) # Why not? :P

        signal.signal(signal.SIGTERM, signal.SIG_DFL)
        signal.signal(signal.SIGINT, signal.SIG_DFL)
        sys.exit(0)
        os.kill(os.getpid(), signal.SIGTERM)
        return 0
    # END handleSigTerm
        

    signal.signal(signal.SIGTERM, handleSigTerm)
    signal.signal(signal.SIGINT, handleSigTerm)

    while True:
        try:
            time.sleep(5)
        except:
            os.kill(os.getpid(), signal.SIGTERM)

# vim: ts=4 sw=4 expandtab
