#!/usr/bin/env python3
"""
 CIJOE Metric Plotter
"""
import argparse
import hashlib
import pprint
import copy
import sys
import os
import matplotlib.ticker as mticker
import matplotlib.pyplot as plt
import jinja2
import yaml
import cij.analyser
import cij.util
import cij

HATCHES = ['o', 'O', '.', '\\', '//', '-', '+', 'x', '\\', '*']

MARKERS = ["+", "o", "s", "x", "h", "d", "*", "<", ">", "1", "*", "p"]

METRICS = ["iops", "bwps", "lat"]

METRIC_DEFAULT_UNIT = {
    "iops": 'k',
    "bwps": "MiB",
    "lat": "usec"
}

KEY_DESCR = {
    "iops": 'I/O operations Per Second',
    "bwps": "BandWidth Per Second",
    "lat": "Latency",

    "bs": "Block-Size",
    "iodepth": "I/O-Depth"
}


def labelizer(ctx, fmt):
    """Produces a label based on the given context, and format-string"""

    tmpl = jinja2.Environment(
        autoescape=True,
        loader=jinja2.BaseLoader()
    ).from_string(fmt)

    return tmpl.render(**ctx)


def collect_metrics(args):
    """Collect all metrics found in 'metrics.yml' of testcase _aux folders"""

    metrics = []
    for root, _, _ in os.walk(args.output):
        yml_fpath = os.path.join(root, "_aux", "metrics.yml")
        if not (root.endswith(".sh") and os.path.exists(yml_fpath)):
            continue

        with open(yml_fpath, 'r', encoding="UTF-8") as yml_file:
            metrics.extend(yaml.safe_load(yml_file))

    return metrics


def metrics_to_pdata(args, metrics):
    """Transform the given metrics into plot-data"""

    pdata = {}
    for metric in metrics:                  # Transform metrics to plot-data
        ctx = copy.deepcopy(metric.get("ctx", {}))

        yval = metric[args.metric]          # metric-key to use as y-value

        xval = int(ctx.pop(args.ctx))       # ctx-key to extract as x-value

        # Consider how to specify requires / queries on ctx like this
        # if ctx["bs"] == "4k":               # ctx key/value requirements
        #     continue

        for key in ["fname", "timestamp"]:  # ctx-keys to ignore / drop
            del ctx[key]

        ident = hashlib.sha224(             # Produce identifer
            pprint.pformat(ctx).encode("utf-8")
        ).hexdigest()[:8]

        if ident not in pdata:
            pdata[ident] = {
                "yvals": [], "xvals": [], "ctx": ctx, "label": ident
            }
            if args.label_fmt:  # overwrite default-label with user-template
                pdata[ident]["label"] = labelizer(ctx, args.label_fmt)

        pdata[ident]["yvals"].append(yval)
        pdata[ident]["xvals"].append(xval)

    for val in pdata.values():
        xvals, yvals = (list(t) for t in (
            zip(*sorted(zip(val["xvals"], val["yvals"])))
        ))
        val["xvals"] = xvals
        val["yvals"] = yvals

    return pdata


def plot_line(args, pdata):
    """Produce a line-plot of the given pdata"""

    for marker, (_, dset) in zip(MARKERS, pdata.items()):
        plt.plot(
            dset["xvals"],
            dset["yvals"],
            label=dset["label"],
            marker=marker
        )

    plt.title("%s as a function of %s" % (
        KEY_DESCR[args.metric], KEY_DESCR[args.ctx]
    ) if args.title is None else args.title)

    plt.xscale(args.x_scale)
    plt.yscale(args.y_scale)

    plt.xlabel(KEY_DESCR[args.ctx] if args.x_label is None else args.x_label)
    plt.ylabel(
        METRIC_DEFAULT_UNIT[args.metric] if args.y_label is None else
        args.y_label
    )

    plt.legend()

    plt.gca().yaxis.set_major_formatter(mticker.FuncFormatter(
        lambda x, pos: "%0.1f" % (
            x / cij.analyser.UNITS[METRIC_DEFAULT_UNIT[args.metric]]
        )
    ))

    return plt


PLTRS = {
    "line": plot_line
}


def parse_args():
    """Parse command-line arguments for CIJOE Metric Plotter"""

    # Parse the Command-Line
    prsr = argparse.ArgumentParser(
        description='cij_plotter - CIJOE Metric Plotter',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    prsr.add_argument(
        '--plot',
        help="The plot to produce",
        choices=sorted(PLTRS.keys()),
        default=sorted(PLTRS.keys())[0]
    )

    prsr.add_argument(
        '--output',
        help="Path to test output",
        default=os.getcwd()
    )

    prsr.add_argument(
        "--x-scale",
        help="Scale of the x-axis",
        choices=["linear", "log", "symlog", "logit"],
        default="linear"
    )
    prsr.add_argument(
        "--x-label",
        help="When None is given, a label is generated using plot-data",
    )

    prsr.add_argument(
        "--y-scale",
        help="Scale of the y-axis",
        choices=["linear", "log", "symlog", "logit"],
        default="linear"
    )
    prsr.add_argument(
        "--y-label",
        help="When none is given, a label is generated using plot-data",
    )

    prsr.add_argument(
        "--title",
        help="When None is given, a title is generated using plot-data"
    )
    prsr.add_argument(
        "--show",
        type=int,
        help="Show the plot produced by the plotter",
        choices=[0, 1],
        default=0
    )
    prsr.add_argument(
        "--save",
        help="Store the plot produced by the plotter",
        type=int,
        choices=[0, 1],
        default=1
    )
    prsr.add_argument(
        "--save-fmt",
        help="The format to store the plot in",
        nargs="*",
        choices=["png", "pdf"],
        default=["png"]
    )

    prsr.add_argument(
        "--label-fmt",
        help="Provide a label format-string using jinja2 template syntax"
    )

    prsr.add_argument(
        "--ctx",
        help="The context attribute to use as 'metric-as-a-function-of'",
        choices=["iodepth", "bs"],
        default="iodepth"
    )
    prsr.add_argument(
        "--metric",
        help="The metric to plot",
        choices=METRICS,
        default=METRICS[0]
    )

    args = prsr.parse_args()

    args.output = cij.util.expand_path(args.output)
    if not os.path.exists(args.output):
        cij.err(f"plotter:output: {args.output}, does not exist")
        return None

    return args


def main():
    """
    Entry point for Plotter, parse arguments, normalize metrics and plot
    """

    args = parse_args()
    if args is None:
        cij.err("plotter: failed parsing command-line args")
        return 1

    metrics = collect_metrics(args)

    pdata = metrics_to_pdata(args, metrics)

    plot = PLTRS[args.plot](args, pdata)    # Produce the plot
    plot_name = f"{args.plot}-{args.metric}"

    if args.save:                           # Store plot-data and plot
        yfname = os.path.join(args.output, f"{plot_name}.yml")
        with open(yfname, 'w', encoding="UTF-8") as yfd:
            yfd.write(yaml.dump(
                pdata, explicit_start=True, default_flow_style=False
            ))

        for fmt in args.save_fmt:
            plot.savefig(os.path.join(args.output, f"{plot_name}.{fmt}"))

    if args.show:                           # Show it
        plt.show()

    return 0


if __name__ == "__main__":
    sys.exit(main())
