#!python

from hidt.utils import *
from hidt.model import EdgeGNN
from hidt.loss import *
from hidt.evaluation import *
import pandas as pd
import numpy as np
from hidt.load_hic_format import *
import os, argparse, sys
import pkg_resources

def load_TAD(TADfile, res):
    """
    Load TADs assuming the first 3 columns are chrom, x1, x2.
    Check that the file has at least 3 columns.
    """
    try:
        df = pd.read_csv(TADfile, delim_whitespace=True)
        if df.shape[1] < 3:
            raise ValueError(f"TAD file '{TADfile}' has fewer than 3 columns.")
        df = df.iloc[:, :3]
        df.columns = ['chrom', 'x1', 'x2']
        df['chrom'] = df['chrom'].astype(str).str.replace('chr', '', regex=False)
        df['x1'] = df['x1'] // res
        df['x2'] = df['x2'] // res
        TADInfo = {}
        chrom_list = []
        for chrom, group in df.groupby('chrom'):
            TADInfo[chrom] = group[['x1', 'x2']].reset_index(drop=True)
            chrom_list.append(chrom)
        return TADInfo, chrom_list
    except Exception as e:
        print(f"Error loading TAD file: {e}")
        sys.exit(1)

def scale_total_contacts(total_contacts, resolution):
    base_resolution = 25000
    if resolution == base_resolution:
        return total_contacts
    elif resolution > base_resolution:
        scale_factor = resolution // base_resolution
        return int(total_contacts * scale_factor)
    else:
        scale_factor = base_resolution // resolution
        return int(total_contacts // scale_factor)

def load_intra_counts_hic(hicfile):
    hic = hicstraw.HiCFile(hicfile)
    chromosomes = hic.getChromosomes()
    chrom_names = [chrom.name for chrom in chromosomes]
    chrom_names = [chrom for chrom in chrom_names if chrom not in ['All', 'Y', 'MT', 'ALL', 'chrY', 'chrM', 'M']]
    # total intra-chromosomal counts
    total_contacts = 0
    for chrom in chrom_names:
        result = hicstraw.straw('observed', 'NONE', hicfile, chrom, chrom, 'BP', resolution)
        for i in range(len(result)):
            total_contacts += result[i].counts
    # scale contacts based on resolution
    print(f"Hi-C file '{hicfile}' - Total intra contacts: {total_contacts}")
    total_contacts = scale_total_contacts(total_contacts, resolution)
    if total_contacts < 50000000:
        depth = 0
    elif 50000000 <= total_contacts < 100000000:
        depth = 1
    elif 100000000 <= total_contacts < 200000000:
        depth = 2
    elif 200000000 <= total_contacts < 250000000:
        depth = 3
    elif 250000000 <= total_contacts < 450000000:
        depth = 4
    elif 450000000 <= total_contacts < 500000000:
        depth = 5
    elif 500000000 <= total_contacts < 600000000:
        depth = 6
    else:
        depth = 7
    return depth


def load_total_counts_hic(hicfile):
    hic = hicstraw.HiCFile(hicfile)
    chromosomes = hic.getChromosomes()
    chrom_names = [chrom.name for chrom in chromosomes]
    chrom_names = [chrom for chrom in chrom_names if chrom not in ['All', 'Y', 'MT', 'ALL', 'chrY', 'chrM', 'M']]
    # total counts
    total_contacts = 0
    for chrom1 in chrom_names:
        for chrom2 in chrom_names:
            if chrom1 < chrom2:
                continue
            result = hicstraw.straw('observed', 'NONE', hicfile, chrom1, chrom2, 'BP', resolution)
            for i in range(len(result)):
                total_contacts += result[i].counts
    print(f"Hi-C file '{hicfile}' - Total contacts: {total_contacts}")
    total_contacts = scale_total_contacts(total_contacts, resolution)
    if total_contacts < 50000000:
        depth = 0
    elif 50000000 <= total_contacts < 100000000:
        depth = 1
    elif 100000000 <= total_contacts < 200000000:
        depth = 2
    elif 200000000 <= total_contacts < 300000000:
        depth = 3
    elif 300000000 <= total_contacts < 400000000:
        depth = 4
    elif 400000000 <= total_contacts < 650000000:
        depth = 5
    elif 650000000 <= total_contacts < 900000000:
        depth = 6
    else:
        depth = 7
    return depth


def check_file(path, desc, suffix=None):
    if not os.path.isfile(path):
        raise FileNotFoundError(f"{desc} not found: {path}")
    if suffix and not path.endswith(suffix):
        raise ValueError(f"{desc} must be a {suffix} file: {path}")

def check_resolution_in_hic(hic_path, resolution):
    hic = hicstraw.HiCFile(hic_path)
    available_res = hic.getResolutions()
    if resolution not in available_res:
        raise ValueError(f"Resolution {resolution} not found in {hic_path}. "
                         f"Available: {available_res}")

def getargs():
    ## Construct an ArgumentParser object for command-line arguments
    parser = argparse.ArgumentParser(description='Identify differential TADs from Hi-C contact maps.',
                                 formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    parser.add_argument('--hicfile1', help='Hi-C file (in .hic format) for condition 1.')
    parser.add_argument('--hicfile2', help='Hi-C file (in .hic format) for condition 2.')
    parser.add_argument('--TADfile', help='TAD boundary file used for differential analysis.')
    parser.add_argument('--res', type=int, help='Resolution of the Hi-C contact maps (e.g., 25000 for 25 kb).')
    parser.add_argument('--depth', type=str, default='intra', help='Method to compute sequencing depth: "intra" for intra-chromosomal counts or "total" for all contacts.')
    parser.add_argument('--output', help='Path to the output result file.')

    ## Parse the command-line arguments
    commands = sys.argv[1:]
    if not commands:
        commands.append('-h')
    args = parser.parse_args(commands)
    return args, commands


if __name__ == '__main__':
    args, commands = getargs()
    check_file(args.TADfile, "TAD file")
    check_file(args.hicfile1, "Hi-C file 1")
    check_file(args.hicfile2, "Hi-C file 2")
    check_resolution_in_hic(args.hicfile1, args.res)
    check_resolution_in_hic(args.hicfile2, args.res)
    
    # TAD_file = "/mnt/d/detectTAD/validation_sample/DipC/cortex_hipp_celltype/Cortical_L2a5_Pyramidal_Cell/Cortical_L2a5_Pyramidal_Cell.bed"
    # resolution = 50000
    # TADInfo, chroms = load_TAD(TAD_file, resolution)
    resolution = args.res
    result_file = args.output
    TADInfo, chroms = load_TAD(args.TADfile, resolution)
    hicfile1 = args.hicfile1
    hicfile2 = args.hicfile2

    # load hic file
    # hicfile_1 = "/mnt/d/detectTAD/validation_sample/DipC/cortex_hipp_celltype/Cortical_L2a5_Pyramidal_Cell.hic"
    # hicfile_2 = "/mnt/d/detectTAD/validation_sample/DipC/cortex_hipp_celltype/Cortical_L6_Pyramidal_Cell.hic"
    # result_file = "/mnt/d/detectTAD/validation_sample/HiDT_rep_result/DipC/Cortical_L2a5_Pyramidal_Cell_result.txt"
    if args.depth == 'intra':
        hic1_depth = load_intra_counts_hic(hicfile1)
        hic2_depth = load_intra_counts_hic(hicfile2)
    elif args.depth == 'total':
        hic1_depth = load_total_counts_hic(hicfile1)
        hic2_depth = load_total_counts_hic(hicfile2)
    chrom_bins = load_binsNum(hicfile1, resolution)
    exclude_chroms = {'All', 'Y', 'MT', 'ALL', 'chrY', 'chrM', 'M'}
    adjs = []
    labels = []
    depths = []
    chrom_list = []
    start_list = []
    end_list = []
    for i in range(len(chroms)):
        chrom = chroms[i]
        if any(exclude == chrom for exclude in exclude_chroms):
            continue
        print(f"Processing chromosome: {chrom}")
        normed_mat_1 = dumpMatrix(chrom, chrom_bins[chrom], resolution, hicfile1)
        normed_mat_2 = dumpMatrix(chrom, chrom_bins[chrom], resolution, hicfile2)
        positions = TADInfo[chrom]
        for start, end in zip(positions['x1'], positions['x2']):
            normed_M1 = normed_mat_1[start:end, start:end].toarray()
            normed_M2 = normed_mat_2[start:end, start:end].toarray()
            if filter_matrix(normed_M1) and filter_matrix(normed_M2):
                # save TAD information
                chrom_list.append(chrom)
                start_list.append(start*resolution)
                end_list.append(end*resolution)
                # save graphs
                adjs.append((normed_M1, normed_M2))
                depths.append((hic1_depth, hic2_depth))
                labels.append(-1)
    valid_graphs = adjs
    valid_labels = labels
    valid_depths = depths

    # run model
    use_cuda = torch.cuda.is_available()
    device = torch.device('cuda:0' if use_cuda else 'cpu')
    edge_feature_dim = 8
    model = EdgeGNN(node_hidden_dims=[8, 64, 16, 64],
                    edge_feature_dim=edge_feature_dim,
                    dropout=0.5,
                    nheads=[1],
                    alpha=0.2)
    model_path = pkg_resources.resource_filename('hidt', 'pretrained_model.pth')
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.to(device)
    model.eval()
    batch_size = 40
    simis = []
    all_label = []
    with torch.no_grad():
        accumulated_pair_auc = []
        for k_iter in range(0, len(valid_graphs), batch_size):
            batch_graphs, batch_features, batch_labels, graphs_idx, depth_idx = generate_valid_batches(k_iter, 
                                                                                                    valid_graphs,
                                                                                                    valid_labels,
                                                                                                    valid_depths,
                                                                                                    batch_size)
            batch_labels = np.array(batch_labels)
            cur_edge_features, cur_node_features, cur_graphs_idx, cur_depth_idx, cur_batch_labels = get_graphs(batch_graphs, 
                                                                                        batch_features, 
                                                                                        graphs_idx, 
                                                                                        depth_idx, 
                                                                                        batch_labels, 
                                                                                        edge_feature_dim)
        
            graph_states, edges = model(cur_node_features.to(device), 
                                cur_edge_features.to(device), 
                                cur_graphs_idx.to(device),
                                cur_depth_idx.to(device))
            x, y = reshape_and_split_tensor(graph_states, 2)
            similarity = compute_similarity(x, y)
            for elem in similarity:
                simis.append(-elem.item())
            for cur_label in cur_batch_labels:
                all_label.append(cur_label.item())

    count = 0
    pos = []
    neg = []
    pos_count = 0
    for i in range(len(simis)):
        pos_count += 1 
        if simis[i] > 1:
            count += 1
    print(
    f"Total TADs: {pos_count}, "
    f"Differential TADs: {count}, "
    f"Percentage: {count / pos_count * 100:.2f}%"
    )
    # Save result
    with open(result_file, 'w') as out:
        for i in range(len(simis)):
            out.write(chrom_list[i] + '\t' + str(start_list[i]) + '\t' + str(end_list[i]) + '\t' + str(simis[i]) + '\n')
    print(f"Results have been saved to '{result_file}'.")
