#!/usr/bin/env python
import os
import sys
import shutil
from glob import glob
import argparse

import numpy as np
import pandas as pd

from mdkit.utility import mol2

# prefix to identify ligand, target and isomer directories
ligdir_prefix = 'lig'
tardir_prefix = 'target'
isodir_prefix = 'isomer'

# command-line arguments and options
parser = argparse.ArgumentParser(description="Extract best docking poses after rundbx finished.")

parser.add_argument('-all-targets',
    dest='combine_targets',
    action='store_true',
    default=False,
    help='Select best poses over all the targets. If not specified, extract best pose separately for each target. A "%s/%s/%s" architecture \
of the folders is assumed'%(ligdir_prefix,tardir_prefix,isodir_prefix))

parser.add_argument('-all-isomers',
    dest='combine_isomers',
    action='store_true',
    default=False,
    help='Select best poses over all the isomers. If not specified, extract best pose separately for every isomer. A "%s/%s/%s" architecture \
of the folders is assumed'%(ligdir_prefix,tardir_prefix,isodir_prefix))

parser.add_argument('-csv',
    type=str,
    dest='csvfile',
    metavar='FILE',
    help='.csv filename with compounds. Used to add names of compounds. Default: none')

parser.add_argument('-cutoff',
    dest='cutoff',
    type=float,
    metavar='RMSD_VALUE',
    default=2.0,
    help='RMSD cutoff used for consensus docking or score-based consensus docking. Default: 2.0 A')

parser.add_argument('-d',
    dest='docking_programs',
    nargs='+',
    metavar=('PRGM1', 'PRGM2'),
    help='Docking programs (instances) to be considered when extracting best poses')

parser.add_argument('-dirs',
    dest='dirs',
    nargs='+',
    default=['.'],
    metavar=('DIR1', 'DIR2'),
    help='Directories considered for analysis. Should contain a folder called "poses". Default: curr. dir')

parser.add_argument('-r',
    required=False,
    dest='resultsdir',
    default='results',
    metavar='DIRECTORY NAME',
    help='Name of results directory. Default: results')

group = parser.add_mutually_exclusive_group(required=False)

group.add_argument('-s',
    nargs='+',
    dest='sf',
    metavar='FUNC',
    help='Scoring functions used to extract the best pose (combination of scores)')

group.add_argument('-cd',
    dest='cd',
    nargs='+',
    metavar='PRGM',
    help='Docking programs used for standard consensus docking')

group.add_argument('-sbcd',
    dest='sbcd',
    nargs='+',
    metavar='FUNC',
    help='Scoring functions used for score-based consensus docking')

# update parsers with arguments
args = parser.parse_args()

def compute_rmsd(file1, file2):
    """Compute RMSD between 2 poses (no fit). Since the protein structures are the same, there is no need to perform the alignment
on the protein prior to RMSD."""

    # load coordinates of first pose (non-hydrogen atoms)
    coords1 = mol2.get_coordinates(file1, keep_h=False)
    coords1 = np.array(coords1)

    # load coordinates of second pose (non-hydrogen atoms)
    coords2 = mol2.get_coordinates(file2, keep_h=False)
    coords2 = np.array(coords2)

    natoms = coords1.shape[0]
    rmsd = np.sqrt(np.sum((coords1-coords2)**2)/natoms)
    return rmsd

def check_architecture(directory):
    """Check architecture %s*/%s*/%s* of specified directories"""%(ligdir_prefix,tardir_prefix,isodir_prefix)

    if os.path.isdir(directory):
        dir_split = directory.split('/')
        if dir_split[-1].startswith(isodir_prefix):
            isisomerID = True
            if len(dir_split) > 1 and dir_split[-2].startswith(tardir_prefix):
                istargetID = True
                if len(dir_split) > 2 and dir_split[-3].startswith(ligdir_prefix):
                    isligID = True
                else:
                    isligID = False
            elif len(dir_split) > 1 and dir_split[-2].startswith(ligdir_prefix):
                istargetID = False
                isligID = True
            else:
                istargetID = False
                isligID = False
        elif dir_split[-1].startswith(tardir_prefix):
            isisomerID = False
            istargetID = True
            if len(dir_split) > 1 and dir_split[-2].startswith(ligdir_prefix):
                isligID = True
            else:
                isligID = False
        elif dir_split[-1].startswith(ligdir_prefix):
            isisomerID = False
            istargetID = False
            isligID = True
        else:
            isisomerID = False
            istargetID = False
            isligID = False

    return isligID, istargetID, isisomerID

def get_IDs(directory, isligID, istargetID, isisomerID):
    """Get IDs of ligand target and isomer (if applicable) according to the current architecture."""

    if isisomerID:
        isomerID = directory.split('/')[-1]
        if istargetID:
            targetID = directory.split('/')[-2]
            if isligID:
                ligID = directory.split('/')[-3]
            else:
                ligID = None
        elif isligID:
            targetID = None
            ligID = directory.split('/')[-2]
        else:
            targetID = None
            ligID = None
    elif istargetID:
        isomerID = None
        targetID = directory.split('/')[-1]
        if isligID:
            ligID = directory.split('/')[-2]
        else:
            ligID = None
    elif isligID:
        isomerID = None
        targetID = None
        ligID = directory.split('/')[-1]
    else:
        isomerID = None
        targetID = None
        ligID = None

    return ligID, targetID, isomerID

def check_directories(dirs):
    if dirs != ['.']:
        iscwd = False
        for jdx, dir in enumerate(dirs):
            isligID, istargetID, isisomerID = check_architecture(dir)
            if jdx == 0:
                isligID_ref = isligID
                istargetID_ref = istargetID
                isisomerID_ref = isisomerID
            elif isligID != isligID_ref or istargetID != istargetID_ref or isisomerID != isisomerID_ref:
                raise ValueError("%s*/%s*/%s* architecture architecture inconsistent between folders!"%(ligdir_prefix,tardir_prefix,isodir_prefix))
    else:
        iscwd = True
        isligID = False
        istargetID = False
        isisomerID = False

    return iscwd, isligID, istargetID, isisomerID

dirs = []
for dir in args.dirs:
    if os.path.isdir(dir+'/poses'):
        dirs.append(os.path.relpath(dir))
    else:
        raise ValueError('directory '+dir+'/poses does not exist!')
iscwd, isligID, istargetID, isisomerID = check_directories(dirs)

# check options relative to best poses extraction
scoring_functions_all = []
if args.sbcd:
    scoring_functions = args.sbcd
    if len(args.sbcd) < 2:
        raise ValueError('Number of functions for score-based consensus docking should be at least 2!')
    resolve_with = args.sbcd[0] # used to decide which pose to extract when selecting over all targets and isomers
elif args.cd:
    scoring_functions = None
    if len(args.cd) < 2:
        raise ValueError('Number of programs for consensus docking should be at least 2!')
    resolve_with = 'score'
elif args.sf:
    scoring_functions = args.sf
    resolve_with = 'score_multi'

features = ['file_l', 'file_r', 'site', 'program', 'instance', 'index_pose']
if args.csvfile:
    if not os.path.isfile(args.csvfile):
        raise IOError("csvfile %s not found!"%args.csvfile)

features_ids = []
if isligID:
    features_ids += ['ligID']
if istargetID: 
    features_ids += ['targetID']
if isisomerID:
    features_ids += ['isomerID']

poses = []
best_poses = []
for jdx, dir in enumerate(dirs):
    posedir = dir + '/poses'
    ligID, targetID, isomerID = get_IDs(dir, isligID, istargetID, isisomerID)

    info_dir = {}
    for ft in features_ids + features + ['score']:
        info_dir[ft] = []

    # get location of poses and receptor files
    with open(posedir+'/info.dat', 'r') as inff:
        # skip the first two lines
        inff.next()
        inff.next()

        for line in inff:
            program, nposes, firstidx, site = line.strip().split(',')
            firstidx = int(firstidx)
            nposes = int(nposes)
            instance = program
            if site:
                instance += '.' + site
            poses_idxs = range(firstidx, firstidx+nposes)

            nposes = len(poses_idxs)
            for index, idx in enumerate(poses_idxs):
                file_l = posedir + '/lig-%s.mol2'%idx
                if os.path.isfile(file_l):
                    info_dir['file_l'].append(os.path.relpath(file_l))
                else:
                    raise IOError("File %s does not exist!"%file_l)
                info_dir['file_r'].append(os.path.relpath(posedir+'/rec.pdb'))
                info_dir['site'].append(site)
                info_dir['program'].append(program)
                info_dir['instance'].append(instance)
                info_dir['index_pose'].append(index)
                if isligID:    
                    info_dir['ligID'].append(ligID)
                if istargetID:
                    info_dir['targetID'].append(targetID)
                if isisomerID:
                    info_dir['isomerID'].append(isomerID)

            nscores = 0
            # extract original scores
            with open(dir+'/'+instance+'/score.out', 'r') as sout:
                for line_s in sout:
                    nscores += 1
                    info_dir['score'].append(float(line_s.strip()))
                if nscores != nposes:
                    raise ValueError("Number of poses different from number of scores (%s/%s)"%(dir,instance))

    # extract all scores
    for score_file in sorted(glob(dir+'/rescoring/*.score')):
        sf = os.path.basename(score_file).split('.')[0]
        if jdx == 0:
            scoring_functions_all.append(sf)
        elif sf not in scoring_functions_all:
            raise ValueError("%s scores not computed in every directory!")
        info_dir[sf] = []
        with open(score_file, 'r') as sout:
            for line_s in sout:
                info_dir[sf].append(float(line_s))

    df_dir = pd.DataFrame(info_dir)
    if args.docking_programs: 
        df_dir = df_dir[df_dir['instance'].isin(args.docking_programs)]
    poses.append(df_dir)

    # TODO: enable to use combination of scores or consensus docking over multiple targets and isomers
    if args.sf:
        # extract best pose from best linear combination of scores
        df_dir['score_multi'] = df_dir[scoring_functions].sum(axis=1)
        best_poses_dir = df_dir.loc[df_dir['score_multi'].idxmin].to_frame().T

        # for some reason, converting to DataFrame make everything become an object, so I need to convert everything back
        best_poses_dir = best_poses_dir.astype({ft:str for ft in features})
        best_poses_dir = best_poses_dir.astype({sf:float for sf in ['score','score_multi']+scoring_functions_all})

    elif args.sbcd:
        consensus = True
        file_l_ref = None
        # extract best pose if consensus between poses with best score after rescoring is reached
        for idx, sf in enumerate(scoring_functions):
            row = df_dir.loc[df_dir[sf].idxmin]
            if idx == 0:
                firstrow = row.copy()
                file_l_ref = row['file_l']
            elif idx > 0:
                rmsd = compute_rmsd(row['file_l'], file_l_ref)
                if rmsd > args.cutoff:
                    consensus = False
                    break
        if consensus:
            best_poses_dir = pd.DataFrame(firstrow).T
            # for some reason, converting to DataFrame make everything become an object, so I need to convert everything back
            best_poses_dir = best_poses_dir.astype({ft:str for ft in features})
            best_poses_dir = best_poses_dir.astype({sf:float for sf in ['score']+scoring_functions_all})

    elif args.cd:
        # extract best pose if consensus between poses with best docking score is met
        df_dir_prgms = df_dir[df_dir['program'].isin(args.cd)]
        df_dir_prgms_groupby = df_dir_prgms.groupby(['program'])
        rows = df_dir_prgms.loc[df_dir_prgms_groupby['score'].idxmin]

        consensus = True
        for idx, (index, row) in enumerate(rows.iterrows()):
            if idx == 0:
                firstrow = row.copy()
                file_l_ref = row['file_l']
            elif idx > 0:
                rmsd = compute_rmsd(row['file_l'], file_l_ref)
                if rmsd > args.cutoff:
                    consensus = False
                    break
        if consensus:
            best_poses_dir = rows[rows['program']==args.cd[0]]

    if args.sf or consensus:
        best_poses.append(best_poses_dir)

def add_names(csvfile, df):
    df_ligands = pd.read_csv(csvfile)
    if 'isomer' in df_ligands:
        df_ligands = df_ligands[df_ligands['isomer']==1]
    df = df.merge(df_ligands[['ligID', 'name']], on='ligID')
    return df

if poses:
    poses = pd.concat(poses).reset_index()
    if args.csvfile and isligID:
        poses = add_names(args.csvfile, poses)
else:
    sys.exit("No poses to extract!")

if best_poses:
    best_poses = pd.concat(best_poses).reset_index()
    if args.csvfile and isligID:
        best_poses = add_names(args.csvfile, best_poses)
    # combine results over all isomers.
    if args.combine_isomers:
        if isisomerID:
            groupby_columns = []
            if isligID:
                groupby_columns += ['ligID']
            if istargetID:
                groupby_columns += ['targetID']
            if groupby_columns:
                best_poses_groupby = best_poses.groupby(groupby_columns)
                best_poses = best_poses.loc[best_poses_groupby[resolve_with].idxmin]
            else:
                best_poses = best_poses.loc[best_poses[resolve_with].idxmin].T
    # combine results over all targets.
    if args.combine_targets:
        if istargetID:
            groupby_columns = []
            if isligID:
                groupby_columns += ['ligID']
            if isisomerID and not args.combine_isomers:
                groupby_columns += ['isomerID']
            if groupby_columns:
                best_poses_groupby = best_poses.groupby(groupby_columns)
                best_poses = best_poses.loc[best_poses_groupby[resolve_with].idxmin]
            else:
                best_poses = best_poses.loc[best_poses[resolve_with].idxmin].T
else:
    best_poses = None

shutil.rmtree(args.resultsdir, ignore_errors=True)
os.mkdir(args.resultsdir)

features_csv = list(features_ids)
if args.csvfile:
    features_csv.append('name')
features_csv += features + scoring_functions_all + ['score']

if args.sf:
    features_csv.append('score_multi')

features_csv.remove('instance')
features_csv.remove('index_pose')

# save poses to .csv file
poses[features_csv].to_csv(args.resultsdir+'/poses.csv', index=False, float_format='%.3f')

if args.combine_isomers and isisomerID:
    features_ids.remove('isomerID')
if args.combine_targets and istargetID:
    features_ids.remove('targetID')

if best_poses is not None:
    best_poses[features_csv].to_csv(args.resultsdir+'/best_poses.csv', index=False, float_format='%.3f')

    for idx, row in best_poses.iterrows():
        newdir = args.resultsdir + '/' + '/'.join(row[features_ids])
        if not os.path.isdir(newdir):
            os.makedirs(newdir)
        file_l = row['file_l']
        file_r = row['file_r']
        instance = row['instance']
        index = row['index_pose']
    
        shutil.copyfile(file_l, newdir+'/ligand.mol2')
        origindir = '/'.join(file_l.split('/')[:-2])
    
        poses_idxs = []
        for filename in glob(origindir+'/'+instance+'/lig-*.mol2'):
            poses_idxs.append(int((filename.split('.')[-2]).split('-')[-1]))
        poses_idxs = sorted(poses_idxs)
        pose_idx = poses_idxs[int(index)]
    
        if os.path.isdir(origindir+'/'+instance+'/origin'):
            shutil.copyfile(origindir+'/'+instance+'/origin/lig-%i.mol2'%pose_idx, newdir+'/ligand_orig.mol2')
        shutil.copyfile(file_r, newdir+'/protein.pdb')
else:
    sys.exit("No poses found for the selected method of extraction!")
