#!/usr/bin/env python

import json
import logging
import argparse

import torchmodels
from torchmodels import utils
from torchmodels import manager


parser = argparse.ArgumentParser()
parser.add_argument("--root-package")
parser.add_argument("--module-name", required=True)
parser.add_argument("--save-path", required=True)
parser.add_argument("--format", type=str, default="yaml",
                    choices=["yaml", "json"])
parser.add_argument("--debug", action="store_true", default=False)


def save_template(args):
    if args.root_package is not None:
        pkg = utils.import_module(args.root_package)
        manager.register_packages(pkg)
    clsmap = manager.get_module_dict()
    cls = clsmap.get(args.module_name)
    assert cls is not None, \
        f"module of name '{args.module_name}' not found."
    template = {
        "type": cls.name,
        "vargs": torchmodels.get_optarg_template(cls)
    }
    dump = utils.map_val(args.format, {
        "yaml": utils.dump_yaml,
        "json": json.dump
    }, "template format")
    with open(args.save_path, "w") as f:
        dump(template, args.save_path)


if __name__ == "__main__":
    args, _ = parser.parse_known_args()
    if args.debug:
        logging.basicConfig(level=0)
    save_template(args)