#!/usr/bin/env python
#Copyright (c) Aleksandra Jarmolinska 2018

"""Program for topologically conscious simplification of the 3D structure of  biological molecule.
Can also calculate knot type using the Dowker notation."""
from __future__ import print_function
import argparse
import sys,os,ntpath


from knot_pull.reader import read_from_pdb, read_from_xyz, read_from_cif, guess_format
from knot_pull.puller import pull
from knot_pull.writer import write_xyz,print_out_last_frame
from knot_pull.unentangler import unentangle
import knot_pull.config as kpcfg

name = "knot_pull"

CONFIG = {'CHECK_KNOT': 0, 'KEEP_ALL': 0, 'TRAJECTORY': 1, 'OUTPUT': '', "GREEDY": 0, 'QUIET': 0}

def main():
    """Main body of the command line program"""
    parser = argparse.ArgumentParser(description="Simplify the 3D structure of a molecule while preserving the topology")

    parser.add_argument("infile", help="Structure file which will be used (to download the structure from RCSB PDB "
                                       "give only the PDB ID",nargs='?')
    parser.add_argument("-p", "--preserve_resnums", action="store_true", default=1,
                        help="Add pseudoatoms to keep consistent length throughout the smoothing trajectory (default 1)")
    #parser.add_argument("-p", "--preserve_resnums", action="store_true", default=0,
    #                    help="Add pseudoatoms to keep consistent length throughout the smoothing trajectory (default 1)")
    parser.add_argument("-t", "--trajectory", action="store_true", default=1,
                        help="Write out all steps of the simplification [default is 1]")
    parser.add_argument("-f", "--format", choices=['pdb','xyz','guess'], default='guess',
                        help="Input file format: xyz or pdb [default is guess]")
    parser.add_argument("-k", "--detect_knot", default=0,  action="store_true",
                        help="Calculate the knot Dowker notation and assign the knot type")
    parser.add_argument("-c", "--chain", default='',
                        help="Specify which chain(s) you want to simplify. Multiple chains should be separated by commas.")
    parser.add_argument("-b", "--begin", default=None,
                        help="""Smooth and analyze only part of the structure: specify the first bead to be considered.
                         '-b 4' will start from the 4th bead, '-b i4' will start from the residue/coordinate with id 4.
                         When specified, only the first chain from selected/present will be used.""")
    parser.add_argument("-e", "--end", default=None,
                        help="""Smooth and analyze only part of the structure: specify the first bead to be considered.
                         '-e 104' will disregard beads from 105th forward, '-e i104' will disregard residue/coordinates 
                         with id >=105. When specified, only the first chain from selected/present will be used.""")
    parser.add_argument("-o", "--output", default='',nargs='?',
                        help="""Specify the output file. Format is guessed based on extension [default pdb]. 
                        For stdout use "-". Empty argument makes an '_out.pdb' file. 
                        If this option is missing just the last frame will be shown (sets trajectory to 0)""")
    parser.add_argument("-r", "--representative_frame", action="store_true", default=0,
                        help="Find and print to a separate file the first frame with the same number of crossings as the final one")
    parser.add_argument("-q", "--quiet", action="store_true", default=0,
                        help="Nothing except the knot type (if calculated) will be written to the screen")
    parser.add_argument("-a", "--atom_name", default=None,
                        help="Specify atom name to be used for selecting coordinates (default CA for proteins, P for RNA)")
    parser.add_argument("-x", "--write_xyz", default=False, action="store_true",
                        help="Write the output files in .xyz format")
    parser.add_argument("-R", "--rna", action="store_true", default=0,
                        help="""Used only if a file in .pdb format is provided. Will use P atoms for coordinates,
                        instead of C-alpha""")
    parser.add_argument("-S", "--save", default='',nargs='?',help=argparse.SUPPRESS)
                        #help="""Write out a save file. If any argument is given here - run from a savefile
                        #(knot detection only)."""
    parser.add_argument("-V", "--verbose", default=0,action="store_true",help=argparse.SUPPRESS)
                        #help="""Write out a save file. If any argument is given here - run from a savefile
                        #(knot detection only)."""
    parser.add_argument("-n", "--no_pulling", default=0, action="store_true", help=argparse.SUPPRESS)
    # help="""Write out a save file. If any argument is given here - run from a savefile
    # (knot detection only)."""

    args = parser.parse_args()

    CONFIG['CHECK_KNOT'] = args.detect_knot
    CONFIG['KEEP_ALL'] = args.preserve_resnums
    CONFIG['TRAJECTORY'] = args.trajectory
    CONFIG['GREEDY'] = args.detect_knot
    CONFIG['QUIET'] = args.quiet

    if args.verbose:
        kpcfg.VERBOSE = True

    rna = False

    if args.save:
        atoms = read_from_xyz(filename=args.save)
        if not CONFIG['QUIET']: print ("Save start",len(atoms))
        chains = pull(atoms, '', greedy=0)
        unentangle(chains, outfile=0)
        exit()

    infile = args.infile
    if not infile:
        print ("""No input file specified!!\n##############################################""")
        parser.print_help(sys.stderr)
        sys.exit(1)

    if not os.path.isfile(infile):
        path,fname = ntpath.split(infile)
        pdbid = fname.split(".")[0]
        #savename = "{}{}{}.pdb".format(path,os.path.sep if path else "",pdbid)
        #download_from_pdb(infile,savename=savename)
        #args.format = 'pdb'
        #infile = savename
        args.format = 'online'
        infile = pdbid

    selected_chains = args.chain.split(",") if args.chain else []
    b_res = args.begin
    e_res = args.end
    if b_res is not None or e_res is not None:
        if len(selected_chains) != 1:
            print("""#WARNING: Begin/end residue was specified. Only one chain will be analyzed """)
        try:
            if b_res: real_b = int(b_res.replace("i",''))
        except:
            sys.exit("Begin residue cannot be interpreted as an int: {}".format(b_res))
        try:
            if e_res: real_e = int(e_res.replace("i",''))
        except:
            sys.exit("End residue cannot be interpreted as an int: {}".format(e_res))


    if args.format == "guess":
        if ".xyz" in infile.lower():
            fformat = 'xyz'
        elif ".pdb" in infile.lower():
            fformat = 'pdb'
        elif ".cif" in infile.lower():
            fformat = 'cif'
        else:
            fformat = guess_format(infile)
    else:
        fformat = args.format

    if not CONFIG['QUIET']: print ("#LOG: data format: {}".format(fformat))
    outfmt = "xyz" if args.write_xyz else "pdb"
    RW_CONFIG = {'keep_all': args.preserve_resnums, 'quiet': args.quiet, 'outfmt': outfmt, "atom_name": (args.atom_name if args.atom_name is not None else "P" if args.rna else "CA"),
                 'begin': b_res,'end': e_res, "trajectory": args.trajectory, "outfile": None, 'chains': selected_chains,
                 'default_atom_name': (args.atom_name if args.atom_name is not None else "P" if args.rna else "CA"),
                 'rna': args.rna, "no_pulling": args.no_pulling}

    if fformat == 'online':
        atoms,cnames,err_mes = read_from_cif(name=infile, config=RW_CONFIG,online=True)
        infile = "{}.pdb".format(infile)
        if not atoms:
            print(err_mes)
            parser.print_help(sys.stderr)
            sys.exit(1)
    elif fformat == 'cif':
        atoms, cnames, err_mes = read_from_cif(name=infile, config=RW_CONFIG,online=False)
        if not atoms:
            print(err_mes)
            #            print ("""No such chain: {}\n##############################################""".format(selected_chains))
            parser.print_help(sys.stderr)
            sys.exit(1)
    elif fformat == 'pdb':
        rna = args.rna
        atoms,cnames,err_mes = read_from_pdb(filename=infile, config=RW_CONFIG)
        if not atoms:
            print(err_mes)
#            print ("""No such chain: {}\n##############################################""".format(selected_chains))
            parser.print_help(sys.stderr)
            sys.exit(1)
    else:
        if args.chain: print ("#WARNING: File format is .xyz yet chain was specified - it will not work, all chains will be used")
        atoms,cnames,err_mes = read_from_xyz(filename=infile, config=RW_CONFIG)
        if not atoms:
            print(err_mes)
#            print ("""Empty file/wrong format!\n##############################################""")
            parser.print_help(sys.stderr)
            sys.exit(1)

    if not CONFIG['QUIET']: print ("#LOG: got coordinates: {} bead(s)".format(len(atoms)))

    RW_CONFIG['chains'] = cnames

    output = args.output
    if args.output == "":
        output = sys.stdout
        CONFIG['TRAJECTORY'] = False
        RW_CONFIG['trajectory'] = False
    elif args.output == "-":
        output = sys.stdout
    elif args.output is None:
        output = infile.replace("xyz","pdb")
        if outfmt=="xyz":
            output = output.replace(".pdb","_out.xyz")
        else:
            output = output.replace(".pdb", "_out.pdb")
        output = open(output, "w")
    else:
        output = open(output, "w")

    RW_CONFIG["outfile"] = output

    chains,repr = pull(atoms, RW_CONFIG, cnames, get_representative=args.representative_frame, greedy=CONFIG["GREEDY"])
    if args.output:
        output.close()

    if args.representative_frame:
        repr_infile = output.name.replace("_out.","_out_repr.")
        with open(repr_infile,"w") as repr_out:
            RW_CONFIG['outfile'] = repr_out
            print_out_last_frame(RW_CONFIG, repr, 1, final=True)

    if not CONFIG['QUIET']:
        print ("#LOG: finished smoothing: {} chain(s)".format(len(chains)))

    if args.save is None:
        out = "SAVE_" + infile.replace(".pdb", "_out.pdb").replace(".xyz", "_out.xyz")
        out = open(out, "w" )
        for c in chains:
            for alist in c.atom_lists:
                write_xyz(out, alist)#,soft_end=True)
        out.close()

    if atoms and CONFIG['CHECK_KNOT']:
        if not CONFIG['QUIET']: print ("Final topology: {}: ".format([len(c.atoms) for c in chains]),sep="")
        topo, dt_code = unentangle(chains)
        print (topo)
        if not CONFIG['QUIET']: print("#DT codes:", dt_code)





if __name__ == "__main__":
    main()
