#!python
from __future__ import print_function
import sys
from argparse import ArgumentParser, REMAINDER

from prody import parsePDB
import numpy as np
from scipy.spatial.distance import cdist

parser = ArgumentParser()
parser.add_argument("--rec", default=None,
                    help="Receptor file if using interface mode.")
parser.add_argument("--only-CA", action="store_const", const=True,
                    default=False,
                    help="Only use C alpha atoms.")
parser.add_argument("--only-backbone", action="store_true", default=False)
parser.add_argument("--interface-only", action='store_const', const=True,
                    default=False,
                    help="Only use inteface atoms. Requires --rec.")
parser.add_argument("--interface_radius", type=float,
                    default=10.0,
                    help="Radius around receptor to consider.")
parser.add_argument("pdb_crys",
                    help="PDB file to compare to")
parser.add_argument("pdb_files", nargs=REMAINDER, metavar="pdb_file",
                    help="PDB files to calculate RMSD for.")
args = parser.parse_args()

if args.interface_only and args.rec is None:
    print("--only-inteface requires --rec")
    sys.exit(1)

crys = parsePDB(args.pdb_crys)
pdbs = (parsePDB(f) for f in args.pdb_files)

if crys is None:
    print("Error parsing pdb files")
    sys.exit(1)

if args.only_CA:
    crys = crys.calpha
    pdbs = (p.calpha for p in pdbs)
elif args.only_backbone:
    crys = crys.backbone
    pdbs = (p.backbone for p in pdbs)

crys_coords = crys.getCoords()
pdb_coords = (p.getCoords() for p in pdbs)

if args.interface_only:
    rec = parsePDB(args.rec)
    rec_coords = rec.getCoords()
    sq_radius = args.interface_radius*args.interface_radius

    dists = cdist(rec_coords, crys_coords, 'sqeuclidean')
    interface = np.any(dists < (sq_radius), axis=0).nonzero()[0]

    crys_coords = crys_coords[interface]
    pdb_coords = (c[interface] for c in pdb_coords)

N = len(crys_coords)
for coords, f in zip(pdb_coords, args.pdb_files):
    if len(coords) != N:
        print("[ERROR] unequal number of atoms for file {}".format(f),
              file=sys.stderr)

    delta = crys_coords - coords
    np.multiply(delta, delta, delta)
    rmsd = np.sqrt(np.sum(delta)/N)
    print("{0} {1:.4f}".format(f, rmsd))
