#!/usr/bin/env python3
#
# Copyright (C) 2017 Chris Cummins.
#
# This file is part of cldrive.
#
# Cldrive is free software: you can redistribute it and/or modify it under
# the terms of the GNU General Public License as published by the Free
# Software Foundation, either version 3 of the License, or (at your
# option) any later version.
#
# Cldrive is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
# or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public
# License for more details.
#
# You should have received a copy of the GNU General Public License
# along with cldrive.  If not, see <http://www.gnu.org/licenses/>.
#
import sys

from argparse import ArgumentParser
from io import TextIOWrapper

import numpy as np

from cldrive import *


def main():
    parser = ArgumentParser(description="""\
Reads an OpenCL kernel from stdin, generates data for it, executes it on a \
suitable device, and prints the outputs. Copyright (C) 2017 Chris Cummins \
<chrisc.101@gmail.com>. <https://github.com/ChrisCummins/cldrive>""")

    # admin
    parser.add_argument(
        "--version", action="store_true",
        help="show version information and exit")
    parser.add_argument(
        "--clinfo", action="store_true",
        help="list available OpenCL devices and exit")

    # environment controls
    parser.add_argument(
        "-p", "--platform", metavar="<platform name>", default=None,
        help="use OpenCL platform with name, e.g. 'NVIDIA CUDA'")
    parser.add_argument(
        "-d", "--device", metavar="<device name>", default=None,
        help="use OpenCL device with name, e.g. 'GeForce GTX 1080'")
    parser.add_argument(
        "--devtype", metavar="<devtype>", default="all",
        help="use any OpenCL device of type {all,cpu,gpu} (default: all)")

    # input opts
    parser.add_argument(
        "-s", "--size", metavar="<size>", type=int, default=64,
        help="size of input arrays to generate (default: 64)")
    parser.add_argument(
        "-i", "--generator", metavar="<{rand,arange,zeros,ones}>", default="arange",
        help="input generator to use, one of: {rand,arange,zeros,ones} (default: arange)")
    parser.add_argument(
        "--scalar-val", metavar="<float>", type=float, default=None,
        help="values to assign to scalar inputs (default: <size> argumnent)")

    # runtime opts
    parser.add_argument(
        "-g", "--gsize", metavar="<global size>", default="64,1,1",
        help="comma separated NDRange for global size (default: 64,1,1)")
    parser.add_argument(
        "-l", "--lsize", metavar="<local size>", default="32,1,1",
        help="comma separated NDRange for local (workgroup) size (default: 32,1,1)")
    parser.add_argument(
        "-t", "--timeout", metavar="<timeout>", type=int, default=-1,
        help=("error if execution has not completed after this many seconds"
              "(default: off)"))
    parser.add_argument(
        "--no-opts", action="store_true",
        help="disable OpenCL optimizations (on by default)")
    parser.add_argument(
        "--profiling", action="store_true",
        help="enable kernel and transfer profiling")
    parser.add_argument(
        "--debug", action="store_true",
        help="enable more verbose OpenCL copmilation and execution")
    parser.add_argument(
        "-b", "--binary", action="store_true",
        help="Print outputs as a pickled binary numpy array")
    args = parser.parse_args()

    if args.version:
        import cldrive
        print(f"cldrive {cldrive.__version__} made with \033[1;31m♥\033[0;0m "
              "by Chris Cummins <chrisc.101@gmail.com>.")
        sys.exit(0)

    if args.clinfo:
        clinfo()
        sys.exit(0)

    # read kernel source
    src = sys.stdin.read()

    # parse inputs from strings
    gsize, lsize = NDRange.from_str(args.gsize), NDRange.from_str(args.lsize)
    data_generator = Generator.from_str(args.generator)

    env = make_env(devtype=args.devtype,
                   platform=args.platform,
                   device=args.device)

    inputs = make_data(src=src, size=args.size, data_generator=data_generator,
                       scalar_val=args.scalar_val)

    outputs = drive(env=env, src=src, inputs=inputs, gsize=gsize, lsize=lsize,
                    optimizations=not args.no_opts, profiling=args.profiling,
                    debug=args.debug, timeout=args.timeout)

    # print result
    if args.binary:
        d = pickle.dumps(outputs)
        sys.stdout = TextIOWrapper(sys.stdout.detach(), encoding='latin-1')
        print(d.decode('latin-1'), end='', flush=True)
    else:
        np.set_printoptions(threshold=np.nan)
        args = [arg for arg in extract_args(src)
                if not arg.address_space == 'local']
        assert(len(args) == len(outputs))
        for arr, arg in zip(outputs, args):
            print(f"{arg.name}: {arr}")


if __name__ == "__main__":
    main()
