#!/usr/bin/env python
# -*- coding: utf-8 -*-

from __future__ import absolute_import, division, print_function, unicode_literals

import intelhex
import argparse
import sys
import hexdump
import colorama
from colorama import Fore, Style

import datetime

import keyplus
from keyplus.layout import KeyplusLayout
from keyplus.device_info import KeyboardDeviceTarget, KeyboardFirmwareInfo

from keyplus.chip_id import get_chip_id_from_name

from keyplus.exceptions import *
from keyplus.constants import *
from keyplus.utility import list_to_map, inverse_map


__version__ = keyplus.__version__

EXIT_NO_ERROR = 0
EXIT_COMMAND_ERROR = 1
EXIT_MATCH_DEVICE = 2
EXIT_UNSUPPORTED_FEATURE = 3
EXIT_BAD_FILE = 4
EXIT_INSUFFICIENT_SPACE = 5

def print_error(*args):
    print(Fore.RED + "Error: " + Style.RESET_ALL, file=sys.stderr, end='')
    print(*args, file=sys.stderr)


def print_hid_info(hid_device):
    print("{:x}:{:x} | {} | {} | {}"
        .format(
            hid_device.vendor_id,
            hid_device.product_id,
            hid_device.manufacturer_string,
            hid_device.product_string,
            hid_device.serial_number,
        )
    )

def timestamp_to_string(tstamp):
    timestamp_str = "<Unavailable>"
    try:
        if tstamp != 0:
            build_date = datetime.datetime.fromtimestamp(tstamp)
            timestamp_str = str(build_date)
    except OverflowError:
        pass
    return timestamp_str

def print_device_info(device_info, indent="  "):
    print(indent, "id: ", device_info.device_id)
    print(indent, "name: '{}'".format(device_info.get_name_str()) )
    print(indent, "layout last updated: ", timestamp_to_string(device_info.timestamp_raw))
    print(indent, "default_report_mode: ", device_info.get_default_report_mode_str())
    print(indent, "scan_mode: ", device_info.scan_plan.mode)
    print(indent, "row_count: ", device_info.scan_plan.rows)
    print(indent, "col_count: ", device_info.scan_plan.cols)

def print_firmware_info(device, indent="  "):
    fw_info = protocol.get_firmware_info(device)
    print(indent, "build time: ", timestamp_to_string(fw_info.timestamp_raw))
    print(indent, "git_hash: {:08x}".format(fw_info.git_hash))

def print_all_info(kb_device):
    print_hid_info(kb_device.hid_device)
    print_device_info(kb_device.device_info)



# Generic Command
class GenericCommand(object):
    def __init__(self, description, **kwargs):

        command_name = ""
        for name in KeyplusCLI.COMMAND_NAME_MAP:
            if KeyplusCLI.COMMAND_NAME_MAP[name] == self.__class__:
                command_name = name

        command_name = sys.argv[0] + " " + command_name

        self.arg_parser = argparse.ArgumentParser(
            prog=command_name,
            description=description,
            **kwargs
        )
        self.description = description

    def task(self, args):
        pass

    def run(self):
        # Get the args after the command
        args = self.arg_parser.parse_args(sys.argv[2:])

        self.task(args)

    def help(self):
        pass

# Commands that target a specific device, and includes arguments and functions
# to find the targeted device.
class GenericDeviceCommand(GenericCommand):
    def __init__(self, description, **kwargs):
        super(GenericDeviceCommand, self).__init__(description, **kwargs)

        self.arg_parser.add_argument(
            '-i', '--id', dest='device_id', type=int, default=None,
            help='The id of the device to program'
        )

        self.arg_parser.add_argument(
            '-s', '--serial', dest='serial', type=str, default=None,
            help='Serial number of the USB device to use. Can be a partial match.'
        )

        self.arg_parser.add_argument(
            '-p', '--chip_name', dest='chip_name', type=str, default=None,
            help='Match devices which are using the given mcu chip name.'
        )

        self.arg_parser.add_argument(
            '-d', '--vid-pid', dest='vid_pid', type=str, default=None,
            help='Open a device based on a vid_pid pair'
        )

        self.arg_parser.add_argument(
            '-n', dest='name', type=str, default=None,
            help='Name of the USB device to use. Can be a partial match.'
        )

    def find_matching_device(self, args, multiple_matches=False):
        matching_devices = keyplus.find_devices(
            name = args.name,
            device_id = args.device_id,
            vid_pid = args.vid_pid,
            serial_number = args.serial,
            chip_name = args.chip_name,
        )
        num_matches = len(matching_devices)

        if multiple_matches:
            return matching_devices

        if num_matches== 0:
            print_error("Couldn't find any matching devices.")
            exit(EXIT_MATCH_DEVICE)
        elif num_matches == 1:
            return matching_devices[0]
        elif num_matches > 1:
            # Print all the matching devices
            for kb_device in matching_devices:
                print_hid_info(kb_device.hid_device)
            # Show an error message since the command is expecting only one
            # device
            print_error("Error: found {} matching devices, select a specifc device or "
                  "disconnect the other devices".format(num_matches))
            exit(EXIT_MATCH_DEVICE)

    def task(self, args):
        pass


# Hexdump of raw debug output
class DebugCommand(GenericDeviceCommand):
    def __init__(self):
        super(DebugCommand, self).__init__(
            'Shows raw debug output as a hexdump'
        )

    def task(self, args):
        kb = self.find_matching_device(args)
        print_all_info(kb)

        with kb:
            kb.listen_raw()


# Prints matrix (row,col) pairs of pressed keys to stdout of a connected USB device
class PassthroughCommand(GenericDeviceCommand):
    def __init__(self):
        super(PassthroughCommand, self).__init__(
            'Prints the matrix row and column of keys on a USB connected device.',
            # Usage information
            epilog='device directly connected by USB. \n'
            '\n'
            'NOTE: If you have a wired/wireless '
            'split device, you will need to connect it directly by USB for this '
            'function to work.'
        )

        self.arg_parser.add_argument(
            'enable', nargs='?',
            help='Set passthrough state "on" or "off"'
        )

        self.arg_parser.add_argument(
            '-t', '--timeout',
            type=int, default=20,
            help='Set passthrough state "on" or "off"'
        )

        self.arg_parser.add_argument(
            '-V', dest='verbose', action='store_const',
            const=True, default=False,
            help='Verbose output'
        )

    def task(self, args):

        kb = self.find_matching_device(args)

        with kb:
            kp_layout = kb.unpack_layout_data()

            dev = kp_layout.get_device(kb.device_info.device_id)
            pin_mapping = dev.scan_mode.generate_pin_mapping(kb.get_device_target())

            column_pins = pin_mapping.column_pins

            # List is (col_num, pin_num)
            pin_and_col = list(enumerate(column_pins))
            # Sort this list by it's pin_num field
            pin_and_col.sort(key=lambda x: x[1])
            # Make a mapping from physical column numbers -> logical column
            # numbers
            inv_col_map = [col_num for (col_num, pin_num) in pin_and_col]

            try:
                kb.set_passthrough_mode(True)
            except KeyplusUSBCommandError as err:
                if err.error_code == KeyplusUSBCommandError.ERROR_UNSUPPORTED_COMMAND:
                    print_error("Target device doesn't support matrix scanning")
                    exit(EXIT_UNSUPPORTED_FEATURE)
                raise err
            passthrough_timeout = args.timeout*1000

            response = kb.hid_read(timeout=passthrough_timeout)

            dev_info = kb.get_device_info()
            rows = dev_info.scan_plan.rows
            cols = dev_info.scan_plan.cols

            data_size = (rows * cols + 7) // 8

            def get_bitmap_bitn(bitmap, n):
                return int(not not (bitmap[n//8] & (1 << (n%8))))

            def passthrough_bitmap_to_row_col(bitmap, rows, cols):
                result = []
                for i in range(rows):
                    result.append([0]*cols)
                pos = 0
                while pos < rows * cols:
                    row = pos//cols
                    col = pos%cols
                    col = inv_col_map[col]
                    result[row][col] = get_bitmap_bitn(bitmap, pos)

                    pos += 1
                return result


            while response != None:
                if response[0] == CMD_PASSTHROUGH_MATRIX:
                    passthrough_bitmap = response[1:data_size+1]
                    passthrough_data = passthrough_bitmap_to_row_col(passthrough_bitmap, rows, cols)
                    print()
                    if (args.verbose):
                        for (row_num, row_data) in enumerate(passthrough_data):
                            row_binary = ''
                            for key in row_data:
                                row_binary += str(key)
                            print("row{}: {}".format(row_num, row_binary))
                    result = ''
                    for (row_num, row) in enumerate(passthrough_data):
                        for (col_num, pin) in enumerate(row):
                            if pin:
                                result += "r{}c{} ".format(row_num, col_num)
                    if (result != ''):
                        print(result)
                response = kb.hid_read(timeout=passthrough_timeout)

            kb.set_passthrough_mode(False)

# List connected USB devices
class ListCommand(GenericDeviceCommand):
    def __init__(self):
        super(ListCommand, self).__init__(
            'List the connected keyplus devices'
        )

        self.arg_parser.add_argument(
            '--verbose', '-v', action='count', dest='verbosity',
            help='Verbose device information'
        )

    def task(self, args):
        kb_devices = self.find_matching_device(args, multiple_matches=True)

        for kb_device in kb_devices:
            print_hid_info(kb_device.hid_device)
            if args.verbosity != None:
                if args.verbosity >= 1:
                    print_device_info(kb_device.device_info)

class ReadCommand(GenericDeviceCommand):
    def __init__(self):
        super(ReadCommand, self).__init__(
            'Read data from the keyboard'
        )

        self.arg_parser.add_argument(
            '--dump-settings', dest='dump_settings',
            action='store_const',
            const=True, default=False,
            help='Hexdump of the settings section'
        )

        self.arg_parser.add_argument(
            '--dump-layout', dest='dump_layout',
            action='store_const',
            const=True, default=False,
            help='Hexdump of the layout section'
        )

        self.arg_parser.add_argument(
            '-C', '--compact', dest='compact',
            action='store_const',
            const=True, default=False,
            help='Make the output more compact'
        )

    def task(self, args):
        kb = self.find_matching_device(args)

        with kb:
            if args.dump_settings:
                print("Settings section hexdump:")
                setings_data = kb.read_settings_section()
                if args.compact:
                    print(hexdump.dump(bytes(setings_data.to_bytes())))
                else:
                    hexdump.hexdump(bytes(setings_data.to_bytes()))
                print()

            if args.dump_layout:
                print("Layout section hexdump:")
                layout_data = kb.read_whole_layout()
                if args.compact:
                    print(hexdump.dump(bytes(layout_data)))
                else:
                    hexdump.hexdump(bytes(layout_data))
                print()



# Test command for controlling LEDs
class LEDCommand(GenericDeviceCommand):
    def __init__(self):
        super(LEDCommand, self).__init__(
            'Set the state of the indicator LED'
        )

        self.arg_parser.add_argument(
            'led_number', type=int,
            help="Led number to set"
        )

        self.arg_parser.add_argument(
            'led_state', type=int,
            help="State to set the led: 0->off, 1->on"
        )

    def task(self, args):
        kb = self.find_matching_device(args)
        with kb:
            kb.set_indicator_led(args.led_number, args.led_state)

# Test command for controlling LEDs
class ResetCommand(GenericDeviceCommand):
    def __init__(self):
        super(ResetCommand, self).__init__(
            'Reset the target device'
        )

    def task(self, args):
        kb = self.find_matching_device(args)
        with kb:
            kb.reset()

# Enter the devices bootloader
class BootloaderCommand(GenericDeviceCommand):
    def __init__(self):
        super(BootloaderCommand, self).__init__(
            'Run the bootloader of the target device'
        )

    def task(self, args):
        kb = self.find_matching_device(args)
        with kb:
            kb.enter_bootloader()

class HelpCommand(GenericCommand):
    def __init__(self):
        super(HelpCommand, self).__init__(
            'Prints help for the given command'
        )

        self.arg_parser.add_argument(
            'command_name', type=str, default=None, nargs='?',
            help='The command to provide help for'
        )

    def task(self, args):
        if args.command_name == None:
            self.arg_parser.print_help()
            exit(EXIT_NO_ERROR)

        if args.command_name in KeyplusCLI.COMMAND_NAME_MAP:
            cmd = KeyplusCLI.COMMAND_NAME_MAP[args.command_name]()
            cmd.arg_parser.print_help()
        else:
            print_error(
                "Error: Can't find help for unknown command " + str(args.command_name)
            )


class ProgramCommand(GenericDeviceCommand):
    def __init__(self):
        super(ProgramCommand, self).__init__(
            'Program layout, rf settings and hex files'
        )

        self.arg_parser.add_argument(
            '-l', '--layout', dest='layout_file', type=str, default=None,
            help='The layout file to program'
        )

        self.arg_parser.add_argument(
            '-r', '--rf', dest='rf_file', type=str, default=None,
            help='The rf file to program'
        )

        self.arg_parser.add_argument(
            '-I', '--new-id', dest='new_id', type=int, default=None,
            help='Gives the device a new device ID (requires both a layout file '
            ' and a rf file)'
        )

        self.arg_parser.add_argument(
            '-x', '--fw-hex', dest='hex_file', type=str, default=None,
            help='TODO firmware hex file'
        )

        self.arg_parser.add_argument(
            '-M', '--merge-hex', dest='merge_hex',
            type=lambda x: int(x, 0), # integer of any base
            nargs='+', default=None,
            help='For use in build scripts. Don\'t actually program the device,'
            ' instead use the arguments provided to construct a firmware hexfile'
            ' with the given settings pre-programmed. In addition to the other'
            ' settings fields this flag takes 3 arguments: '
            ' [settings_addr, layout_addr, layout_size]'
            ' which determine where the layout will be placed in the result hex.'
        )

        self.arg_parser.add_argument(
            '-F', '--fw-set', dest='firmware_settings',
            action='append',
            type=str,
            # nargs='+',
            # default=None,
        )

        self.arg_parser.add_argument(
            '-o', '--outfile', dest='outfile',
            type=str, default=None,
            help='The file written by the -M flag. If no file is given, print'
            ' to stdoutput'
        )

        self.arg_parser.add_argument(
            '-e', '--erase',
            action='store_const',
            const=True, default=False,
            help='Erase all the settings and flash of this device. If given'
            ' with --merge-hex, generate a hex file with an erase layout.'
        )

    def task(self, args):
        if (
            args.layout_file == None and
            args.hex_file == None and
            args.rf_file == None and
            args.new_id == None and
            args.merge_hex == None and
            args.erase == None
        ):
            self.arg_parser.print_help()
            exit(EXIT_NO_ERROR)

        if args.merge_hex:
            self.task_mereged_hex(args)
        elif args.erase:
            self.task_erase(args)
        else:
            self.task_program_device(args)

    def task_erase(self, args):
        kb = self.find_matching_device(args)

        with kb:
            kb.erase_settings_section()
            kb.erase_layout_section()
            kb.reset()

    def load_layout_file(self, args):
        try:
            kp_layout = KeyplusLayout()
            kp_layout.from_yaml_file(
                layout_file = args.layout_file,
                rf_file = args.rf_file,
                print_warnings = True,
            )
            return kp_layout
        except KeyplusError as err:
            print_error(str(err))
            exit(EXIT_BAD_FILE)

    def erase_hex_region(self, fw_hex, start, size):
        # first erase anything that is in the layout section
        for i in range(start, start+size):
            fw_hex[i] = 0 # dummy, write so del always works
            del fw_hex[i]

    def write_hex_file(self, args, fw_hex):
        if args.outfile:
            with open(args.outfile, 'w') as outfile:
                fw_hex.write_hex_file(outfile)
        else:
                fw_hex.write_hex_file(sys.stdout)


    def task_mereged_hex(self, args):
        if (
            args.new_id == None or
            args.layout_file == None or
            args.rf_file == None or
            args.hex_file == None
        ):
            self.arg_parser.print_help()
            print_error(
                "To generate a merged hex file, need the following options: "
                "layout file (-l), rf file (-r), hex file (-x), and a device id (-I)"
            )
            exit(EXIT_COMMAND_ERROR)

        if len(args.merge_hex) != 3:
            print_error("To generate a merged hex file, need to provide "
                    "[settings_addr, layout_addr, layout_size] as arguments")
            exit(EXIT_COMMAND_ERROR)

        if args.firmware_settings == None:
            print_error("Must provide firmware settings with -F flag.")
            exit(EXIT_COMMAND_ERROR)

        known_settings = [
            'chip_name', 'scan_method', 'max_rows'
        ]
        firmware_settings = {}
        for option in args.firmware_settings:
            try:
                expression = option.split('=', )
                key = expression[0].lower()
                value = expression[1].lower()
                if key not in known_settings:
                    print_error(
                        "Unknown firmware option '{}', expected one of: {}".format(
                            key, known_settings
                        )
                    )
                    exit(EXIT_COMMAND_ERROR)
                firmware_settings[key] = value
                known_settings.remove(key)
            except KeyError:
                print_error(
                    "Bad format for firmware setting. Expected format is "
                    "'fw_opt=value'."
                )
                exit(EXIT_COMMAND_ERROR)

        # TODO: build device target from command line parameters
        firmware_info = KeyboardFirmwareInfo()
        try:
            firmware_info.chip_id = get_chip_id_from_name(firmware_settings['chip_name'])
            firmware_info.set_internal_scan_method(firmware_settings['scan_method'])
            firmware_info.set_max_rows(firmware_settings['max_rows'])
        except KeyplusSettingsError as err:
            print_error(str(err))
            exit(EXIT_COMMAND_ERROR)
        except KeyError as err:
            print_error("firmware setting '{}' must be set".format(err.args[0]))
            exit(EXIT_COMMAND_ERROR)


        # don't want to program the device, instead we want to build a
        # hexfile with the settings preprogrammed
        with open(args.hex_file) as f:
            fw_hex = intelhex.IntelHex(f)

        # Add chip id to hex file.
        #
        # NOTE:
        # Intel hex files have 6 record types. The record type 0x05 is a 4
        # byte value that is only used on 80386 processors, so here were going
        # to use it for a different purpose. We want the firmware loader to
        # know what microcontroller the hex was compiled for.
        #
        # NOTE: The 0x03 record looks like it can also be used, but it can't
        # be used together with the 0x05 record, so we can't use it.
        #
        # Reference:
        # https://en.wikipedia.org/wiki/Intel_HEX#Record_types

        # fw_hex.start_addr = {
        #     'EIP': firmware_info.chip_id
        # }

        settings_addr = args.merge_hex[0]
        layout_addr = args.merge_hex[1]
        layout_size = args.merge_hex[2]

        # first erase anything that is in the layout section
        self.erase_hex_region(fw_hex, settings_addr, SETTINGS_SIZE)
        self.erase_hex_region(fw_hex, layout_addr, layout_size)

        if args.erase:
            # If the erase flag is given, write the file now and exit
            self.write_hex_file(args, fw_hex)
            exit(EXIT_NO_ERROR)

        # firmware_info.chip_id =
        device_target = KeyboardDeviceTarget(
            device_id = args.new_id,
            firmware_info = firmware_info
        )

        kp_layout = self.load_layout_file(args)

        settings_data = kp_layout.build_settings_section(device_target)
        layout_data = kp_layout.build_layout_section(device_target)

        if len(layout_data) > layout_size:
            print_error("layout data to large. Got {} bytes, but only "
                    "{} bytes available".format(
                        len(layout_data),
                        layout_size
                    ))
            exit(EXIT_INSUFFICIENT_SPACE)

        settings_hex = intelhex.IntelHex()
        settings_hex.frombytes(
            settings_data,
            offset = settings_addr
        )

        layout_hex = intelhex.IntelHex()
        layout_hex.frombytes(
            layout_data,
            offset = layout_addr
        )


        fw_hex.merge(settings_hex, overlap='replace')
        fw_hex.merge(layout_hex, overlap='replace')

        self.write_hex_file(args, fw_hex)

        exit(EXIT_NO_ERROR)

    def task_program_device(self, args):

        kb = self.find_matching_device(args)

        with kb:
            print("Programing start...")
            print_all_info(kb)
            print("")

            device_target = kb.get_device_target()

            # Set the target device id
            if args.new_id != None:
                device_target.device_id = args.new_id

            # If only a new id was provide, update the old settings section
            # with just the new id.
            if (
                args.layout_file == None and
                args.rf_file == None and
                args.new_id != None
            ):
                settings = kb.read_settings_section()
                settings.header.device_id = args.new_id
                kb.update_settings_section(settings, keep_rf=True)
                exit(EXIT_NO_ERROR)


            kp_layout = self.load_layout_file(args)

            # If this command is running, we always update the layout section
            settings = kp_layout.build_settings_section(device_target)

            keep_rf = True
            if args.rf_file:
                keep_rf = False

            kb.update_settings_section(settings, keep_rf)

            layout = kp_layout.build_layout_section(device_target)

            # only update the layout section if a layout file was given
            if args.layout_file:
                layout = kp_layout.build_layout_section(device_target)
                kb.update_layout_section(layout)

            kb.reset(reset_type=RESET_TYPE_SOFTWARE)


class PairCommand(GenericDeviceCommand):
    def __init__(self):
        super(PairCommand, self).__init__(
            'Put the device into pairing mode'
        )

    def task(self, args):
        kb = self.find_matching_device(args)
        with kb:
            kb.enter_pairing_mode()

class KeyplusCLI(object):
    COMMAND_NAME_MAP = {
        "bootloader": BootloaderCommand,
        "debug": DebugCommand,
        "led": LEDCommand,
        "list": ListCommand,
        "passthrough": PassthroughCommand,
        "read": ReadCommand,
        "reset": ResetCommand,
        "program": ProgramCommand,
        "pair": PairCommand,
        # "layers": LayerCommand,
        "help": HelpCommand,
    }

    def __init__(self):

        command_list = {}
        for cmd in KeyplusCLI.COMMAND_NAME_MAP:
            command_list[cmd] = KeyplusCLI.COMMAND_NAME_MAP[cmd]()

        command_list_string = "\n".join(["    {0: <16}{1:}".format(cmd, command_list[cmd].description) for cmd in command_list])

        usage_string = "%(prog)s <command> [<args>]\n\n" + \
        "Where command is one of the options listed below.\n" + command_list_string

        parser = argparse.ArgumentParser(
            description='keyplus keyboard control tool',
            usage=usage_string
        )

        parser.add_argument('command', help='Subcommand to run')

        parser.add_argument('-v', '--version', action='version',
                            version="%(prog)s: "+__version__)
        # parse_args defaults to [1:] for args, but you need to
        # exclude the rest of the args too, or validation will fail
        args = parser.parse_args(sys.argv[1:2])

        if args.command in command_list:
            command_list[args.command].run()
        else:
            print_error('Unrecognized command: {}'.format(args.command))
            parser.print_help()
            exit(EXIT_COMMAND_ERROR)
        exit(EXIT_NO_ERROR)


if __name__ == "__main__":
    colorama.init()
    KeyplusCLI()
