#! /usr/bin/env python

import argparse
import struct
import sys
import re
from itertools import chain, zip_longest, groupby

import umodbus.exceptions

REGISTER_RE = re.compile('([cdhi]@)?(\d*)(/.*)?')

def flatmap(f, items):
    return chain.from_iterable(map(f, items))

def grouper(iterable, n, fillvalue=None):
    "Collect data into fixed-length chunks or blocks"
    # grouper('ABCDEFG', 3, 'x') --> ABC DEF Gxx"
    args = [iter(iterable)] * n
    return zip_longest(*args, fillvalue=fillvalue)

def fail(fmt, *args):
    print(fmt.format(args))
    sys.exit(1)

def dump(xs):
    print(' '.join('{:02x}'.format(x) for x in xs))

def read_registers(modbus_type, address, pack_types):
    if modbus_type in 'cd':
        n_registers = 0
        for pack_type in pack_types:
            n_registers += struct.calcsize(pack_type)
    else:
        n_bytes = 0

        for pack_type in pack_types:
            n_bytes += struct.calcsize(pack_type)
            assert n_bytes % 2 == 0

        n_registers =  n_bytes // 2

    reader = {
            'c': 'read_coils',
            'd': 'read_discrete_inputs',
            'h': 'read_holding_registers',
            'i': 'read_input_registers',
            }[modbus_type]

    message = getattr(modbus, reader)(slave_id, address, n_registers)

    if args.verbose:
        dump(message)

    words = modbus.send_message(message, connection)

    if args.verbose:
        print("Read {} registers: {}".format(n_registers, words))

    if modbus_type in 'cd':
        return words
    else:
        packed = bytes(flatmap(lambda w: [w >> 8, w & 0xff], words))

        return struct.unpack(' '.join(pack_types), packed)

def write_registers(modbus_type, address, values):
    if modbus_type == 'c':
        values = [v[1] for v in values] # drop the pack type
        if len(values) == 1:
            # TODO validate value, should be boolean
            message = modbus.write_single_coil(slave_id, address, int(values[0]))
        else:
            raise "Multiple coil write not implemented yet."
    else:
        words = []

        for pack_type, value in values:
            n_bytes = struct.calcsize(pack_type)
            assert n_bytes % 2 == 0

            n_registers =  n_bytes // 2

            if 'f' in pack_type or 'd' in pack_type:
                value = float(value)
            else:
                value = int(value, 0)

            words.extend([ h << 8 | l for h, l in grouper(struct.pack(pack_type, value), 2) ])

        message = modbus.write_multiple_registers(slave_id, address, words)

    if args.verbose:
        dump(message)

    return modbus.send_message(message, connection)

class Access:
    def __init__(self, name, modbus_type, write, address, pack_type, value=None):
        self.names = [name]
        self.modbus_type = modbus_type
        self.write = write
        self.addresses = [address]
        self.pack_types = [pack_type]
        if self.write:
            self.values = [value]
        else:
            self.values = None

    def address(self):
        return self.addresses[0]

    def pack_type(self):
        return self.pack_types[0]

    def endianness(self):
        return self.pack_type()[0]

    def size(self):
        total = 0
        for p in self.pack_types:
            size = struct.calcsize(p)
            if self.modbus_type in ('h', 'i'):
                assert size % 2 == 0
                size //= 2

            total += size

        return total

    def operations(self):
        if self.write:
            return zip(self.pack_types, self.values)
        else:
            return self.pack_types

    def append(self, other):
        self.names.extend(other.names)
        # The endianess can be specified only once at the beginning of the unpack format.
        # Only accesses with the same endianess are grouped together anyway.
        self.pack_types.extend(t[1:] for t in other.pack_types)
        self.addresses.extend(other.addresses)
        if self.write:
            self.values.extend(other.values)

    def labels(self):
        return (name or address for (name, address) in zip(self.names, self.addresses))

    def __repr__(self):
        return "{}@{}/{}{}".format(self.modbus_type, self.address(), self.pack_types, "={}".format(self.values) if self.write else "")

def by_type(access):
    return (access.modbus_type, access.write, access.endianness())

def by_address(access):
    return access.address()

def group_accesses(accesses):
    grouped = []

    for (modbus_type, write, _), xs in groupby(sorted(accesses, key=by_type), key=by_type):
        xs = sorted(xs, key=by_address)
        while len(xs):
            first = xs.pop(0)
            while len(xs):
                second = xs[0]
                if first.address() + first.size() == second.address():
                    first.append(second)
                    xs.pop(0)
                else:
                    break
            grouped.append(first)

    return grouped

def parse_accesses(s, registers):
    accesses = []

    for access in s:
        parts = access.split('=')
        if len(parts) == 1:
            register = parts[0]
            value = None
            write = False
        else:
            register, value = parts
            write = True

        if register in registers:
            name = register
            register = registers[name]
        else:
            name = None

        modbus_type, address, pack_type = re.match(REGISTER_RE, register).groups()

        if not address:
            print("{!r} is not a known named register nor a valid register definition. Skipping it.".format(register))
            continue

        if not modbus_type:
            modbus_type = 'h'
        else:
            modbus_type = modbus_type[:-1]

        if not pack_type:
            if modbus_type in 'cd':
                pack_type = 'B'
            else:
                pack_type = '!H'
        else:
            pack_type = pack_type[1:]

        address = int(address)

        if pack_type[0] not in '@=<>!':
            pack_type = '!' + pack_type

        modbus_type = modbus_type.lower()
        if modbus_type not in 'cdhi':
            fail("Invalid Modbus type '{}'. Valid ones are 'cdhi'", modbus_type)
        if write and modbus_type not in 'ch':
            fail("Invalid Modbus type '{}'. Only coils and holding registers are writable", modbus_type)

        accesses.append(Access(name, modbus_type, write, address, pack_type, value))

    return accesses

def parse_registers(registers_filename):
    registers = {}

    if registers_filename:
        with open(registers_filename) as f:
            for line in f:
                line = line.split('#')[0]

                parts = line.split()
                if len(parts) == 2:
                    name, definition = parts

                    if REGISTER_RE.match(definition):
                        registers[name] = definition
                    else:
                        print("Invalid definition {!r} for register {!r}. Skipping it".format(definition, name))

    return registers

def connect_to_device(args):
    if args.device[0] == '/':
        from serial import Serial, PARITY_NONE

        connection = Serial(port=args.device, baudrate=args.baud, parity=PARITY_NONE,
                            stopbits=args.stop_bits, bytesize=8, timeout=1)

        # I had to patch rtu to read(1000) instead of read(serial_device.in_waiting) (in_waiting was always 0)

        slave_id = args.slave_id
        if not slave_id:
            slave_id = 1

        import umodbus.client.serial.rtu as modbus
    else:
        port = 502
        parts = args.device.split(':')
        if len(parts) == 2:
            host, port = parts
        elif len(parts) == 1:
            host = parts[0]
        else:
            fail("Invalid device {}", args.device)

        if host == '':
            host = 'localhost'

        port = int(port)

        import socket
        connection = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        connection.connect((host, port))

        if args.slave_id:
            slave_id = args.slave_id
        else:
            slave_id = 255

        import umodbus.client.tcp as modbus

    return modbus, slave_id, connection

parser = argparse.ArgumentParser()
parser.add_argument('-r', '--registers')
parser.add_argument('-s', '--slave-id', type=int)
parser.add_argument('-b', '--baud', type=int, default=19200)
parser.add_argument('-p', '--stop-bits', type=int, default=1)
parser.add_argument('-v', '--verbose', action='store_true')
parser.add_argument('device')
parser.add_argument('access', nargs='+')
args = parser.parse_args()

modbus, slave_id, connection = connect_to_device(args)

for a in group_accesses(parse_accesses(args.access, parse_registers(args.registers))):
    if a.write:
        write_registers(a.modbus_type, a.address(), a.operations())
    else:
        try:
            value = read_registers(a.modbus_type, a.address(), a.operations())
        except umodbus.exceptions.IllegalDataAddressError:
            value = ('Invalid address', )
        except umodbus.exceptions.IllegalFunctionError:
            value = ('Invalid modbus type', )

        for label, value in zip(a.labels(), value):
            print("{}: {} {}".format(label, value, hex(value) if type(value) == int else ''))

connection.close()
