#! /usr/bin/env python3
"""
Convenience script for gathering metrics data

"""
import argparse
import logging
import os
import shutil
from datetime import datetime, timedelta

from indy_common.config_util import getConfig
from plenum.common.constants import KeyValueStorageType
from plenum.common.metrics_collector import MetricsName
from plenum.common.metrics_stats import load_metrics_from_kv_store
from storage.helper import initKeyValueStorage

logging.root.handlers = []
logger = logging.getLogger()
logger.propagate = False
logger.disabled = True


def read_args():
    parser = argparse.ArgumentParser(description="Process metrics")

    parser.add_argument('--data_dir', required=False, help="Path to metrics data, replaces a")
    parser.add_argument('--node_name', required=False, help="Node's name")
    parser.add_argument('--network', required=False, help="Network name to read metrics from")
    parser.add_argument('--output', required=False, default=None, help="Output CSV file name with metrics")
    parser.add_argument('--ts_fmt', required=False, default="%Y-%m-%d %H:%M:%S", help="Output timestamp format")

    parser.add_argument(
        '--min_ts',
        required=False,
        default=None,
        help="Process all metrics starting from given timestamp (beginning by default)")
    parser.add_argument(
        '--max_ts',
        required=False,
        default=None,
        help="Process all metrics till given timestamp (end by default)")
    parser.add_argument(
        '--step',
        type=int,
        required=False,
        default=60,
        help="Build timelog with given timestep, seconds (60 by default)"
    )

    return parser.parse_args()


def get_data_dir(node_name=None, network=None):
    config = getConfig()
    _network = network if network else config.NETWORK_NAME
    data_dir = os.path.join(config.LEDGER_DIR, _network, 'data')

    if not os.path.exists(data_dir):
        print("No such file or directory: {}".format(data_dir))
        print("Please check, that network: '{}' was used ".format(_network))
        exit()

    if not node_name:
        dirs = os.listdir(data_dir)
        if len(dirs) == 0:
            print("Node's 'data' folder not found: {}".format(data_dir))
            exit()
        node_name = dirs[0]

    return os.path.join(data_dir, node_name)


def make_copy_of_data(data_dir):
    read_copy_data_dir = data_dir + '-read-copy'
    if os.path.exists(read_copy_data_dir):
        shutil.rmtree(read_copy_data_dir)
    shutil.copytree(data_dir, read_copy_data_dir)
    return read_copy_data_dir


def process_storage(storage, args):
    fmt = "%Y-%m-%d %H:%M:%S"
    min_ts = datetime.strptime(args.min_ts, fmt) if args.min_ts else None
    max_ts = datetime.strptime(args.max_ts, fmt) if args.max_ts else None
    stats = load_metrics_from_kv_store(storage, min_ts, max_ts, timedelta(seconds=args.step))

    if args.output is not None:
        with open(args.output, 'w') as f:
            columns = [
                "timestamp",
                "avg_looper_run_time_spent",
                "avg_three_pc_batch_size",
                "avg_transport_batch_size",
                "avg_incoming_node_message_size",
                "avg_outgoing_node_message_size",
                "avg_incoming_client_message_size",
                "avg_outgoing_client_message_size",
                "avg_master_validation_time",
                "avg_master_validation_time_per_req",
                "avg_monitor_avg_throughput",
                "avg_monitor_avg_latency",
                "avg_master_monitor_avg_throughput",
                "avg_master_monitor_avg_latency",
                "avg_node_stack_messages_processed",
                "avg_client_stack_messages_processed",
                "node_stack_messages_processed_per_sec",
                "client_stack_messages_processed_per_sec",
                "master_three_pc_batch_count_per_sec",
                "master_requests_count_per_sec",
                "master_ordered_three_pc_batch_count_per_sec",
                "master_ordered_requests_count_per_sec",
                "incoming_node_messages_per_sec",
                "incoming_node_traffic_per_sec",
                "outgoing_node_messages_per_sec",
                "outgoing_node_traffic_per_sec",
                "incoming_client_messages_per_sec",
                "incoming_client_traffic_per_sec",
                "outgoing_client_messages_per_sec",
                "outgoing_client_traffic_per_sec",
                "AVG_NODE_PROD_TIME",
                "AVG_SERVICE_REPLICAS_TIME",
                "AVG_SERVICE_NODE_MSGS_TIME",
                "AVG_SERVICE_CLIENT_MSGS_TIME",
                "AVG_SERVICE_ACTIONS_TIME",
                "AVG_SERVICE_LEDGER_MANAGER_TIME",
                "AVG_SERVICE_VIEW_CHANGER_TIME",
                "AVG_SERVICE_OBSERVABLE_TIME",
                "AVG_SERVICE_OBSERVER_TIME",
                "AVG_FLUSH_OUTBOXES_TIME",
                "AVG_SERVICE_NODE_LIFECYCLE_TIME",
                "AVG_SERVICE_CLIENT_STACK_TIME"
            ]
            f.write(",".join(columns))
            f.write("\n")
            for ts, frame in sorted(stats.frames(), key=lambda v: v[0]):
                row = [
                    ts.replace(microsecond=0).strftime(args.ts_fmt),
                    frame.get(MetricsName.LOOPER_RUN_TIME_SPENT).avg,
                    frame.get(MetricsName.THREE_PC_BATCH_SIZE).avg,
                    frame.get(MetricsName.TRANSPORT_BATCH_SIZE).avg,
                    frame.get(MetricsName.INCOMING_NODE_MESSAGE_SIZE).avg,
                    frame.get(MetricsName.OUTGOING_NODE_MESSAGE_SIZE).avg,
                    frame.get(MetricsName.INCOMING_CLIENT_MESSAGE_SIZE).avg,
                    frame.get(MetricsName.OUTGOING_CLIENT_MESSAGE_SIZE).avg,
                    frame.get(MetricsName.MASTER_REQUEST_PROCESSING_TIME).avg,
                    frame.get(MetricsName.MASTER_REQUEST_PROCESSING_TIME).sum / (frame.get(MetricsName.MASTER_ORDERED_BATCH_SIZE).sum + 0.01),
                    frame.get(MetricsName.MONITOR_AVG_THROUGHPUT).avg,
                    frame.get(MetricsName.MONITOR_AVG_LATENCY).avg,
                    frame.get(MetricsName.MASTER_MONITOR_AVG_THROUGHPUT).avg,
                    frame.get(MetricsName.MASTER_MONITOR_AVG_LATENCY).avg,
                    frame.get(MetricsName.NODE_STACK_MESSAGES_PROCESSED).avg,
                    frame.get(MetricsName.CLIENT_STACK_MESSAGES_PROCESSED).avg,
                    frame.get(MetricsName.NODE_STACK_MESSAGES_PROCESSED).sum / stats.timestep.total_seconds(),
                    frame.get(MetricsName.CLIENT_STACK_MESSAGES_PROCESSED).sum / stats.timestep.total_seconds(),
                    frame.get(MetricsName.MASTER_3PC_BATCH_SIZE).count / stats.timestep.total_seconds(),
                    frame.get(MetricsName.MASTER_3PC_BATCH_SIZE).sum / stats.timestep.total_seconds(),
                    frame.get(MetricsName.MASTER_ORDERED_BATCH_SIZE).count / stats.timestep.total_seconds(),
                    frame.get(MetricsName.MASTER_ORDERED_BATCH_SIZE).sum / stats.timestep.total_seconds(),
                    frame.get(MetricsName.INCOMING_NODE_MESSAGE_SIZE).count / stats.timestep.total_seconds(),
                    frame.get(MetricsName.INCOMING_NODE_MESSAGE_SIZE).sum / stats.timestep.total_seconds(),
                    frame.get(MetricsName.OUTGOING_NODE_MESSAGE_SIZE).count / stats.timestep.total_seconds(),
                    frame.get(MetricsName.OUTGOING_NODE_MESSAGE_SIZE).sum / stats.timestep.total_seconds(),
                    frame.get(MetricsName.INCOMING_CLIENT_MESSAGE_SIZE).count / stats.timestep.total_seconds(),
                    frame.get(MetricsName.INCOMING_CLIENT_MESSAGE_SIZE).sum / stats.timestep.total_seconds(),
                    frame.get(MetricsName.OUTGOING_CLIENT_MESSAGE_SIZE).count / stats.timestep.total_seconds(),
                    frame.get(MetricsName.OUTGOING_CLIENT_MESSAGE_SIZE).sum / stats.timestep.total_seconds(),
                    frame.get(MetricsName.NODE_PROD_TIME).avg,
                    frame.get(MetricsName.SERVICE_REPLICAS_TIME).avg,
                    frame.get(MetricsName.SERVICE_NODE_MSGS_TIME).avg,
                    frame.get(MetricsName.SERVICE_CLIENT_MSGS_TIME).avg,
                    frame.get(MetricsName.SERVICE_ACTIONS_TIME).avg,
                    frame.get(MetricsName.SERVICE_LEDGER_MANAGER_TIME).avg,
                    frame.get(MetricsName.SERVICE_VIEW_CHANGER_TIME).avg,
                    frame.get(MetricsName.SERVICE_OBSERVABLE_TIME).avg,
                    frame.get(MetricsName.SERVICE_OBSERVER_TIME).avg,
                    frame.get(MetricsName.FLUSH_OUTBOXES_TIME).avg,
                    frame.get(MetricsName.SERVICE_NODE_LIFECYCLE_TIME).avg,
                    frame.get(MetricsName.SERVICE_CLIENT_STACK_TIME).avg
                ]
                f.write(",".join(str(v if v else "0") for v in row))
                f.write("\n")

    print("Start time: {}".format(stats.min_ts))
    print("End time: {}".format(stats.max_ts))
    print("Duration: {}".format(stats.max_ts - stats.min_ts))
    print("")

    total = stats.total

    print("Seconds passed between looper runs:")
    print("   {}".format(total.get(MetricsName.LOOPER_RUN_TIME_SPENT)))
    print("")

    print("Profiling info:")
    for i in range(MetricsName.NODE_PROD_TIME, MetricsName.SERVICE_CLIENT_STACK_TIME):
        mn = MetricsName(i)
        print("   {} : {}".format(str(mn), total.get(mn)))
    print("")

    print("Number of node messages processed in one looper run:")
    print("   {}".format(total.get(MetricsName.NODE_STACK_MESSAGES_PROCESSED)))
    print("")

    print("Number of client messages processed in one looper run:")
    print("   {}".format(total.get(MetricsName.CLIENT_STACK_MESSAGES_PROCESSED)))
    print("")

    print("Number of messages in one tranport batch:")
    print("   {}".format(total.get(MetricsName.TRANSPORT_BATCH_SIZE)))
    print("")

    print("Number of requests in one 3PC batch:")
    print("   {}".format(total.get(MetricsName.THREE_PC_BATCH_SIZE)))
    print("")

    print("Number of requests ordered:")
    print("   {}".format(total.get(MetricsName.ORDERED_BATCH_SIZE)))
    print("")

    print("Time spent on requests processing:")
    print("   {}".format(total.get(MetricsName.REQUEST_PROCESSING_TIME)))
    print("")

    print("Number of requests in one 3PC batch on master instance:")
    print("   {}".format(total.get(MetricsName.MASTER_3PC_BATCH_SIZE)))
    print("")

    print("Number of requests ordered on master instance:")
    print("   {}".format(total.get(MetricsName.MASTER_ORDERED_BATCH_SIZE)))
    print("")

    print("Time spent on requests processing on master instance:")
    print("   {}".format(total.get(MetricsName.MASTER_REQUEST_PROCESSING_TIME)))
    print("")

    print("Incoming node message size, bytes:")
    print("   {}".format(total.get(MetricsName.INCOMING_NODE_MESSAGE_SIZE)))
    print("")

    print("Outgoing node message size, bytes:")
    print("   {}".format(total.get(MetricsName.OUTGOING_NODE_MESSAGE_SIZE)))
    print("")

    print("Incoming client message size, bytes:")
    print("   {}".format(total.get(MetricsName.INCOMING_CLIENT_MESSAGE_SIZE)))
    print("")

    print("Outgoing client message size, bytes:")
    print("   {}".format(total.get(MetricsName.OUTGOING_CLIENT_MESSAGE_SIZE)))
    print("")

    print("Additional statistics:")
    three_pc = total.get(MetricsName.MASTER_ORDERED_BATCH_SIZE)
    process = total.get(MetricsName.MASTER_REQUEST_PROCESSING_TIME)
    node_in = total.get(MetricsName.INCOMING_NODE_MESSAGE_SIZE)
    node_out = total.get(MetricsName.OUTGOING_NODE_MESSAGE_SIZE)
    client_in = total.get(MetricsName.INCOMING_CLIENT_MESSAGE_SIZE)
    client_out = total.get(MetricsName.OUTGOING_CLIENT_MESSAGE_SIZE)
    node_messages = node_in.count + node_out.count
    node_traffic = node_in.sum + node_out.sum
    client_messages = (client_in.count + client_out.count)
    client_traffic = (client_in.sum + client_out.sum)
    print("   Node incoming/outgoing: {:.2f} messages, {:.2f} traffic"
          .format(node_in.count / node_out.count, node_in.sum / node_out.sum))
    print("   Client incoming/outgoing: {:.2f} messages, {:.2f} traffic"
          .format(client_in.count / client_out.count, client_in.sum / client_out.sum))
    print("   Node/client: {:.2f} messages, {:.2f} traffic"
          .format(node_messages / client_messages, node_traffic / client_traffic))
    print("   Node messages per batch: {:.2f}".format(node_messages / three_pc.count))
    print("   Node traffic per batch: {:.2f}".format(node_traffic / three_pc.count))
    print("   Node messages per request: {:.2f}".format(node_messages / three_pc.sum))
    print("   Node traffic per request: {:.2f}".format(node_traffic / three_pc.sum))
    print("   Request processing time: {:.2f} ms".format(1000.0 * process.sum / three_pc.sum))


if __name__ == '__main__':
    args = read_args()

    if args.data_dir is not None:
        storage = initKeyValueStorage(KeyValueStorageType.Rocksdb, args.data_dir, "")
        process_storage(storage, args)
        exit()

    data_dir = get_data_dir(args.node_name, args.network)
    data_is_copied = False
    try:
        config = getConfig()
        if config.METRICS_KV_STORAGE != KeyValueStorageType.Rocksdb:
            data_dir = make_copy_of_data(data_dir)
            data_is_copied = True

        storage = initKeyValueStorage(
            config.METRICS_KV_STORAGE,
            data_dir,
            config.METRICS_KV_DB_NAME)

        process_storage(storage, args)
    finally:
        if data_is_copied:
            shutil.rmtree(data_dir)
