from pathlib import Path
import os
import json
import random
import tempfile

import polars as pl

# -------------------------------------------------------------------------------------- #
#                                     Configuration                                    #
# -------------------------------------------------------------------------------------- #

# Load configuration
configfile: "db_builder/config.yaml"

# --- Base Directories ---
# Use config values with defaults if they weren't specified in the config file
# These are needed to format the paths loaded from the config below
OUTPUT_DIR = Path(config.get("output_dir", "out/alignment_db"))
WORK_DIR = Path(config.get("work_dir", "tmp/alignment_pipeline"))
TMP_DIR = Path(tempfile.mkdtemp(prefix="flatprot_snk_")) # Use a unique temp dir per run

# --- Convenience function to format paths from config ---
def format_path(key: str, **kwargs):
    """Formats a path from the config using base dirs and wildcards."""
    return config["paths"][key].format(
        work_dir=WORK_DIR,
        output_dir=OUTPUT_DIR,
        tmp_dir=TMP_DIR,
        **kwargs
    )

# --- Pipeline Settings ---
RETRY_COUNT = config.get("pdb_retry_count", 3)
TIMEOUT = config.get("pdb_network_timeout", 20)
TEST_MODE = config.get("test_mode", False)
NUM_FAMILIES = config.get("num_families", 401)
RANDOM_SEED = config.get("random_seed", 42)
CONCURRENT_DOWNLOADS = config.get("concurrent_downloads", 5)
FOLDSEEK_PATH = config.get("foldseek_path", "foldseek")
FOLDSEEK_SMALL_THREADS = config.get("foldseek_small_threads", 4)
FOLDSEEK_LARGE_THREADS = config.get("foldseek_large_threads", 10)

# -------------------------------------------------------------------------------------- #
#                             Directory Creation                                       #
# -------------------------------------------------------------------------------------- #

# Use paths directly from config, formatted with base dirs
LOG_DIR = Path(format_path("log_dir"))
PDB_FILES_DIR = Path(format_path("pdb_files_dir"))
DOMAIN_FILES_DIR = Path(format_path("domain_files_dir"))
REP_DOMAINS_DIR = Path(format_path("representative_domains_dir"))
MATRICES_DIR = Path(format_path("matrices_dir"))
REPORT_DIR = Path(format_path("report_dir"))
FOLDSEEK_DB_DIR = Path(format_path("foldseek_db_dir"))

# Create output/working directories (TMP_DIR is created by tempfile.mkdtemp)
os.makedirs(OUTPUT_DIR, exist_ok=True)
os.makedirs(WORK_DIR, exist_ok=True)
os.makedirs(LOG_DIR, exist_ok=True)
os.makedirs(PDB_FILES_DIR, exist_ok=True)
os.makedirs(DOMAIN_FILES_DIR, exist_ok=True) # Base domain dir
os.makedirs(REP_DOMAINS_DIR, exist_ok=True)
os.makedirs(MATRICES_DIR, exist_ok=True)
os.makedirs(REPORT_DIR, exist_ok=True)
os.makedirs(FOLDSEEK_DB_DIR, exist_ok=True)

# -------------------------------------------------------------------------------------- #
#                                Wildcard Constraints                                    #
# -------------------------------------------------------------------------------------- #

wildcard_constraints:
    sf_id=r"\d+",
    pdb_id="[^/]+",
    chain="[^/]+",
    start=r"-?\d+",
    end=r"-?\d+"

# -------------------------------------------------------------------------------------- #
#                               Helper Functions                                         #
# -------------------------------------------------------------------------------------- #

def get_all_superfamily_ids():
    """Get list of all superfamilies initially identified from SCOP."""
    superfamilies_path = format_path("superfamilies")
    try:
        df = pl.read_csv(superfamilies_path, separator="\t")
        sf_ids = df["sf_id"].unique()
        if TEST_MODE:
            # Ensure we seed the shuffle if in test mode
            return sf_ids.shuffle(seed=RANDOM_SEED).to_list()[:NUM_FAMILIES]
        return sf_ids.to_list()
    except FileNotFoundError:
        return [] # Expected during initial DAG construction
    except Exception as e:
        print(f"Error reading superfamilies file {superfamilies_path}: {e}")
        return []

def get_valid_superfamily_ids_from_file():
    """Reads the list of valid superfamily IDs from the intermediate file."""
    valid_sf_list_path = format_path("valid_sf_list")
    try:
        with open(valid_sf_list_path, 'r') as f:
            ids = [int(line.strip()) for line in f if line.strip()]
        return ids
    except FileNotFoundError:
        return [] # Expected during initial DAG construction
    except Exception as e:
        print(f"Error reading valid superfamilies file {valid_sf_list_path}: {e}")
        return []

def get_all_structure_ids():
    """Get all unique PDB IDs required based on the superfamilies list (respecting TEST_MODE)."""
    superfamilies_path = format_path("superfamilies")
    try:
        # Use the potentially filtered list of superfamilies (respecting TEST_MODE)
        sf_ids_to_process = get_all_superfamily_ids()
        if not sf_ids_to_process:
             return [] # If no superfamilies, no structures

        df = pl.read_csv(superfamilies_path, separator="\t")
        # Filter the full list based on the superfamilies we are actually processing
        return df.filter(pl.col("sf_id").is_in(sf_ids_to_process))["pdb_id"].unique().to_list()
    except FileNotFoundError:
        return [] # Expected during initial DAG construction
    except Exception as e:
        print(f"Error reading superfamilies file {superfamilies_path}: {e}")
        return []

def get_domains_for_superfamily(wildcards):
    """Get expected domain files for a superfamily based on SCOP data."""
    superfamilies_path = format_path("superfamilies")
    # Checkpoint dependency removed, rely on rule dependencies
    # checkpoints.download_all_pdbs.get()

    try:
        df = pl.read_csv(superfamilies_path, separator="\t")
    except FileNotFoundError:
        return [] # If SUPERFAMILIES doesn't exist yet
    except Exception as e:
        print(f"Error reading SUPERFAMILIES file {superfamilies_path}: {e}")
        return []

    # Filter for this superfamily
    sf_df = df.filter(pl.col("sf_id") == int(wildcards.sf_id))

    # Generate list of expected domain files using the config pattern
    domain_files = []
    for row in sf_df.iter_rows(named=True):
        domain_file = format_path(
            "domain_file_pattern",
            sf_id=wildcards.sf_id,
            pdb_id=row["pdb_id"],
            chain=row["chain"],
            start=row["start_res"],
            end=row["end_res"]
        )
        domain_files.append(domain_file)

    return domain_files

# -------------------------------------------------------------------------------------- #
#                                      Target rule                                       #
# -------------------------------------------------------------------------------------- #

rule all:
    input:
        format_path("foldseek_db"),
        format_path("alignment_db"),
        format_path("database_info")

# -------------------------------------------------------------------------------------- #
#                                Process SCOP information                                #
# -------------------------------------------------------------------------------------- #

rule download_scop:
    output:
        scop_file = format_path("scop_file")
    log:
        # Use format_path for consistency, though LOG_DIR is already defined
        format_path("log_dir") + "/download_scop.log"
    shell:
        """
        wget -O {output.scop_file} https://www.ebi.ac.uk/pdbe/scop/files/scop-cla-latest.txt 2>&1 | tee {log}
        """

rule parse_scop:
    input:
        scop_file = format_path("scop_file")
    output:
        superfamilies = format_path("superfamilies"),
        report = report(format_path("scop_report_pattern"), category="SCOP Analysis")
    log:
        format_path("log_dir") + "/parse_scop.log"
    script:
        "scop_parse.py"

# -------------------------------------------------------------------------------------- #
#                           Download PDBs and extract domains                            #
# -------------------------------------------------------------------------------------- #

rule download_pdb:
    output:
        # Use config patterns directly
        struct_file = format_path("pdb_struct_pattern", pdb_id="{pdb_id}"),
        status = format_path("pdb_status_pattern", pdb_id="{pdb_id}")
    params:
        pdb_id = "{pdb_id}",
        retry_count = RETRY_COUNT,
        timeout = TIMEOUT,
        # Pass proxy from main config if needed by script
        proxy_url = config.get("proxy_url")
    resources:
        network_connections = CONCURRENT_DOWNLOADS # Control concurrency here
    log:
        format_path("log_dir") + "/download_pdb_{pdb_id}.log"
    script:
        "pdb_download.py"

checkpoint download_all_pdbs:
    input:
        # Depend on the status files, not structure files, as status indicates attempt completion
        status_files = expand(format_path("pdb_status_pattern", pdb_id="{pdb_id}"), pdb_id=get_all_structure_ids()),
        superfamilies = format_path("superfamilies")
    output:
        flag = touch(format_path("pdb_flag_pattern")),
        report = report(format_path("pdb_report_pattern"), category="PDB Download Report")
    params:
        # Pass the directory path needed by the report script
        pdb_dir = format_path("pdb_files_dir")
    log:
        format_path("log_dir") + "/download_all_pdbs_report.log"
    script:
        "pdb_report.py"

rule extract_domain:
    input:
        struct_file = format_path("pdb_struct_pattern", pdb_id="{pdb_id}")
    output:
        # Use config pattern
        domain_file = format_path("domain_file_pattern", sf_id="{sf_id}", pdb_id="{pdb_id}", chain="{chain}", start="{start}", end="{end}")
    params:
        sf_id = "{sf_id}",
        chain = "{chain}",
        start = "{start}",
        end = "{end}"
    log:
        format_path("log_dir") + "/extract_domain_{sf_id}_{pdb_id}_{chain}_{start}_{end}.log"
    script:
        "domain_extract.py"

rule extract_all_domains_for_superfamily:
    input:
        domain_files = get_domains_for_superfamily
    output:
        # Use config pattern, touch creates the file
        flag = touch(format_path("domain_flag_pattern", sf_id="{sf_id}"))
    log:
        format_path("log_dir") + "/extract_all_domains_{sf_id}.log"
    run:
        # Use python logging/print directly instead of shell echo
        log_file_path = log[0]
        # Check if any domain files expected by get_domains_for_superfamily were actually generated
        if not input.domain_files:
            # Log using print (Snakemake redirects to log file)
            print(f"Warning: No domain files were expected for superfamily {wildcards.sf_id} based on SCOP data.")
        else:
            # Check how many of the expected files actually exist after extract_domain rules ran
            try:
                existing_files = [f for f in input.domain_files if os.path.exists(f)]
                num_expected = len(input.domain_files)
                num_existing = len(existing_files)
                if num_existing == 0:
                    print(f"Warning: No domain files were successfully extracted for superfamily {wildcards.sf_id}. Flag will be created, but directory may be empty.")
                else:
                    print(f"Completed domain extraction checks for superfamily {wildcards.sf_id}. {num_existing}/{num_expected} expected files present.")
            except Exception as e:
                 # Log potential errors during os.path.exists (e.g., permissions, long paths)
                 print(f"Error checking existence of domain files for superfamily {wildcards.sf_id}: {e}")
                 # Re-raise to make the job fail clearly
                 raise
        # Flag file is touched by Snakemake on successful completion of the run block

# -------------------------------------------------------------------------------------- #
#                         Identify Valid Superfamilies                                   #
# -------------------------------------------------------------------------------------- #

rule filter_valid_superfamilies:
    input:
        # Depends on the flags created by extract_all_domains_for_superfamily
        flags = expand(format_path("domain_flag_pattern", sf_id="{sf_id}"), sf_id=get_all_superfamily_ids())
    output:
        valid_list = format_path("valid_sf_list")
    log:
        format_path("log_dir") + "/filter_valid_superfamilies.log"
    run:
        valid_sf_ids = []
        # Iterate through the expected flag files based on initial SCOP parse
        all_sf_ids = get_all_superfamily_ids()
        log_file = log[0]

        with open(log_file, "w") as log_f:
            log_f.write(f"Checking {len(all_sf_ids)} potential superfamilies...\n")
            for sf_id in all_sf_ids:
                flag_file = format_path("domain_flag_pattern", sf_id=sf_id)
                # Check if the flag file exists (meaning the extraction rule finished)
                # Add an extra check: does the domain subdir actually contain any .cif files?
                domain_subdir = Path(format_path("domain_subdir_pattern", sf_id=sf_id))
                if os.path.exists(flag_file) and domain_subdir.is_dir() and any(domain_subdir.glob("*.cif")):
                    valid_sf_ids.append(str(sf_id))
                    log_f.write(f"Superfamily {sf_id}: Valid (flag exists, domains found).\n")
                elif os.path.exists(flag_file):
                    log_f.write(f"Superfamily {sf_id}: Invalid (flag exists, but no *.cif files found in {domain_subdir}).\n")
                else:
                    log_f.write(f"Superfamily {sf_id}: Invalid (flag file {flag_file} does not exist).\n")

            with open(output.valid_list, "w") as f:
                for sf_id in valid_sf_ids:
                    f.write(f"{sf_id}\n")
            log_f.write(f"Found {len(valid_sf_ids)} valid superfamilies with extracted domains.\n")
            log_f.write(f"List saved to {output.valid_list}\n")

# -------------------------------------------------------------------------------------- #
#                  Create Foldseek DBs and Align Domains                                 #
# -------------------------------------------------------------------------------------- #

rule create_foldseek_db_per_sf:
    input:
        # Depend on the extraction flag for this SF
        flag = format_path("domain_flag_pattern", sf_id="{sf_id}"),
        # Depend on the *existence* of cif files (implicitly via domain_subdir)
        domains = lambda wildcards: list(Path(format_path("domain_subdir_pattern", sf_id=wildcards.sf_id)).glob("*.cif"))
    output:
        # Foldseek creates multiple files, target the main db file
        db = format_path("tmp_foldseek_db_pattern", sf_id="{sf_id}")
    params:
        foldseek_bin = FOLDSEEK_PATH,
        domain_dir = format_path("domain_subdir_pattern", sf_id="{sf_id}"),
        tmp_db_prefix = format_path("tmp_foldseek_db_pattern", sf_id="{sf_id}"),
        tmp_db_dir = format_path("tmp_foldseek_dir_pattern", sf_id="{sf_id}") # Dir for temp files
    threads:
        FOLDSEEK_SMALL_THREADS
    log:
        format_path("log_dir") + "/create_foldseek_db_per_sf_{sf_id}.log"
    shell:
        """
        # Ensure the temporary directory for this SF exists
        mkdir -p {params.tmp_db_dir}
        # Create the Foldseek database from CIF files in the domain subdirectory
        # Use the specified temp directory for intermediate files
        {params.foldseek_bin} createdb {params.domain_dir} {params.tmp_db_prefix} --threads {threads} > {log} 2>&1
        """

rule align_domains:
    input:
        # Depends on the temporary Foldseek DB for this SF
        db = format_path("tmp_foldseek_db_pattern", sf_id="{sf_id}"),
        # Include domains as input to ensure they exist before alignment
        domains = lambda wildcards: list(Path(format_path("domain_subdir_pattern", sf_id=wildcards.sf_id)).glob("*.cif"))
    output:
        alignment_file = format_path("matrix_file_pattern", sf_id="{sf_id}")
    params:
        foldseek_bin = FOLDSEEK_PATH,
        tmp_db_prefix = format_path("tmp_foldseek_db_pattern", sf_id="{sf_id}"),
        tmp_dir = format_path("tmp_foldseek_dir_pattern", sf_id="{sf_id}") # Pass tmp dir
    threads:
        FOLDSEEK_SMALL_THREADS
    log:
        format_path("log_dir") + "/align_domains_{sf_id}.log"
    shell:
        """
        # Perform all-vs-all alignment within the superfamily DB
        # Output format includes query, target, qstart, qend, tstart, tend, tseq, prob, alntmscore
        {params.foldseek_bin} easy-search {params.tmp_db_prefix} {params.tmp_db_prefix} {output.alignment_file} {params.tmp_dir}/aln_tmp_{wildcards.sf_id} \
            --format-output query,target,qstart,qend,tstart,tend,tseq,prob,alntmscore -e inf --threads {threads} > {log} 2>&1
        """

# -------------------------------------------------------------------------------------- #
#                   Select Representatives and Build Final DBs                           #
# -------------------------------------------------------------------------------------- #

rule select_representative:
    input:
        # Depends on the alignment results for this SF
        alignment_file = format_path("matrix_file_pattern", sf_id="{sf_id}"),
        # Depends on the valid SF list being created
        valid_list = format_path("valid_sf_list")
    output:
        representative_domain = format_path("representative_domain_pattern", sf_id="{sf_id}")
    params:
        # Pass the directory where original domain CIFs are located
        domain_dir = format_path("domain_subdir_pattern", sf_id="{sf_id}")
    log:
        format_path("log_dir") + "/select_representative_{sf_id}.log"
    script:
        "domain_select.py"

checkpoint collect_representatives:
    input:
        # Depend on all representative domain files being created for *valid* superfamilies
        representatives = expand(
            format_path("representative_domain_pattern", sf_id="{sf_id}"),
            sf_id=get_valid_superfamily_ids_from_file()
        ),
        valid_list = format_path("valid_sf_list")
    output:
        flag = touch(format_path("representative_flag_pattern"))
    log:
        format_path("log_dir") + "/collect_representatives.log"
    run:
        num_representatives = len(input.representatives)
        print(f"Collected {num_representatives} representative domains.")


rule create_foldseek_database:
    input:
        # Depend on the checkpoint flag ensuring all representatives exist
        flag = format_path("representative_flag_pattern"),
        # Explicitly list representative files as input for Foldseek
        representatives = expand(
            format_path("representative_domain_pattern", sf_id="{sf_id}"),
            sf_id=get_valid_superfamily_ids_from_file()
        )
    output:
        # Foldseek creates multiple files, specify the prefix
        db = format_path("foldseek_db")
    params:
        foldseek_bin = FOLDSEEK_PATH,
        # Pass the directory containing the representative CIF files
        representative_dir = format_path("representative_domains_dir"),
    threads:
        FOLDSEEK_LARGE_THREADS
    log:
        format_path("log_dir") + "/create_foldseek_database.log"
    shell:
        """
        # Create the final Foldseek database from all representative domains
        # Note: foldseek createdb takes the *directory* containing structures
        {params.foldseek_bin} createdb {params.representative_dir} {output.db} --threads {threads} > {log} 2>&1
        """

rule create_alignment_database:
    input:
        # Depend on the checkpoint flag ensuring all representatives exist
        flag = format_path("representative_flag_pattern"),
        # Get the list of representative files (needed by the script)
        representatives = expand(
            format_path("representative_domain_pattern", sf_id="{sf_id}"),
            sf_id=get_valid_superfamily_ids_from_file()
        ),
        # Pass the superfamilies file if needed by the script for metadata
        superfamilies = format_path("superfamilies")
    output:
        database = format_path("alignment_db"),
        database_info = format_path("database_info")
    params:
        # Pass the directory containing representatives to the script
        representative_dir = format_path("representative_domains_dir"),
        db_version = config.get("db_version", "1.0")
    log:
        format_path("log_dir") + "/create_alignment_database.log"
    script:
        "db_create.py"

# -------------------------------------------------------------------------------------- #
#                                     Cleanup Rule (Optional)                          #
# -------------------------------------------------------------------------------------- #
# Example of how to add the cleanup script as a rule, if desired
# rule clean:
#     run:
#         shell("python db_builder/clean_pipeline.py")

# -------------------------------------------------------------------------------------- #
#                                    On Success/Error Hooks                             #
# -------------------------------------------------------------------------------------- #

onsuccess:
    # Clean up temporary directory on successful run
    if TMP_DIR.exists():
        print(f"Workflow successful. Cleaning up temporary directory: {TMP_DIR}")
        shell(f"rm -rf {TMP_DIR}")

onerror:
    # Provide instructions on error
    print(f"Workflow failed. Check logs in: {LOG_DIR}")
    print(f"Temporary files kept for debugging in: {TMP_DIR}")
    print(f"Consider running 'python db_builder/clean_pipeline.py' and retrying.")

# -------------------------------------------------------------------------------------- #
#                                     Report Generation (Optional)                     #
# -------------------------------------------------------------------------------------- #

# report: format_path("report_dir") + "/pipeline_report.html"
