#!/usr/bin/python
# coding: utf-8

"""
Proxy which uses simple time-based XOR scheme to make heuristic detection
of underlying transmission impossible.

Based on twisted example 'tcp proxy' with an original link:
http://musta.sh/2012-03-04/twisted-tcp-proxy.html

MIT license
"""

import sys

from twisted.internet import defer
from twisted.internet import protocol
from twisted.internet import reactor
from twisted.python import log

from optparse import OptionParser


from barylib import Mixer

# Using -k makes the passphrase visible in 'ps aux' output. On multiuser systems
# you might want to change the key here instead.
DEFAULT_KEY='You should set up your own key, although this one is usually fine.'

cfg = {
    'server_addr': None,
    'server_port': None,
    'listen_port': None,
    'base_key': None,
}

class ProxyClientProtocol(protocol.Protocol):
    """Handles connection with client.

    Data incoming from the server object is read from the queue
    and sent to the client. Data read from the client is written
    into the server pipe.
    """

    def connectionMade(self):
        self.mixer = self.cli_queue = None
        log.msg("Client: connected to peer")
        self.mixer = Mixer(cfg['base_key'] + '_cli')
        self.cli_queue = self.factory.cli_queue
        self.cli_queue.get().addCallback(self.serverDataReceived)

    def serverDataReceived(self, chunk):
        if chunk is False:
            self.cli_queue = None
            log.msg("Client: disconnecting from peer")
            self.factory.continueTrying = False
            self.transport.loseConnection()
        elif self.cli_queue:
            chunk = self.mixer.mix(chunk)
            self.transport.write(chunk)
            self.cli_queue.get().addCallback(self.serverDataReceived)
        else:
            self.factory.cli_queue.put(chunk)

    def dataReceived(self, chunk):
        self.factory.srv_queue.put(chunk)

    def connectionLost(self, why):
        if self.cli_queue:
            self.cli_queue = None
            log.msg("Client: peer disconnected unexpectedly")
        if self.mixer:
            del self.mixer


class ProxyClientFactory(protocol.ReconnectingClientFactory):
    maxDelay = 10
    continueTrying = True
    protocol = ProxyClientProtocol

    def __init__(self, srv_queue, cli_queue):
        self.srv_queue = srv_queue
        self.cli_queue = cli_queue


class ProxyServer(protocol.Protocol):
    """On each new connection to the server spawn new client connection
    and a new mixer.
    """

    def connectionMade(self):
        self.mixer = self.srv_queue = self.cli_queue = None

        self.mixer = Mixer(cfg['base_key'])
        self.srv_queue = defer.DeferredQueue()
        self.cli_queue = defer.DeferredQueue()
        self.srv_queue.get().addCallback(self.clientDataReceived)

        factory = ProxyClientFactory(self.srv_queue, self.cli_queue)
        reactor.connectTCP(cfg['server_addr'], cfg['server_port'], factory)

    def clientDataReceived(self, chunk):
        chunk = self.mixer.mix(chunk)
        self.transport.write(chunk)
        self.srv_queue.get().addCallback(self.clientDataReceived)

    def dataReceived(self, chunk):
        self.cli_queue.put(chunk)

    def connectionLost(self, why):
        if self.cli_queue:
            self.cli_queue.put(False)
        if self.mixer:
            del self.mixer


def test():
    "Small built-in test"
    from time import time
    import timeit
    from collections import defaultdict
    import random

    print "Measuring key extraction algorithm speed..."
    t = timeit.Timer(lambda: Mixer.create_key_from_passphrase('This is a test passphrase'))
    result = t.timeit(number=5)
    result /= 5.0
    print "  Key extraction takes {:.2f}s. Keep it around 0.5 - 3s".format(result)

    print "Checking mixer correctness / reproductibility"
    input_string = "".join(chr(random.randrange(0, 255)) for i in range(8096))
    mixer_1 = Mixer('Shared KEY')
    mixer_2 = Mixer('Shared KEY')
    success = True
    for ch_in in input_string:
        ch_out_1 = mixer_1.mix(ch_in)
        ch_out_2 = mixer_2.mix(ch_in)
        if ch_out_1 != ch_out_2:
            success = False
            break

    if success:
        print "  Correctness ok"
    else:
        print "  Failed - bytes differ"

    stat = defaultdict(lambda: 0)
    for i in range(0, 65536):
        ch = ord(mixer_1.mix('A'))
        for bit in range(0, 8):
            if ch & 0x01<<bit:
                stat["bit {} = 1".format(bit)] += 1
            else:
                stat["bit {} = 0".format(bit)] += 1

    print "Bit statistics should be unbiased:"
    for bit in range(0, 8):
        val_0 = stat['bit {} = 0'.format(bit)]
        val_1 = stat['bit {} = 1'.format(bit)]
        print "  bit {}: zeroes={} ones={} difference={} ({:.3f}%)".format(
            bit, val_0, val_1,
            abs(val_0-val_1),
            100.0 * abs(val_0-val_1) / (val_0 + val_1)
        )

    print "Measuring mixing throughput"
    s = 'this sentence has 26 bytes' * 40
    t = timeit.Timer(lambda: mixer_1.mix(s))
    cnt = 1024 * 3
    result = t.timeit(number=cnt)
    length = cnt * len(s)
    print "  {} B / {:.2f}s = {:.3f}MB/s".format(length, result, length / result / 1024.0 / 1024.0)

def main():
    # Parse options
    parser = OptionParser()

    parser.add_option("-c", "--connect-ip", action="store", help="Connect to this specific ip address")
    parser.add_option("-p", "--connect-port", action="store", default=22, type=int, help="Connect to this specific port")
    parser.add_option("-l", "--listen-port", action="store", type=int, help="Listen on this port")

    parser.add_option("-t", "--test", action="store_true", help="Run internal test and exit")
    parser.add_option("-k", "--passphrase", action="store", help="Passphare for masking",
                      default=None)
    parser.add_option("--key-file", action="store", help="Instead of passphrase - use a keyfile",
                      default=None)

    options, args = parser.parse_args()

    if options.test:
        test()
        sys.exit(0)

    if options.key_file is not None and options.passphrase is not None:
        parser.error("You can't use both passphrase and key_file!")

    if options.key_file is not None:
        with open(args.key_file, 'r') as f:
            passphrase = f.read()
    else:
        passphrase = options.passphrase

    if not options.listen_port:
        parser.error('You need to specify --listen-port option')

    log.startLogging(sys.stdout)

    cfg['server_addr'] = options.connect_ip
    cfg['server_port'] = options.connect_port
    cfg['listen_port'] = options.listen_port
    cfg['base_key'] = Mixer.create_key_from_passphrase(passphrase)

    # Start proxy
    factory = protocol.Factory()
    factory.protocol = ProxyServer
    reactor.listenTCP(cfg['listen_port'], factory, interface="0.0.0.0")
    reactor.run()


if __name__ == "__main__":
    main()
