#!/usr/bin/python
""" Sends a text or file to any listening host on the LAN

Started from py-multicast code (available on PyPi)
"""
__version__ = "1.3.1"

import sys
import socket
import time
from hashlib import sha256

BUFFER_SZ = 65000
ADDR = '224.0.0.100'
PORT = 9096

MAGIC_EOF = b'#E-O-F#close now'
MAGIC_START = b'MCAST FILE TRANSFER VERSION'

def dprint(txt):
    return
    sys.stderr.write('%s\n'%txt)


def stdout_writer(data):
    " Writes data on stdout, converting bytes to string if needed "
    if not isinstance(data, str):
        data = data.decode('latin1')
    sys.stdout.write(data)


def main():
    " Main entry point"
    if '-v' in sys.argv or '--version' in sys.argv:
        print("Version %s"%__version__)
        sys.exit(0)
    if '-h' in sys.argv or '--help' in sys.argv:
        print("""Syntax: %s [filename]

With a filename:
    sends the file on the network
Without a filename:
    get the data diffused on the network
    (default prints on stdout, redirect in case of file!)
              """%sys.argv[0])
        sys.exit(0)
    if len(sys.argv) > 1: # file given, this is the sender
        dg = DatagramEmitter(None, PORT, ADDR, PORT)
        dg.send_header()
        t0 = time.time()
        if sys.argv[1] == '-':
            data = sys.stdin
        else:
            data = open(sys.argv[1], mode='rb')
        kb = dg.send(data)/1024
        duration = time.time() - t0
        sys.stderr.write("%.0f kbps (%.2f Mbps).\n" % (kb/duration, kb/duration/1024))
        dg.send_footer()
    else: # read data
        dg = Multicast(None, ADDR, PORT, ttl=32)
        dg.read_header()
        dg.retrieve(stdout_writer)
        dg.read_footer()


class Datagram(object):
    """A datagram socket wrapper.  """

    def __init__(self, address, port):
        if address:
            socket.inet_aton(address)
            self.local_address = address
        else:
            self.local_address = '0.0.0.0'

        self._socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
        self._socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)

        self.address = address
        self.port = port

        # error detection
        self.corruption_detector = sha256()

        self.bind()

    def get_hash(self):
        return self.corruption_detector.hexdigest().encode('ascii')

    def bind(self):
        " Bind socket "
        self._socket.bind((self.local_address, self.port))

    def read_header(self):
        data = self.pipeone(None, check_error=False)
        version = data.strip().split()[-1]
        if version != __version__.encode('ascii'):
            sys.stderr.write('WARNING: version mismatch: %s\n!= %s\n'%(version, __version__))

    def read_footer(self):
        data = self.pipeone(None, check_error=False).split()[1]
        digest = self.get_hash()
        if data != digest:
            sys.stderr.write('Data corruption detected !! %r\n!= %r'%(digest, data))
            sys.exit(-1)

    def retrieve(self, ostream, size=BUFFER_SZ):
        " Receive all data into `ostream` "
        while True:
            try:
                self.pipeone(ostream, size=size)
            except EOFError:
                break

    def pipeone(self, ostream, size=BUFFER_SZ, check_error=True):
        " Pick one datagram only , save to `ostream` "
        data, addr = self._socket.recvfrom(size)
        if data and check_error:
            dprint("up")
            self.corruption_detector.update(data)
        dprint("READ[%s] %s\n"%(check_error,repr(data)))
        if len(data) == len(MAGIC_EOF) and data == MAGIC_EOF:
            raise EOFError("End of transfer")

        if ostream is not None:
            if callable(ostream):
                target = ostream
            else:
                target = ostream.write
            data = target(data)
        return data

    def read(self, size=BUFFER_SZ):
        " Get one datagram & return its data "
        return self.pipeone(None, size)

    def cleanup(self):
        pass

    def close(self):
        self.cleanup()
        self._socket.close()

    def __unicast__(self):
        return "DatagramReceiver [{0}] {1}:{2}".format(self._socket, self.local_address, self.port)


class Multicast(Datagram):
    def __init__(self, bind_address_or_interface, multicast_address, port, ttl=32, loop=1):
        Datagram.__init__(self, bind_address_or_interface, port)

        self.loop = loop
        self.multicast_address = multicast_address

        self._socket.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL, ttl)
        self._socket.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_LOOP, self.loop)

        self._socket.setsockopt(socket.SOL_IP, socket.IP_MULTICAST_IF, socket.inet_aton(self.local_address))
        self._socket.setsockopt(socket.SOL_IP, socket.IP_ADD_MEMBERSHIP, socket.inet_aton(self.multicast_address) + socket.inet_aton(self.local_address))

    def bind(self):
        sys.stderr.write("ready\n")
        self._socket.bind(('', self.port))

    def cleanup(self):
        self._socket.setsockopt(socket.SOL_IP, socket.IP_DROP_MEMBERSHIP, socket.inet_aton(self.multicast_address) + socket.inet_aton(self.local_address))

    def __unicast__(self):
        return "MulticastReceiver [{0}] {1}:{2} @ {3}".format(self._socket, self.multicast_address, self.port, self.local_address)


class DatagramEmitter(Datagram):
    def __init__(self, source_address, source_port, target_address, target_port, ttl=32, loop=1):
        Datagram.__init__(self, source_address, source_port)

        self.loop = loop
        self.target_address = target_address
        self.target_port = target_port

        # multicast only:
        self._socket.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL, ttl)
        self._socket.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_LOOP, self.loop)

    def send(self, data, check_error=True):
        dprint("WRITE [%s] %r"%(check_error,data))
        written = 0
        if hasattr(data, 'read'):
            while True:
                d = data.read(BUFFER_SZ)
                written += len(d)
                if not d:
                    break
                if check_error:
                    dprint('up!')
                    self.corruption_detector.update(d)
                if self._socket.sendto(d, (self.target_address, self.target_port)) != len(d):
                    sys.stderr.write("Packet loss ??\n")
        else:
            if check_error:
                dprint('up!')
                self.corruption_detector.update(data)
            self._socket.sendto(data, (self.target_address, self.target_port))
            written = len(data)
        return written

    def send_header(self):
        self.send(MAGIC_START + b' ' + __version__.encode('ascii'), check_error=False)

    def send_footer(self):
        self.send(MAGIC_EOF, check_error=True)
        self.send(b'HASH: %b'%self.get_hash(), check_error=False)

if __name__ == '__main__':
    try:
        main()
    except KeyboardInterrupt:
        sys.stderr.write('\nInterrupted !\n')
