#!python

import os
import glob
import argparse
import logging
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.patches import Patch
import seaborn as sns
import numpy as np 
import shutil
import textwrap
import math
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from scipy import stats
from scipy.stats import mannwhitneyu, kruskal
from scipy.stats import pearsonr, spearmanr
from datetime import datetime
import itertools
import warnings
warnings.filterwarnings('ignore')

# Set up logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")

def convert_tab_to_csv(tab_file, csv_file):
    """Convert a .tab file to .csv format.

    Args:
        tab_file (str): Path to the .tab file.
        csv_file (str): Path to save the .csv file.
    """
    try:
        df = pd.read_csv(tab_file, sep="\t")
        df.to_csv(csv_file, index=False)
        logging.info(f"Converted {tab_file} to {csv_file}")
    except Exception as e:
        logging.error(f"Error converting {tab_file} to CSV: {e}")
        raise


def load_and_merge_data(ncbi_clean_path, abricate_summary_file):
    """Load and merge NCBI and Abricate data."""
    try:
        ncbi_clean_df = pd.read_csv(ncbi_clean_path)
        abricate_summary_df = pd.read_csv(abricate_summary_file)
        
        # Extract Assembly Accession from the '#File' column
        file_col = "#File" if "#File" in abricate_summary_df.columns else "#FILE"
        abricate_summary_df["Assembly Accession"] = abricate_summary_df[file_col].str.extract(r"(GCF_\d+\.\d+)")
        
        # Merge dataframes
        merged_df = pd.merge(ncbi_clean_df, abricate_summary_df, on="Assembly Accession", how="left")
        merged_df.fillna("0", inplace=True)
        
        return merged_df
    except Exception as e:
        logging.error(f"Error loading or merging data: {e}")
        raise


def save_merged_data(merged_df, output_dir, output_filename):
    """Save the merged dataframe to a CSV file."""
    os.makedirs(output_dir, exist_ok=True)
    output_file = os.path.join(output_dir, output_filename)
    merged_df.to_csv(output_file, index=False)
    logging.info(f"Merged file saved to: {output_file}")


def convert_to_tidy_format(df):
    """Convert the merged dataframe to tidy format."""
    try:
        if df.shape[1] < 14:
            raise ValueError("Not enough columns in the dataframe.")
        
        # Extract metadata columns
        id_vars = df.columns[:15]
        
        # Melt the dataframe
        df_tidy = df.melt(id_vars=id_vars, var_name="Gene", value_name="Identity")
        df_tidy["Identity"] = pd.to_numeric(df_tidy["Identity"], errors="coerce").fillna(0)
        
        # Add Presence column
        df_tidy["Presence"] = df_tidy["Identity"].apply(lambda x: 0 if x == 0 else 1)
        
        return df_tidy
    except Exception as e:
        logging.error(f"Error converting to tidy format: {e}")
        raise


def analyze_gene_presence(df, output_dir, base_name, fig_format):
    """Analyze gene presence and generate plots."""
    try:
        # Check if RESISTANCE column exists
        group_cols = ["Gene"]
        if "RESISTANCE" in df.columns:
            group_cols = ["RESISTANCE", "Gene"]

        # Group and calculate presence
        gene_presence = df.groupby(group_cols)["Presence"].sum().reset_index()
        gene_presence = gene_presence.sort_values(by="Presence", ascending=False)

        # Calculate percentage presence
        total_biosamples = df["Assembly BioSample Accession"].nunique()
        gene_presence["Percentage"] = (gene_presence["Presence"] / total_biosamples) * 100

        # Save count and percentages
        count_percentage_file = os.path.join(output_dir, f"{base_name}_gene_presence_count_percentage.csv")
        gene_presence.to_csv(count_percentage_file, index=False)
        logging.info(f"Count and percentages saved to {count_percentage_file}")

        # Generate plots
        generate_lollipop_plot(gene_presence, output_dir, base_name, fig_format)
        generate_percentage_bar_plot(gene_presence, output_dir, base_name, fig_format)
        #Generate interactive gene_presence/percentage plot by plotly
        gene_presence_plot_plotly(gene_presence, output_dir, base_name)
        generate_percentage_bar_plot_plotly(gene_presence, output_dir, base_name)

    except Exception as e:
        logging.error(f"Error analyzing gene presence: {e}")
        raise


def generate_lollipop_plot(gene_presence, output_dir, base_name, fig_format):
    """Generate a lollipop plot with grid faceting by RESISTANCE and adaptive layout."""
    try:
        if "RESISTANCE" in gene_presence.columns:
            sns.set(style="whitegrid")

            # Sort and group data by RESISTANCE
            grouped = gene_presence.groupby("RESISTANCE")
            resistance_classes = list(grouped.groups.keys())
            num_categories = len(resistance_classes)

            col_wrap = 8  # Number of columns per row
            n_rows = math.ceil(num_categories / col_wrap)
            facet_height = 4
            facet_width = 6

            fig, axes = plt.subplots(n_rows, col_wrap, figsize=(col_wrap * facet_width, n_rows * facet_height), squeeze=False)
            axes = axes.flatten()
            max_presence = gene_presence["Presence"].max()

            for i, (res_class, data) in enumerate(grouped):
                ax = axes[i]
                data = data.sort_values("Presence", ascending=False)

                sns.barplot(
                    data=data,
                    x="Gene",
                    y="Presence",
                    palette="Set2",
                    ax=ax
                )

                # Annotate bars
                for idx, row in enumerate(data.itertuples()):
                    ax.text(
                        idx, row.Presence + max_presence * 0.02,
                        f"{int(row.Presence)}",
                        ha='center', va='bottom', fontsize=8, rotation=90
                    )

                ax.set_title(res_class, fontsize=11, weight="bold")
                ax.set_xticks(range(len(data["Gene"])))
                ax.set_xticklabels(data["Gene"], rotation=90, ha='center', fontsize=8)
                ax.set_ylim(0, max_presence * 1.25)
                ax.set_xlabel("")
                ax.set_ylabel("Presence")

            # Turn off unused subplots
            for j in range(i + 1, len(axes)):
                fig.delaxes(axes[j])

            plt.tight_layout(h_pad=2.5, w_pad=1.5)
            plt.suptitle("Gene Presence by Resistance Category", fontsize=16, y=1.02)

            plot_file = os.path.join(output_dir, f"Resistance_gene_presence.{fig_format}")
            plt.savefig(plot_file, dpi=300, bbox_inches="tight", format=fig_format)
            plt.close()
            logging.info(f"Faceted lollipop plot saved to {plot_file}")

        else:
            # Fallback if RESISTANCE column is missing
            num_genes = len(gene_presence)
            width = max(12, num_genes * 0.4)
            height = 8

            plt.figure(figsize=(width, height))
            markerline, stemlines, baseline = plt.stem(
                gene_presence["Gene"], gene_presence["Presence"], basefmt=" ", use_line_collection=True)
            plt.setp(markerline, markersize=6)
            plt.setp(stemlines, linewidth=1.5)

            plt.xlabel("Gene", fontsize=14)
            plt.ylabel("Total Presence", fontsize=14)
            plt.xticks(rotation=90, fontsize=10)
            plt.ylim(0, gene_presence["Presence"].max() * 1.2)

            for i, (gene, presence) in enumerate(zip(gene_presence["Gene"], gene_presence["Presence"])):
                plt.text(i, presence + 1, str(presence), ha="center", va="bottom", fontsize=9, rotation=90)

            plt.title("Gene Presence Across Assemblies", fontsize=16)
            plt.tight_layout()

            plot_file = os.path.join(output_dir, f"Resistance_gene_presence.{fig_format}")
            plt.savefig(plot_file, format=fig_format, dpi=300, bbox_inches='tight')
            plt.close()
            logging.info(f"Lollipop plot saved to {plot_file}")

    except Exception as e:
        logging.error(f"Error generating lollipop plot: {e}")
        raise


def generate_percentage_bar_plot(gene_presence, output_dir, base_name, fig_format):
    """Generate a percentage lollipop plot with optional grid faceting by RESISTANCE."""
    try:
        if "RESISTANCE" in gene_presence.columns:
            sns.set(style="whitegrid")

            # Group and layout setup
            grouped = gene_presence.groupby("RESISTANCE")
            resistance_classes = list(grouped.groups.keys())
            num_categories = len(resistance_classes)

            col_wrap = 8
            n_rows = math.ceil(num_categories / col_wrap)
            facet_height = 4
            facet_width = 6

            fig, axes = plt.subplots(n_rows, col_wrap, figsize=(col_wrap * facet_width, n_rows * facet_height), squeeze=False)
            axes = axes.flatten()
            max_percent = gene_presence["Percentage"].max()

            for i, (res_class, data) in enumerate(grouped):
                ax = axes[i]
                data = data.sort_values("Percentage", ascending=False)

                sns.barplot(
                    data=data,
                    x="Gene",
                    y="Percentage",
                    palette="Set2",
                    ax=ax
                )

                # Annotate each bar with percentage and count
                for idx, row in enumerate(data.itertuples()):
                    ax.text(
                        idx, row.Percentage + max_percent * 0.01,
                        f"({row.Percentage:.1f}%, {row.Presence})",
                        ha='center', va='bottom', fontsize=8, rotation=90
                    )

                ax.set_title(res_class, fontsize=11, weight="bold")
                ax.set_xticks(range(len(data["Gene"])))
                ax.set_xticklabels(data["Gene"], rotation=90, ha='center', fontsize=8)
                ax.set_ylim(0, max_percent * 1.25)
                ax.set_xlabel("")

                # ✅ Remove y-axis on all but the first column in each row
                if i % col_wrap != 0:
                    ax.set_ylabel("")
                    ax.set_yticklabels([])
                else:
                    ax.set_ylabel("Percentage")
                    
            # Hide unused axes
            for j in range(i + 1, len(axes)):
                fig.delaxes(axes[j])

            plt.tight_layout(h_pad=2.5, w_pad=1.5)
            plt.suptitle("Gene Presence by Resistance Category", fontsize=16, y=1.02)

            plot_file = os.path.join(output_dir, f"Resistance_gene_percentage.{fig_format}")
            plt.savefig(plot_file, dpi=300, bbox_inches="tight", format=fig_format)
            plt.close()
            logging.info(f"Faceted percentage plot saved to {plot_file}")

        else:
            # Fallback if no RESISTANCE column
            num_genes = len(gene_presence)
            width = max(12, num_genes * 0.4)
            height = 8

            plt.figure(figsize=(width, height))
            markerline, stemlines, baseline = plt.stem(
                gene_presence["Gene"], gene_presence["Percentage"], basefmt=" ", use_line_collection=True)
            plt.setp(markerline, markersize=6)
            plt.setp(stemlines, linewidth=1.5)

            plt.xlabel("Gene", fontsize=14)
            plt.ylabel("Presence (%)", fontsize=14)
            plt.xticks(rotation=90, fontsize=10)
            plt.ylim(0, gene_presence["Percentage"].max() * 1.25)

            for i, (gene, percent) in enumerate(zip(gene_presence["Gene"], gene_presence["Percentage"])):
                presence = gene_presence.iloc[i]["Presence"]
                plt.text(i, percent + 1, f"({percent:.1f}%, {presence})", ha="center", va="bottom", fontsize=9, rotation=90)

            plt.title("Gene Presence Across Assemblies", fontsize=16)
            plt.tight_layout()

            plot_file = os.path.join(output_dir, f"Resistance_gene_percentage_lollipop.{fig_format}")
            plt.savefig(plot_file, format=fig_format, dpi=300, bbox_inches='tight')
            plt.close()
            logging.info(f"Percentage lollipop plot saved to {plot_file}")

    except Exception as e:
        logging.error(f"Error generating lollipop plot: {e}")
        raise


def gene_presence_plot_plotly(gene_presence, output_dir, base_name):
    """Generate an interactive lollipop plot with dropdown selection by RESISTANCE."""
    try:
        fig = go.Figure()

        if "RESISTANCE" in gene_presence.columns:
            grouped = gene_presence.groupby("RESISTANCE")
            resistance_classes = list(grouped.groups.keys())
            dropdown_buttons = []
            trace_visibility_template = []

            trace_counter = 0  # To track how many traces we've added

            for i, res_class in enumerate(resistance_classes):
                data = grouped.get_group(res_class)
                data_sorted = data.sort_values("Presence", ascending=False).reset_index(drop=True)

                # Main scatter plot
                fig.add_trace(
                    go.Scatter(
                        x=data_sorted["Gene"],
                        y=data_sorted["Presence"],
                        mode='markers+lines',
                        name=res_class,
                        visible=True if i == 0 else False,
                        marker=dict(
                            size=10,
                            color=px.colors.qualitative.Set2[i % len(px.colors.qualitative.Set2)],
                            line=dict(width=2, color='white')
                        ),
                        line=dict(
                            color=px.colors.qualitative.Set2[i % len(px.colors.qualitative.Set2)],
                            width=3
                        ),
                        text=[f"{int(val)}" for val in data_sorted["Presence"]],
                        textposition="top center",
                        textfont=dict(size=10),
                        showlegend=False
                    )
                )
                trace_counter += 1

                # Stem lines for each gene
                for gene, presence in zip(data_sorted["Gene"], data_sorted["Presence"]):
                    fig.add_trace(
                        go.Scatter(
                            x=[gene, gene],
                            y=[0, presence],
                            mode='lines',
                            line=dict(
                                color=px.colors.qualitative.Set2[i % len(px.colors.qualitative.Set2)],
                                width=2
                            ),
                            visible=True if i == 0 else False,
                            showlegend=False,
                            hoverinfo='skip'
                        )
                    )
                    trace_counter += 1

            # Now that we know how many traces were added, we can build dropdowns
            current_trace = 0
            for i, res_class in enumerate(resistance_classes):
                data = grouped.get_group(res_class)
                num_genes = len(data)
                visibility = [False] * trace_counter

                # 1 main trace + N stem traces
                visibility[current_trace] = True
                for j in range(1, num_genes + 1):
                    visibility[current_trace + j] = True

                dropdown_buttons.append(
                    dict(
                        label=res_class,
                        method="update",
                        args=[
                            {"visible": visibility},
                            {"title": f"Gene Presence - {res_class}"}
                        ]
                    )
                )

                current_trace += 1 + num_genes  # move to next group

            fig.update_layout(
                title=f"Gene Presence - {resistance_classes[0]}",
                xaxis_title="Gene",
                yaxis_title="Presence",
                xaxis=dict(tickangle=-45, tickfont=dict(style="italic")),
                height=600,
                updatemenus=[
                    dict(
                        buttons=dropdown_buttons,
                        direction="down",
                        showactive=True,
                        x=0.1,
                        xanchor="left",
                        y=1.15,
                        yanchor="top"
                    )
                ],
                annotations=[
                    dict(text="Select Resistance:", showarrow=False, 
                         x=0.02, y=1.18, xref="paper", yref="paper")
                ]
            )

        else:
            # Fallback if RESISTANCE column is missing
            data_sorted = gene_presence.sort_values("Presence", ascending=False)
            fig = go.Figure()

            fig.add_trace(
                go.Scatter(
                    x=data_sorted["Gene"],
                    y=data_sorted["Presence"],
                    mode='markers+text',
                    marker=dict(size=10, color='blue'),
                    text=[f"{int(val)}" for val in data_sorted["Presence"]],
                    textposition="top center",
                    name="Gene Presence"
                )
            )

            for gene, presence in zip(data_sorted["Gene"], data_sorted["Presence"]):
                fig.add_trace(
                    go.Scatter(
                        x=[gene, gene],
                        y=[0, presence],
                        mode='lines',
                        line=dict(color='blue', width=2),
                        showlegend=False,
                        hoverinfo='skip'
                    )
                )

            fig.update_layout(
                title="Gene Presence Across Assemblies",
                xaxis_title="Gene",
                yaxis_title="Total Presence",
                xaxis=dict(tickangle=-45, tickfont=dict(style="italic")),
                height=600
            )

        # Save the figure
        plot_file = os.path.join(output_dir, f"Resistance_gene_presence.html")
        fig.write_html(plot_file)
        logging.info(f"Interactive lollipop plot saved to {plot_file}")

        return fig

    except Exception as e:
        logging.error(f"Error generating interactive lollipop plot: {e}")
        raise

def generate_percentage_bar_plot_plotly(gene_presence, output_dir, base_name):
    """Generate an interactive percentage bar plot with dropdown by RESISTANCE category (Plotly version)."""
    try:
        fig = go.Figure()

        if "RESISTANCE" in gene_presence.columns:
            # Group data
            grouped = gene_presence.groupby("RESISTANCE")
            resistance_classes = list(grouped.groups.keys())

            dropdown_buttons = []
            trace_count = 0

            for i, res_class in enumerate(resistance_classes):
                data = grouped.get_group(res_class).sort_values("Percentage", ascending=False)

                gene_names = data["Gene"]
                percentages = data["Percentage"]
                presence_counts = data["Presence"]

                hover_text = [
                    f"{gene}: {pct:.1f}% ({count} assemblies)"
                    for gene, pct, count in zip(gene_names, percentages, presence_counts)
                ]

                # Add bar trace
                fig.add_trace(go.Bar(
                    x=gene_names,
                    y=percentages,
                    text=[f"{pct:.1f}%" for pct in percentages],
                    hovertext=hover_text,
                    hoverinfo="text",
                    name=res_class,
                    marker_color=px.colors.qualitative.Set2[i % len(px.colors.qualitative.Set2)],
                    visible=True if i == 0 else False,
                    textposition="outside"
                ))

            # Create dropdown buttons to toggle visibility
            for i, res_class in enumerate(resistance_classes):
                visibility = [False] * len(resistance_classes)
                visibility[i] = True
                dropdown_buttons.append(
                    dict(
                        label=res_class,
                        method="update",
                        args=[
                            {"visible": visibility},
                            {"title": f"Gene Presence - {res_class}"}
                        ]
                    )
                )

            fig.update_layout(
                title=f"Gene Presence - {resistance_classes[0]}",
                xaxis_title="Gene",
                yaxis_title="Prevalence",
                xaxis=dict(tickangle=-45),
                height=600,
                updatemenus=[
                    dict(
                        buttons=dropdown_buttons,
                        direction="down",
                        showactive=True,
                        x=0.1,
                        xanchor="left",
                        y=1.15,
                        yanchor="top"
                    )
                ],
                annotations=[
                    dict(text="Select Resistance:", showarrow=False,
                         x=0.02, y=1.18, xref="paper", yref="paper")
                ]
            )

        else:
            # Fallback if no RESISTANCE column
            gene_presence = gene_presence.sort_values("Percentage", ascending=False)
            fig = go.Figure()

            hover_text = [
                f"{gene}: {pct:.1f}% ({count} assemblies)"
                for gene, pct, count in zip(gene_presence["Gene"], gene_presence["Percentage"], gene_presence["Presence"])
            ]

            fig.add_trace(go.Bar(
                x=gene_presence["Gene"],
                y=gene_presence["Percentage"],
                text=[f"{pct:.1f}%" for pct in gene_presence["Percentage"]],
                hovertext=hover_text,
                hoverinfo="text",
                marker_color="blue",
                textposition="outside"
            ))

            fig.update_layout(
                title="Gene Presence Across Assemblies",
                xaxis_title="Gene",
                yaxis_title="Prevalence",
                xaxis=dict(tickangle=-45),
                height=600
            )

        # Save plot
        plot_file = os.path.join(output_dir, f"Resistance_gene_percentage.html")
        fig.write_html(plot_file)
        logging.info(f"Interactive percentage bar plot saved to {plot_file}")

        return fig

    except Exception as e:
        logging.error(f"Error generating percentage bar plot (Plotly): {e}")
        raise


def generate_gene_identity_boxplot(tidy_file, output_dir, fig_format, genes_per_subplot=100):
    """Generate and save a boxplot of gene identity scores with multiple subplots."""
    try:
        # Extract the base name of the file (e.g., "A" from "A_tidy_summary.csv")
        base_name = os.path.basename(tidy_file).replace("_tidy_summary.csv", "")
        # Load the file into a DataFrame
        df = pd.read_csv(tidy_file)
        # Filter out entries with Identity = 0
        df_filtered = df[df["Identity"] > 0]
        
        # Get unique genes
        unique_genes = df_filtered["Gene"].unique()
        total_genes = len(unique_genes)
        
        # Calculate number of subplots needed
        n_subplots = (total_genes + genes_per_subplot - 1) // genes_per_subplot
        
        # Create subplots
        fig, axes = plt.subplots(n_subplots, 1, figsize=(15, 6 * n_subplots))
        
        # Handle case where there's only one subplot
        if n_subplots == 1:
            axes = [axes]
        
        for i in range(n_subplots):
            start_idx = i * genes_per_subplot
            end_idx = min((i + 1) * genes_per_subplot, total_genes)
            genes_subset = unique_genes[start_idx:end_idx]
            
            # Filter data for this subset
            df_subset = df_filtered[df_filtered["Gene"].isin(genes_subset)]
            
            # Create boxplot
            sns.boxplot(x="Gene", y="Identity", data=df_subset, palette="Set3", ax=axes[i])
            axes[i].set_xticklabels(axes[i].get_xticklabels(), rotation=90)
            axes[i].set_title(f"Gene Identity Distribution (Genes {start_idx+1}-{end_idx})")
            axes[i].grid(True, alpha=0.3)
        
        plt.tight_layout()
        
        # Save the figure
        plot_file = os.path.join(output_dir, f"Resistance_gene_identity_boxplot.{fig_format}")
        plt.savefig(plot_file, dpi=300, format=fig_format)
        plt.close()
        logging.info(f"Multi-subplot boxplot saved to {plot_file}")
        
    except Exception as e:
        logging.error(f"Error generating gene identity boxplot: {e}")
        raise

def generate_gene_identity_boxplot_plotly(tidy_file, output_dir, genes_per_subplot=50):
    """Generate and save an interactive boxplot of gene identity scores using Plotly."""
    try:
        # Extract the base name of the file
        base_name = os.path.basename(tidy_file).replace("_tidy_summary.csv", "")
        # Load the file into a DataFrame
        df = pd.read_csv(tidy_file)
        # Filter out entries with Identity = 0
        df_filtered = df[df["Identity"] > 0]
        
        # Get unique genes
        unique_genes = df_filtered["Gene"].unique()
        total_genes = len(unique_genes)
        
        # Option 1: Single interactive plot (good for moderate number of genes)
        if total_genes <= 100:
            fig = px.box(df_filtered, x="Gene", y="Identity", 
                        title="Gene Identity Distribution (Interactive)",
                        color_discrete_sequence=px.colors.qualitative.Set3)
            
            # Add count annotations for each boxplot
            gene_counts = df_filtered.groupby("Gene")["Identity"].count()
            for i, gene in enumerate(unique_genes):
                count = gene_counts[gene]
                # Position counts at the bottom of the plot instead of top
                fig.add_annotation(
                    x=i, y=df_filtered["Identity"].min() - 1,  # Below the minimum value
                    text=f"n={count}",
                    showarrow=False,
                    font=dict(size=9, color="black"),
                    bgcolor="rgba(255,255,255,0.9)",
                    bordercolor="gray",
                    borderwidth=1
                )
            
            fig.update_xaxes(tickangle=-45, tickfont=dict(style="italic"))
            fig.update_layout(height=600, width=max(800, total_genes * 15),
                            # Add some padding at the bottom for annotations
                            margin=dict(b=80))
            
            # Save as HTML
            plot_file = os.path.join(output_dir, f"Resistance_gene_identity_boxplot.html")
            fig.write_html(plot_file)
            logging.info(f"Interactive boxplot saved to {plot_file}")
            
        # Option 2: Multiple subplots for large datasets
        else:
            n_subplots = (total_genes + genes_per_subplot - 1) // genes_per_subplot
            fig = make_subplots(rows=n_subplots, cols=1, 
                               subplot_titles=[f"Genes {i*genes_per_subplot+1}-{min((i+1)*genes_per_subplot, total_genes)}" 
                                             for i in range(n_subplots)])
            
            for i in range(n_subplots):
                start_idx = i * genes_per_subplot
                end_idx = min((i + 1) * genes_per_subplot, total_genes)
                genes_subset = unique_genes[start_idx:end_idx]
                
                # Filter data for this subset
                df_subset = df_filtered[df_filtered["Gene"].isin(genes_subset)]
                gene_counts = df_subset.groupby("Gene")["Identity"].count()
                
                # Add boxplot traces
                for j, gene in enumerate(genes_subset):
                    gene_data = df_subset[df_subset["Gene"] == gene]["Identity"]
                    fig.add_trace(
                        go.Box(y=gene_data, name=gene, showlegend=False),
                        row=i+1, col=1
                    )
                    
                    # Add count annotation at bottom of each subplot
                    count = gene_counts[gene]
                    fig.add_annotation(
                        x=j, y=df_subset["Identity"].min() - 1,  # Below minimum
                        text=f"n={count}",
                        showarrow=False,
                        font=dict(size=9, color="black"),
                        bgcolor="rgba(255,255,255,0.9)",
                        bordercolor="gray",
                        borderwidth=1,
                        row=i+1, col=1
                    )
            
            # Update x-axis formatting for all subplots
            for i in range(n_subplots):
                fig.update_xaxes(tickangle=-45, tickfont=dict(style="italic"), row=i+1, col=1)
            
            fig.update_layout(height=400 * n_subplots, title="Gene Identity Distribution (Multi-panel)",
                            # Add margin for bottom annotations
                            margin=dict(b=80))
            
            # Save as HTML
            plot_file = os.path.join(output_dir, f"Resistance_gene_identity_boxplot.html")
            fig.write_html(plot_file)
            logging.info(f"Multi-panel interactive boxplot saved to {plot_file}")
            
    except Exception as e:
        logging.error(f"Error generating interactive gene identity boxplot: {e}")
        raise

def generate_mean_arg_lollipop(tidy_file, output_dir, fig_format, group_by="Country"):
    """Generate lollipop plot of average NUM_FOUND per group, with counts annotated."""
    try:

        # Extract base name
        base_name = os.path.basename(tidy_file).replace("_tidy_summary.csv", "")
        df = pd.read_csv(tidy_file)

        # Group by group_by and Assembly Accession to get distinct samples
        df_grouped = df.groupby([group_by, "Assembly Accession"])["NUM_FOUND"].max().reset_index()

        # Calculate mean and count per group
        summary = df_grouped.groupby(group_by).agg(
            Mean_NUM_FOUND=("NUM_FOUND", "mean"),
            Count=("NUM_FOUND", "count")
        ).reset_index()

        # Create a label with count in brackets
        summary["Label"] = summary.apply(lambda row: f"{row[group_by]} ({row['Count']})", axis=1)

        # Sort for clean plotting
        summary = summary.sort_values("Mean_NUM_FOUND", ascending=False)

        # Plot
        plt.figure(figsize=(10, max(5, len(summary) * 0.4)))
        plt.hlines(y=summary["Label"], xmin=0, xmax=summary["Mean_NUM_FOUND"], color="gray", alpha=0.7)
        plt.plot(summary["Mean_NUM_FOUND"], summary["Label"], "o", color="firebrick")

        plt.xlabel(f"Mean ARGs per {group_by}")
        plt.ylabel("")
        plt.title(f"Mean ARGs by {group_by}")
        plt.tight_layout()

        # Save the plot
        plot_file = os.path.join(output_dir, f"Mean_ARG_by_{group_by}.{fig_format}")
        plt.savefig(plot_file, dpi=300, format=fig_format)
        plt.close()
        logging.info(f"Lollipop plot with counts saved to {plot_file}")

    except Exception as e:
        logging.error(f"Error generating lollipop plot with counts for {group_by}: {e}")
        raise

def generate_resistance_barplot(tidy_file, output_dir, fig_format):
    """Generate and save a bar plot of resistance percentages (both matplotlib and Plotly versions)."""
    try:
        # Extract base name
        base_name = os.path.basename(tidy_file).replace("_tidy_summary.csv", "")
        
        # Load and process data
        df = pd.read_csv(tidy_file)
        
        resistance_data = (
            df[df["Presence"] == 1][['RESISTANCE', 'Assembly Accession']]
            .drop_duplicates()
            .groupby('RESISTANCE')
            .agg(count=('Assembly Accession', 'size'))
            .reset_index()
        )

        # Total number of unique assemblies
        total_assemblies = df['Assembly Accession'].nunique()

        # Calculate percentage
        resistance_data['percentage'] = (resistance_data['count'] / total_assemblies) * 100

        # Sort by percentage for better visualization
        resistance_data = resistance_data.sort_values('percentage', ascending=True)

        # Wrap long resistance names (max 3 lines, 25 chars per line)
        resistance_data['wrapped_name'] = resistance_data['RESISTANCE'].apply(
            lambda x: '\n'.join(textwrap.wrap(x, width=30, max_lines=3))
        )
        # === MATPLOTLIB VERSION (Original) ===
        plt.figure(figsize=(max(10, len(resistance_data) * 0.6), 8))
        ax = sns.barplot(
            x='wrapped_name',
            y='percentage',
            data=resistance_data,
            palette='Set3',
            order=resistance_data['wrapped_name']  # Maintain sorting
        )
        
        # Add percentage labels
        for p in ax.patches:
            height = p.get_height()
            ax.annotate(
                f"{height:.1f}%",
                (p.get_x() + p.get_width() / 2., height),
                ha='center',
                va='center',
                xytext=(0, 5),
                textcoords='offset points',
                fontsize=11
            )
        
        # Formatting
        plt.xticks(rotation=90, ha='center')
        plt.xlabel('Type of Antibiotics')
        plt.ylabel('Resistance (%)')
        plt.title('Distribution of Resistance Across the Genomes')
        
        # Adjust layout to prevent label cutoff
        plt.tight_layout()
        plt.subplots_adjust(bottom=0.3)  # Extra space for wrapped labels
        
        # Save matplotlib plot
        plot_file = os.path.join(output_dir, f"Resistance_percentage_by_Antibiotics.{fig_format}")
        plt.savefig(plot_file, dpi=300, format=fig_format, bbox_inches='tight')
        plt.close()
        logging.info(f"Resistance barplot (matplotlib) saved to {plot_file}")
        
        # === PLOTLY VERSION (New Interactive) ===
        # Wrap long resistance names (max 3 lines, 25 chars per line)
        resistance_data['wrapped_name'] = resistance_data['RESISTANCE'].apply(
            lambda x: '<br>'.join(textwrap.wrap(x, width=30, max_lines=3))
        )

        # Create color palette similar to Set3
        n_bars = len(resistance_data)
        colors = px.colors.qualitative.Set3[:n_bars] if n_bars <= len(px.colors.qualitative.Set3) else px.colors.qualitative.Set3 * (n_bars // len(px.colors.qualitative.Set3) + 1)
        
        # Create Plotly bar chart
        fig_plotly = go.Figure()
        
        fig_plotly.add_trace(go.Bar(
            x=resistance_data['wrapped_name'],
            y=resistance_data['percentage'],
            text=[f"{val:.1f}%" for val in resistance_data['percentage']],
            textposition='outside',
            textfont=dict(size=12),
            marker=dict(
                color=colors[:n_bars],
                line=dict(color='rgba(0,0,0,0.3)', width=1)
            ),
            hovertemplate='<b>%{x}</b><br>Resistance: %{y:.1f}%<br>Count: %{customdata}<extra></extra>',
            customdata=resistance_data['count'],
            name='Resistance Percentage'
        ))
        
        # Calculate dynamic figure size
        width = max(800, len(resistance_data) * 60)
        height = max(600, 500)
        
        # Update layout
        fig_plotly.update_layout(
            title=dict(
                text='Distribution of Resistance Across the Genomes',
                x=0.5,
                font=dict(size=16)
            ),
            xaxis=dict(
                title='Type of Antibiotics',
                tickangle=-45,
                tickfont=dict(size=10),
                title_font=dict(size=14)
            ),
            yaxis=dict(
                title='Resistance (%)',
                title_font=dict(size=14),
                tickfont=dict(size=12)
            ),
            width=width,
            height=height,
            margin=dict(l=80, r=50, t=80, b=150),  # Extra bottom margin for rotated labels
            showlegend=False,
            plot_bgcolor='white',
            paper_bgcolor='white',
            font=dict(family="Arial, sans-serif")
        )
        
        # Update axes styling
        fig_plotly.update_xaxes(
            showgrid=False,
            showline=True,
            linewidth=1,
            linecolor='black',
            mirror=True
        )
        
        fig_plotly.update_yaxes(
            showgrid=True,
            gridwidth=1,
            gridcolor='lightgray',
            showline=True,
            linewidth=1,
            linecolor='black',
            mirror=True,
            range=[0, max(resistance_data['percentage']) * 1.25]  # Add some space above bars
        )
        
        # Add annotations for summary statistics
        fig_plotly.add_annotation(
            text=f"Total Assemblies: {total_assemblies}<br>Resistance Types: {len(resistance_data)}",
            xref="paper", yref="paper",
            x=0.5, y=1.0,
            xanchor="left", yanchor="top",
            showarrow=False,
            font=dict(size=11),
            bgcolor="rgba(255,255,255,0.8)",
            bordercolor="gray",
            borderwidth=1
        )
        
        # Save interactive Plotly plot
        plotly_file = os.path.join(output_dir, f"Resistance_percentage_by_Antibiotics.html")
        fig_plotly.write_html(plotly_file)
        logging.info(f"Interactive resistance barplot (plotly) saved to {plotly_file}")
        
        return fig_plotly
        
    except Exception as e:
        logging.error(f"Error generating resistance barplot: {e}")
        raise

def generate_mean_arg_lollipop_plotly(tidy_file, output_dir):
    """Generate an interactive Plotly lollipop plot of average NUM_FOUND per group, with dropdown for group_by options."""
    try:
        group_by_options = ["Geographic Location", "Collection Date", "Continent", "Subcontinent"]
        base_name = os.path.basename(tidy_file).replace("_tidy_summary.csv", "")
        df = pd.read_csv(tidy_file)

        fig = go.Figure()
        all_traces = []
        buttons = []

        for i, group_by in enumerate(group_by_options):
            df_grouped = df.groupby([group_by, "Assembly Accession"])["NUM_FOUND"].max().reset_index()

            summary = df_grouped.groupby(group_by).agg(
                Mean_NUM_FOUND=("NUM_FOUND", "mean"),
                Count=("NUM_FOUND", "count")
            ).reset_index()

            summary["Label"] = summary.apply(lambda row: f"{row[group_by]} ({row['Count']})", axis=1)
            summary = summary.sort_values("Mean_NUM_FOUND", ascending=True)

            y = summary["Label"]
            x = summary["Mean_NUM_FOUND"]

            trace_lines = go.Scatter(
                x=x,
                y=y,
                mode='lines',
                line=dict(color='lightgray'),
                showlegend=False,
                visible=(i == 0)
            )

            trace_dots = go.Scatter(
                x=x,
                y=y,
                mode='markers',
                marker=dict(color='firebrick', size=8),
                name=group_by,
                visible=(i == 0)
            )

            all_traces.extend([trace_lines, trace_dots])

            visibility = [False] * (2 * len(group_by_options))
            visibility[2*i] = visibility[2*i + 1] = True

            buttons.append(dict(
                label=group_by,
                method="update",
                args=[{"visible": visibility},
                      {"title": f"Average ARGs by {group_by} with Sample Counts",
                       "xaxis.title": f"Mean ARGs per {group_by}"}]
            ))

        fig.add_traces(all_traces)

        fig.update_layout(
            template="simple_white",
            updatemenus=[dict(
                buttons=buttons,
                direction="down",
                showactive=True,
                x=0.5,
                xanchor="center",
                y=1.2,
                yanchor="top"
            )],
            title="Average ARGs by Country with Sample Counts",
            xaxis_title="Mean ARG counts",
            yaxis_title="",
            margin=dict(l=160, r=30, t=100, b=40),
            height=800,
            plot_bgcolor="white",
            paper_bgcolor="white"
        )

        output_path = os.path.join(output_dir, f"Mean_Frequency_Antibiotic_Resistance_genes.html")
        fig.write_html(output_path)
        logging.info(f"Interactive ARG mean plot saved to {output_path}")

    except Exception as e:
        logging.error(f"Error generating interactive mean ARG plot: {e}")
        raise

def generate_comparison_heatmap(tidy_file, output_dir, fig_format, group_col="Geographic Location", 
                              genep_threshold=10, nseq_threshold=1, resistance_col="RESISTANCE"):
    """Generate a heatmap with primary gene labels and secondary resistance labels"""
    try:
        base_name = os.path.basename(tidy_file).replace("_tidy_summary.csv", "")
        df = pd.read_csv(tidy_file)

        # Standardize column names
        df = df.rename(columns={group_col: "Group", resistance_col: "Resistance"})
        
        # Create resistance category mapping
        resistance_map = df[['Gene', 'Resistance']].drop_duplicates()
        resistance_categories = sorted(resistance_map['Resistance'].unique())
        resistance_dict = {cat: f"A{i+1}" for i, cat in enumerate(resistance_categories)}
        resistance_map['Resistance_Code'] = resistance_map['Resistance'].map(resistance_dict)

        # Merge resistance info and sort
        df = pd.merge(df, resistance_map, on=['Gene', 'Resistance'])
        df = df.sort_values(['Resistance', 'Gene'])

        # Filtering logic
        df = df[['Assembly Accession', 'Group', 'Gene', 'Resistance', 'Resistance_Code', 'Presence']].dropna()
        df = df.drop_duplicates()
        
        group_counts = df.groupby("Group")["Assembly Accession"].nunique()
        selected_groups = group_counts[group_counts >= nseq_threshold].index
        filtered_df = df[df["Group"].isin(selected_groups)]

        if filtered_df.empty:
            logging.warning(f"No data for {base_name} - {group_col} after filtering. Skipping.")
            return

        # Prepare data for heatmap
        grouped_df = filtered_df.groupby(["Group", "Gene", "Resistance", "Resistance_Code"], as_index=False)["Presence"].sum()
        group_assembly_count = filtered_df.groupby("Group")["Assembly Accession"].nunique().reset_index()
        group_assembly_count.columns = ["Group", "Assembly Count"]

        merged_df = pd.merge(grouped_df, group_assembly_count, on="Group")
        merged_df["Percentage"] = (merged_df["Presence"] / merged_df["Assembly Count"]) * 100

        # Create complete grid
        all_groups = merged_df["Group"].unique()
        all_genes = merged_df["Gene"].unique()
        complete_grid = pd.MultiIndex.from_product([all_groups, all_genes], names=["Group", "Gene"]).to_frame(index=False)
        merged_df = pd.merge(complete_grid, merged_df, on=["Group", "Gene"], how="left")
        merged_df["Percentage"] = merged_df["Percentage"].fillna(0)

        # Filter genes by threshold
        filtered_genes_df = merged_df[merged_df["Percentage"] >= genep_threshold]
        if filtered_genes_df.empty:
            logging.warning(f"No data for {base_name} - {group_col} after filtering genes. Skipping.")
            return

        # Pivot for heatmap
        heatmap_data = filtered_genes_df.pivot(index="Group", columns="Gene", values="Percentage").fillna(0).round(0).astype(int)

        # Add overall row
        total_assembly_count = filtered_df["Assembly Accession"].nunique()
        overall_percentage = filtered_genes_df.groupby("Gene")["Presence"].sum() / total_assembly_count * 100
        overall_row = overall_percentage.reset_index().set_index("Gene").T
        overall_row.index = ["Overall"]
        heatmap_data = pd.concat([heatmap_data, overall_row])

        # Annotate group with sample size
        heatmap_data.index = [
            f"{group} ({group_counts.get(group, '')})" if group != "Overall" else "Overall"
            for group in heatmap_data.index
        ]

        # Plotting
        # Determine dynamic figure size
        num_genes = heatmap_data.shape[1]
        num_groups = heatmap_data.shape[0]
        fig_width = max(12, num_genes * 0.4)
        fig_height = max(6, num_groups * 0.5)

        plt.figure(figsize=(fig_width, fig_height))
        sns.set_style("white")
        sns.set(font_scale=1.2)

        ax = sns.heatmap(
            heatmap_data,
            annot=False,  # Don't label cell values
            fmt=".0f",
            cmap="Reds",
            linewidths=0.5,
            cbar_kws={"shrink": 0.8},
            vmin=0,
            vmax=100)


        # Sort and align Gene and Resistance_Code
        resistance_info = filtered_genes_df.sort_values(['Resistance', 'Gene']).drop_duplicates('Gene')[['Gene', 'Resistance_Code']]
        gene_order = resistance_info['Gene'].tolist()
        resistance_labels = resistance_info['Resistance_Code'].tolist()

        
        # Primary gene labels (top)
        ax.set_xticks(np.arange(len(gene_order)))
        ax.set_xticklabels(gene_order, rotation=90, ha='left', va='top')  # Gene names at top
        ax.tick_params(axis='x', pad = 25, which='both', bottom=False)  # Keep bottom ticks visible

        # Resistance codes as secondary axis (bottom)
        secax = ax.secondary_xaxis('bottom')
        secax.set_xticks(np.arange(len(gene_order)))
        secax.set_xticklabels(resistance_labels, rotation=90, ha='center', va='bottom')
        secax.tick_params(axis='x', pad=25, length=0)  # pad pushes it below main x-ticks

        plt.ylabel(f"{group_col} (Total Submitted Sequence)", fontsize=14)
        plt.yticks(rotation=0)
        
        # Build legend text as 3 columns
        legend_entries = [f"{v}: {k}" for k, v in resistance_dict.items()]
        num_cols = 3
        legend_chunks = [legend_entries[i::num_cols] for i in range(num_cols)]

        # Pad shorter columns to match length
        max_len = max(len(col) for col in legend_chunks)
        for col in legend_chunks:
            while len(col) < max_len:
                col.append("")

        # Join rows from columns
        legend_text_lines = [
            f"{col1:<55} {col2:<55} {col3}" for col1, col2, col3 in zip(*legend_chunks)
        ]
        legend_text = "Resistance Categories:\n\n" + "\n".join(legend_text_lines)


        # Add legend below plot, left-aligned
        plt.figtext(0.01, -0.01, legend_text,  # <- lower Y-position
                    ha="left", va="top", fontsize=14, family='monospace',
                    bbox={"facecolor": "white", "alpha": 0.9, "pad": 3})

        # Adjust figure margins
        plt.subplots_adjust(bottom=0.15)  # Increased from 0.2 to make room for legend

        output_file = os.path.join(output_dir, f"{base_name}_{group_col.replace(' ', '_')}_heatmap.{fig_format}")
        plt.savefig(output_file, format=fig_format, dpi=300, bbox_inches='tight')
        plt.close()
        logging.info(f"Heatmap saved to {output_file}")

    except Exception as e:
        logging.error(f"Error generating heatmap for {group_col}: {e}")
        raise

def generate_comparison_heatmap_plotly(tidy_file, output_dir, resistance_col="RESISTANCE"):
    """Generate an interactive heatmap with multiple grouping options using Plotly"""
    try:
        base_name = os.path.basename(tidy_file).replace("_tidy_summary.csv", "")
        df = pd.read_csv(tidy_file)
        
        # Available grouping options
        group_options = ["Geographic Location", "Subcontinent", "Continent", "Collection Date"]
        
        # Create figure with initial empty traces
        fig = go.Figure()
        
        # Store data for each grouping option
        heatmap_data_dict = {}
        resistance_info_dict = {}
        
        for group_col in group_options:
            if group_col not in df.columns:
                logging.warning(f"Column '{group_col}' not found in data. Skipping.")
                continue
                
            # Process data for this grouping
            df_copy = df.copy()
            df_copy = df_copy.rename(columns={group_col: "Group", resistance_col: "Resistance"})
            
            # Create resistance category mapping
            resistance_map = df_copy[['Gene', 'Resistance']].drop_duplicates()
            resistance_categories = sorted(resistance_map['Resistance'].unique())
            resistance_dict = {cat: f"A{i+1}" for i, cat in enumerate(resistance_categories)}
            resistance_map['Resistance_Code'] = resistance_map['Resistance'].map(resistance_dict)

            # Merge resistance info and sort
            df_copy = pd.merge(df_copy, resistance_map, on=['Gene', 'Resistance'])
            df_copy = df_copy.sort_values(['Resistance', 'Gene'])

            # Remove filtering logic - keep all data
            df_copy = df_copy[['Assembly Accession', 'Group', 'Gene', 'Resistance', 'Resistance_Code', 'Presence']].dropna()
            df_copy = df_copy.drop_duplicates()
            
            if df_copy.empty:
                logging.warning(f"No data for {base_name} - {group_col}. Skipping.")
                continue

            # Prepare data for heatmap
            grouped_df = df_copy.groupby(["Group", "Gene", "Resistance", "Resistance_Code"], as_index=False)["Presence"].sum()
            group_assembly_count = df_copy.groupby("Group")["Assembly Accession"].nunique().reset_index()
            group_assembly_count.columns = ["Group", "Assembly Count"]

            merged_df = pd.merge(grouped_df, group_assembly_count, on="Group")
            merged_df["Percentage"] = (merged_df["Presence"] / merged_df["Assembly Count"]) * 100

            # Create complete grid
            all_groups = merged_df["Group"].unique()
            all_genes = merged_df["Gene"].unique()
            complete_grid = pd.MultiIndex.from_product([all_groups, all_genes], names=["Group", "Gene"]).to_frame(index=False)
            merged_df = pd.merge(complete_grid, merged_df, on=["Group", "Gene"], how="left")
            merged_df["Percentage"] = merged_df["Percentage"].fillna(0)

            # Pivot for heatmap
            heatmap_data = merged_df.pivot(index="Group", columns="Gene", values="Percentage").fillna(0).round(0).astype(int)

            # Add overall row
            total_assembly_count = df_copy["Assembly Accession"].nunique()
            overall_percentage = merged_df.groupby("Gene")["Presence"].sum() / total_assembly_count * 100
            overall_row = overall_percentage.reset_index().set_index("Gene").T
            overall_row.index = ["Overall"]
            heatmap_data = pd.concat([heatmap_data, overall_row])

            # Annotate group with sample size
            group_counts = df_copy.groupby("Group")["Assembly Accession"].nunique()
            heatmap_data.index = [
                f"{group} ({group_counts.get(group, '')})" if group != "Overall" else "Overall"
                for group in heatmap_data.index
            ]

            # Store data for this grouping
            heatmap_data_dict[group_col] = heatmap_data
            
            # Sort and align Gene and Resistance_Code
            resistance_info = merged_df.sort_values(['Resistance', 'Gene']).drop_duplicates('Gene')[['Gene', 'Resistance_Code']]
            resistance_info_dict[group_col] = {
                'genes': resistance_info['Gene'].tolist(),
                'resistance_codes': resistance_info['Resistance_Code'].tolist(),
                'resistance_dict': resistance_dict
            }
        
        # Create initial heatmap (first available group)
        if not heatmap_data_dict:
            logging.error("No valid data for any grouping option")
            return
            
        initial_group = list(heatmap_data_dict.keys())[0]
        initial_data = heatmap_data_dict[initial_group]
        initial_resistance_info = resistance_info_dict[initial_group]
        
        # Calculate flexible figure size
        n_genes = len(initial_data.columns)
        n_groups = len(initial_data.index)
        
        # Dynamic sizing
        width = max(800, n_genes * 40)  # Minimum 800px, 40px per gene
        height = max(600, n_groups * 30)  # Minimum 600px, 30px per group
        
        # Create initial heatmap trace
        fig.add_trace(
            go.Heatmap(
                z=initial_data.values,
                x=initial_data.columns,
                y=initial_data.index,
                colorscale='Reds',
                zmin=0,
                zmax=100,
                showscale=True,
                text=initial_data.values,
                texttemplate='%{text:.0f}',
                textfont=dict(size=10),
                hovertemplate='<b>%{y}</b><br>Gene: %{x}<br>Percentage: %{z:.1f}%<extra></extra>',
                colorbar=dict(title=dict(text="Percentage"))
            )
        )
        
        # Create resistance code annotations for initial data
        def create_resistance_annotations(gene_order, resistance_codes):
            """Create annotations for resistance codes below gene names"""
            annotations = []
            for i, (gene, res_code) in enumerate(zip(gene_order, resistance_codes)):
                annotations.append(
                    dict(
                        x=gene,
                        y=-0.08,  # Position below the heatmap
                        xref="x",
                        yref="paper",
                        text=f"<i>{res_code}</i>",
                        showarrow=False,
                        font=dict(size=10, color="blue"),
                        align="center"
                    )
                )
            return annotations
        
        # Create initial resistance annotations
        initial_annotations = create_resistance_annotations(
            initial_data.columns.tolist(), 
            initial_resistance_info['resistance_codes']
        )
        
        # Create buttons for different grouping options
        buttons = []
        for group_col in heatmap_data_dict.keys():
            data = heatmap_data_dict[group_col]
            resistance_info = resistance_info_dict[group_col]
            
            # Calculate size for this grouping
            group_width = max(800, len(data.columns) * 40)
            group_height = max(800, len(data.index) * 40)
            
            # Create resistance legend text
            legend_entries = [f"{v}: {k}" for k, v in resistance_info['resistance_dict'].items()]
            legend_text = "Resistance Categories: " + " | ".join(legend_entries)
            
            # Create annotations for this grouping
            resistance_annotations = create_resistance_annotations(
                data.columns.tolist(),
                resistance_info['resistance_codes']
            )
            
            # Combine all annotations for this grouping
            all_annotations = [
                dict(
                    text="Select Grouping:",
                    showarrow=False,
                    x=0.02, y=1.25,
                    xref="paper", yref="paper",
                    font=dict(size=12)
                ),
                dict(
                    text=f"<b>Gene Names</b>",
                    x=0.5, y=1.02,
                    xref="paper", yref="paper",
                    showarrow=False,
                    font=dict(size=12)
                ),
                dict(
                    text=legend_text,
                    x=0.5, y=-0.18,
                    xref="paper", yref="paper",
                    showarrow=False,
                    font=dict(size=9),
                    bgcolor="rgba(255,255,255,0.8)"
                )
            ] + resistance_annotations
            
            buttons.append(
                dict(
                    label=group_col,
                    method="update",
                    args=[
                        {
                            "z": [data.values],
                            "x": [data.columns],
                            "y": [data.index],
                            "text": [data.values]
                        },
                        {
                            "title": f"Gene Presence Heatmap - {group_col}",
                            "width": group_width,
                            "height": group_height,
                            "annotations": all_annotations,
                            "yaxis.title": f"{group_col} (Total Submitted Sequence)"
                        }
                    ]
                )
            )
        
        # Create initial complete annotations
        initial_complete_annotations = [
            dict(
                text="Select Grouping:",
                showarrow=False,
                x=0.02, y=1.25,
                xref="paper", yref="paper",
                font=dict(size=12)
            ),
            dict(
                text=f"<b>Gene Names</b>",
                x=0.5, y=1.02,
                xref="paper", yref="paper",
                showarrow=False,
                font=dict(size=12)
            ),
            dict(
                text="Resistance Categories: " + " | ".join([f"{v}: {k}" for k, v in initial_resistance_info['resistance_dict'].items()]),
                x=0.5, y=-0.18,
                xref="paper", yref="paper",
                showarrow=False,
                font=dict(size=12),
                bgcolor="rgba(255,255,255,0.8)"
            )
        ] + initial_annotations
        
        # Update layout (single comprehensive update)
        fig.update_layout(
            title=f"Gene Presence Heatmap - {initial_group}",
            xaxis=dict(
                tickangle=-45,
                side="top",
                tickfont=dict(size=10, style="italic")
            ),
            yaxis=dict(
                title=f"{initial_group} (Total Submitted Sequence)",
                tickfont=dict(size=10)
            ),
            width=width,
            height=height,
            updatemenus=[
                dict(
                    type="buttons",
                    direction="right",
                    buttons=buttons,
                    pad={"r": 10, "t": 10},
                    showactive=True,
                    x=0.1,
                    xanchor="left",
                    y=1.4,
                    yanchor="top"
                )
            ],
            annotations=initial_complete_annotations,
            margin=dict(l=100, r=100, t=180, b=150),  # Space for annotations
            font=dict(size=10)
        )
        
        # Save the figure
        output_file = os.path.join(output_dir, f"Resistance_gene_distribution_heatmap.html")
        fig.write_html(output_file)
        logging.info(f"Interactive heatmap saved to {output_file}")
        
        return fig
        
    except Exception as e:
        logging.error(f"Error generating interactive heatmap: {e}")
        raise

def mean_Arg_resistance_analysis_plotly(tidy_file, output_dir):
    """
    Generate an interactive Plotly box plot showing the distribution of NUM_FOUND per selected grouping variable.
    Allows filtering by different grouping options with statistical comparisons between groups.
    """
    try:
        # Load data
        df = pd.read_csv(tidy_file)
        df = df[["Collection Date", "Geographic Location", "Continent", "Subcontinent", "Assembly BioSample Accession", "NUM_FOUND"]].dropna()
        df = df.drop_duplicates()

        # Ensure required columns exist
        required_cols = ["Collection Date", "Geographic Location", "Continent", "Subcontinent", "Assembly BioSample Accession", "NUM_FOUND"]
        if not all(col in df.columns for col in required_cols):
            raise ValueError(f"Missing required columns in input file: {required_cols}")
        
        # Available grouping options
        group_options = ["Geographic Location", "Subcontinent", "Continent", "Collection Date"]
        
        # Create initial plot with first grouping option
        initial_group = group_options[0]
        
        # Function to perform statistical tests
        def perform_statistical_tests(data, group_col):
            """Perform statistical tests comparing medians between groups (minimum 5 observations per group)"""
            groups = data[group_col].unique()
            results = []
            csv_data = []
            
            if len(groups) < 2:
                return "Not enough groups for statistical comparison", []
            
            # Get data for each group
            group_data = [data[data[group_col] == group]['NUM_FOUND'].values for group in groups]
            
            # Filter out groups with less than 5 observations
            valid_groups = [(i, group, data) for i, (group, data) in enumerate(zip(groups, group_data)) if len(data) >= 5]
            
            if len(valid_groups) < 2:
                return "Not enough valid groups (n>=5) for statistical comparison", []
            
            results.append(f"Statistical Analysis for {group_col}:")
            results.append("=" * 50)
            
            # Summary statistics
            results.append("\nSummary Statistics (only groups with n>=5):")
            for _, group, group_vals in valid_groups:
                median_val = np.median(group_vals)
                mean_val = np.mean(group_vals)
                q1 = np.percentile(group_vals, 25)
                q3 = np.percentile(group_vals, 75)
                std_val = np.std(group_vals)
                results.append(f"{group}: n={len(group_vals)}, Median={median_val:.2f}, Mean={mean_val:.2f}, IQR=({q1:.2f}-{q3:.2f}), SD={std_val:.2f}")
                
                # Add to CSV data
                csv_data.append({
                    'Group': group,
                    'Sample_Size': len(group_vals),
                    'Mean': mean_val,
                    'Median': median_val,
                    'Q1': q1,
                    'Q3': q3,
                    'Standard_Deviation': std_val,
                    'Min': np.min(group_vals),
                    'Max': np.max(group_vals)
                })
            
            # Overall test (Kruskal-Wallis for multiple groups)
            if len(valid_groups) > 2:
                try:
                    valid_data = [data for _, _, data in valid_groups]
                    h_stat, p_value = kruskal(*valid_data)
                    results.append(f"\nKruskal-Wallis Test (overall): H={h_stat:.4f}, p={p_value:.4f}")
                    if p_value < 0.05:
                        results.append("Significant difference detected between groups (p<0.05)")
                    else:
                        results.append("No significant difference between groups (p>=0.05)")
                    
                    # Add overall test to CSV data
                    overall_test_data = {
                        'Test_Type': 'Kruskal-Wallis',
                        'Statistic': h_stat,
                        'P_Value': p_value,
                        'Significant': p_value < 0.05,
                        'Groups_Compared': f"All {len(valid_groups)} groups"
                    }
                    
                except Exception as e:
                    results.append(f"Error in Kruskal-Wallis test: {str(e)}")
                    overall_test_data = {'Error': str(e)}
            
            # Pairwise comparisons (Mann-Whitney U tests)
            results.append(f"\nPairwise Comparisons (Mann-Whitney U tests, n>=5 per group):")
            results.append("-" * 40)
            
            pairwise_data = []
            for i, (_, group1, data1) in enumerate(valid_groups):
                for j, (_, group2, data2) in enumerate(valid_groups[i+1:], i+1):
                    try:
                        statistic, p_value = mannwhitneyu(data1, data2, alternative='two-sided')
                        significance = "***" if p_value < 0.001 else "**" if p_value < 0.01 else "*" if p_value < 0.05 else "ns"
                        results.append(f"{group1} vs {group2}: U={statistic:.2f}, p={p_value:.4f} {significance}")
                        
                        # Add pairwise comparison to CSV data
                        pairwise_data.append({
                            'Group_1': group1,
                            'Group_2': group2,
                            'Sample_Size_1': len(data1),
                            'Sample_Size_2': len(data2),
                            'Test_Type': 'Mann-Whitney U',
                            'U_Statistic': statistic,
                            'P_Value': p_value,
                            'Significant': p_value < 0.05,
                            'Significance_Level': significance
                        })
                    except Exception as e:
                        results.append(f"{group1} vs {group2}: Error - {str(e)}")
                        pairwise_data.append({
                            'Group_1': group1,
                            'Group_2': group2,
                            'Error': str(e)
                        })
            
            results.append("\nLegend: *** p<0.001, ** p<0.01, * p<0.05, ns = not significant")
            
            # Prepare CSV data dictionary
            csv_results = {
                'summary_stats': csv_data,
                'pairwise_tests': pairwise_data
            }
            
            if len(valid_groups) > 2:
                csv_results['overall_test'] = [overall_test_data]
            
            return "\n".join(results), csv_results
        
        # Function to calculate dynamic width based on number of categories
        def calculate_plot_width(n_categories):
            base_width = 600
            width_per_category = 80
            max_width = 1600
            return min(base_width + (n_categories * width_per_category), max_width)
        
        # Create the main figure
        fig = go.Figure()
        
        # Add box plots for each grouping option (initially hidden except first)
        for i, group_option in enumerate(group_options):
            # Filter data for valid entries
            valid_data = df[df[group_option].notna()]
            
            # Create custom hover template with sample size and statistics
            hover_template = (
                "<b>%{x}</b><br>" +
                "Median: %{median}<br>" +
                "Q1: %{q1}<br>" +
                "Q3: %{q3}<br>" +
                "Mean: %{mean}<br>" +
                "Sample Size (n): %{customdata}<br>" +
                "<extra></extra>"
            )
            
            # Calculate sample sizes for each category
            sample_sizes = valid_data.groupby(group_option)['NUM_FOUND'].count().to_dict()
            categories = valid_data[group_option].unique()
            customdata = [sample_sizes.get(cat, 0) for cat in categories]
            
            fig.add_trace(go.Box(
                y=valid_data['NUM_FOUND'],
                x=valid_data[group_option],
                name=group_option,
                visible=True if i == 0 else False,
                boxpoints='outliers',  # Show outliers
                jitter=0.3,  # Add some jitter to points
                pointpos=0,  # Position of points
                marker=dict(
                    size=4,
                    opacity=0.7
                ),
                line=dict(width=2),
                fillcolor='rgba(56, 128, 139, 0.7)',
                boxmean='sd',  # Show mean and standard deviation
                hovertemplate=hover_template,
                customdata=customdata
            ))
        
        # Create dropdown menu with dynamic width updates
        dropdown_buttons = []
        for i, group_option in enumerate(group_options):
            # Calculate number of categories for this group option
            n_categories = len(df[df[group_option].notna()][group_option].unique())
            plot_width = calculate_plot_width(n_categories)
            
            # Create visibility array
            visibility = [False] * len(group_options)
            visibility[i] = True
            
            dropdown_buttons.append(
                dict(
                    args=[{"visible": visibility},
                          {"title": f"Distribution of NUM_FOUND by {group_option}",
                           "xaxis.title": group_option,
                           "width": plot_width}],
                    label=f"{group_option} ({n_categories} categories)",
                    method="update"
                )
            )
        
        # Calculate initial width
        initial_n_categories = len(df[df[initial_group].notna()][initial_group].unique())
        initial_width = calculate_plot_width(initial_n_categories)
        
        # Update layout
        fig.update_layout(
            title=f"Distribution of Resistance Gene Frequency by {initial_group}",
            xaxis_title=initial_group,
            yaxis_title="Frequency",
            updatemenus=[
                dict(
                    buttons=dropdown_buttons,
                    direction="down",
                    showactive=True,
                    x=0.02,
                    xanchor="left",
                    y=1.15,
                    yanchor="top",
                    bgcolor="rgba(255, 255, 255, 0.9)",
                    bordercolor="rgba(0, 0, 0, 0.3)",
                    borderwidth=1
                ),
            ],
            annotations=[
                dict(text="Select Grouping Variable:", 
                     showarrow=False, 
                     x=0.02, y=1.30, 
                     xref="paper", yref="paper",
                     align="left",
                     font=dict(size=12))
            ],
            showlegend=False,
            width=initial_width,
            height=900,
            template="plotly_white"
        )
        
        # Update axes
        fig.update_xaxes(tickangle=45)
        fig.update_yaxes(gridcolor='lightgray', gridwidth=0.5)
        
        # Save the plot
        os.makedirs(output_dir, exist_ok=True)
        output_file = os.path.join(output_dir, "Resistance_gene_frequency_boxplot.html")
        fig.write_html(output_file)
        
        # Generate statistical reports for each grouping option
        print("Generating statistical analyses...")
        print("=" * 80)
        
        # Initialize combined dataframes
        all_summary_stats = []
        all_pairwise_tests = []
        all_overall_tests = []
        
        for group_option in group_options:
            valid_data = df[df[group_option].notna()]
            if len(valid_data) > 0:
                stat_results, csv_data = perform_statistical_tests(valid_data, group_option)
                print(f"\n{stat_results}")
                print("\n" + "=" * 80)
                
                # Collect data for combined CSV files
                if csv_data:
                    # Add grouping variable column to summary statistics
                    if 'summary_stats' in csv_data and csv_data['summary_stats']:
                        for row in csv_data['summary_stats']:
                            row['Grouping_Variable'] = group_option
                        all_summary_stats.extend(csv_data['summary_stats'])
                    
                    # Add grouping variable column to pairwise comparisons
                    if 'pairwise_tests' in csv_data and csv_data['pairwise_tests']:
                        for row in csv_data['pairwise_tests']:
                            row['Grouping_Variable'] = group_option
                        all_pairwise_tests.extend(csv_data['pairwise_tests'])
                    
                    # Add grouping variable column to overall tests
                    if 'overall_test' in csv_data and csv_data['overall_test']:
                        for row in csv_data['overall_test']:
                            row['Grouping_Variable'] = group_option
                        all_overall_tests.extend(csv_data['overall_test'])
        
        # Save combined CSV files
        if all_summary_stats:
            summary_df = pd.DataFrame(all_summary_stats)
            # Reorder columns to put Grouping_Variable first
            cols = ['Grouping_Variable'] + [col for col in summary_df.columns if col != 'Grouping_Variable']
            summary_df = summary_df[cols]
            summary_file = os.path.join(output_dir, "combined_summary_statistics.csv")
            summary_df.to_csv(summary_file, index=False)
            print(f"Combined summary statistics saved to: {summary_file}")
        
        if all_pairwise_tests:
            pairwise_df = pd.DataFrame(all_pairwise_tests)
            # Reorder columns to put Grouping_Variable first
            cols = ['Grouping_Variable'] + [col for col in pairwise_df.columns if col != 'Grouping_Variable']
            pairwise_df = pairwise_df[cols]
            pairwise_file = os.path.join(output_dir, "combined_pairwise_comparisons.csv")
            pairwise_df.to_csv(pairwise_file, index=False)
            print(f"Combined pairwise comparisons saved to: {pairwise_file}")
        
        if all_overall_tests:
            overall_df = pd.DataFrame(all_overall_tests)
            # Reorder columns to put Grouping_Variable first
            cols = ['Grouping_Variable'] + [col for col in overall_df.columns if col != 'Grouping_Variable']
            overall_df = overall_df[cols]
            overall_file = os.path.join(output_dir, "combined_overall_tests.csv")
            overall_df.to_csv(overall_file, index=False)
            print(f"Combined overall test results saved to: {overall_file}")
        
        print(f"\nInteractive box plot saved to: {output_file}")
        print(f"Combined statistical analysis CSV files saved to: {output_dir}")
        
        return fig
        
    except Exception as e:
        print(f"Error generating box plot analysis: {str(e)}")
        return None


def correlation_scatterplot_analysis(tidy_file, output_dir, group_col):
    """
    Generate scatterplots and correlation analysis for NUM_FOUND vs Collection Date (Year),
    grouped by the specified column (e.g., 'Continent', 'Geographic Location').
    Uses NumPy polyfit for regression line. Saves interactive Plotly HTML and correlation summary CSV.
    """
    try:
        df = pd.read_csv(tidy_file)
        df = df[["Assembly BioSample Accession", "Collection Date", group_col, "NUM_FOUND"]].dropna().drop_duplicates()

        # Clean collection date
        df = df[~df["Collection Date"].isin(["absent", "none", "", None])]

        def parse_year(x):
            try:
                return int(str(x)[:4])
            except Exception:
                return np.nan

        df["Year"] = df["Collection Date"].apply(parse_year)
        df = df.dropna(subset=["Year"])
        df["Year"] = df["Year"].astype(int)

        unique_values = []
        valid_groups = {}

        for val in sorted(df[group_col].dropna().unique()):
            sub_df = df[df[group_col] == val]
            if len(sub_df) >= 5:
                unique_values.append(val)
                valid_groups[val] = sub_df

        if not unique_values:
            raise ValueError("No groups have sufficient data (minimum 5 samples required)")

        records = []
        fig = go.Figure()

        for i, val in enumerate(unique_values):
            sub_df = valid_groups[val]
            x = sub_df["Year"].values
            y = sub_df["NUM_FOUND"].values

            pearson_corr, pearson_p = pearsonr(x, y)
            spearman_corr, spearman_p = spearmanr(x, y)

            records.append({
                group_col: val,
                "n_samples": len(sub_df),
                "pearson_r": pearson_corr,
                "pearson_p": pearson_p,
                "spearman_r": spearman_corr,
                "spearman_p": spearman_p
            })

            fig.add_trace(go.Scatter(
                x=x,
                y=y,
                mode="markers",
                name=val,
                visible=(i == 0),
                marker=dict(size=8, opacity=0.6),
                text=sub_df["Assembly BioSample Accession"],
                hovertemplate=(
                    f"<b>{group_col}</b>: {val}<br>"
                    "Year: %{x}<br>"
                    "NUM_FOUND: %{y}<br>"
                    "Sample: %{text}<extra></extra>"
                )
            ))

            # Add regression line using np.polyfit
            slope, intercept = np.polyfit(x, y, 1)
            y_pred = slope * x + intercept
            fig.add_trace(go.Scatter(
                x=x,
                y=y_pred,
                mode="lines",
                name=f"{val} Trend",
                line=dict(dash="dash", color="black"),
                visible=(i == 0),
                showlegend=False
            ))

        corr_df = pd.DataFrame(records)
        corr_csv = os.path.join(output_dir, f"{group_col.replace(' ', '_')}_correlation_summary.csv")
        corr_df.to_csv(corr_csv, index=False)
        logging.info(f"Saved correlation summary to {corr_csv}")

        dropdown_buttons = []
        for i, val in enumerate(unique_values):
            visibility = [False] * len(unique_values) * 2  # 2 traces per group (scatter + line)
            visibility[i * 2] = True       # scatter
            visibility[i * 2 + 1] = True   # line

            stats = records[i]
            dropdown_buttons.append(dict(
                label=f"{val} (n={stats['n_samples']})",
                method="update",
                args=[
                    {"visible": visibility},
                    {
                        "title": f"NUM_FOUND vs Year - {group_col}: {val}<br>" +
                                 f"<sub>Pearson r={stats['pearson_r']:.3f} (p={stats['pearson_p']:.3f}), " +
                                 f"Spearman ρ={stats['spearman_r']:.3f} (p={stats['spearman_p']:.3f})</sub>"
                    }
                ]
            ))

        initial_stats = records[0]
        initial_title = (f"NUM_FOUND vs Year - {group_col}: {unique_values[0]}<br>" +
                         f"<sub>Pearson r={initial_stats['pearson_r']:.3f} (p={initial_stats['pearson_p']:.3f}), " +
                         f"Spearman ρ={initial_stats['spearman_r']:.3f} (p={initial_stats['spearman_p']:.3f})</sub>")

        fig.update_layout(
            title=initial_title,
            xaxis_title="Collection Year",
            yaxis_title="Number of Resistance Genes (NUM_FOUND)",
            updatemenus=[dict(
                buttons=dropdown_buttons,
                direction="down",
                showactive=True,
                x=0.02,
                y=0.98,
                xanchor="left",
                yanchor="top",
                bgcolor="rgba(255,255,255,0.8)",
                bordercolor="rgba(0,0,0,0.2)",
                borderwidth=1
            )],
            margin=dict(t=120, l=60, r=60, b=60),
            template="plotly_white"
        )

        fig.add_annotation(
            text=f"Select {group_col}:",
            x=0.02,
            y=1.02,
            xref="paper",
            yref="paper",
            showarrow=False,
            font=dict(size=12, color="black")
        )

        html_out = os.path.join(output_dir, f"{group_col.replace(' ', '_')}_correlation_plot.html")
        fig.write_html(html_out)
        logging.info(f"Saved interactive plot to {html_out}")

        print(f"\nCorrelation Analysis Summary for {group_col}:")
        print(f"Total groups processed: {len(unique_values)}")
        for record in records:
            print(f"  {record[group_col]}: {record['n_samples']} samples, "
                  f"Pearson r={record['pearson_r']:.3f}")

        return fig

    except Exception as e:
        logging.error(f"Error in correlation analysis: {e}")
        raise

def combined_correlation_analysis(output_dir):
    """
    Generate combined correlation analysis for all grouping variables in the dataset.
    Saves individual correlation scatterplots and a summary CSV file.
    """
    try:
        # Read the three generated CSV files
        geo_location_df = pd.read_csv(os.path.join(output_dir, "Geographic_Location_correlation_summary.csv"))
        continent_df = pd.read_csv(os.path.join(output_dir, "Continent_correlation_summary.csv"))
        subcontinent_df = pd.read_csv(os.path.join(output_dir, "Subcontinent_correlation_summary.csv"))
        
        # Add Geographic_Level column to identify the type of geographic grouping
        geo_location_df['Geographic_Level'] = 'Geographic Location'
        continent_df['Geographic_Level'] = 'Continent'
        subcontinent_df['Geographic_Level'] = 'Subcontinent'
        
        # Rename the geographic columns to a consistent name
        geo_location_df = geo_location_df.rename(columns={'Geographic Location': 'Geographic_Region'})
        continent_df = continent_df.rename(columns={'Continent': 'Geographic_Region'})
        subcontinent_df = subcontinent_df.rename(columns={'Subcontinent': 'Geographic_Region'})
        
        # Combine all dataframes
        combined_df = pd.concat([geo_location_df, continent_df, subcontinent_df], ignore_index=True)
        
        # Reorder columns for better readability
        combined_df = combined_df[['Geographic_Level', 'Geographic_Region', 'n_samples', 
                                 'pearson_r', 'pearson_p', 'spearman_r', 'spearman_p']]
        
        # Save combined CSV
        combined_csv = os.path.join(output_dir, "combined_geographic_correlation_summary.csv")
        combined_df.to_csv(combined_csv, index=False)

        # Remove the three separate CSV files
        separate_files = [
            os.path.join(output_dir, "Geographic_Location_correlation_summary.csv"),
            os.path.join(output_dir, "Continent_correlation_summary.csv"),
            os.path.join(output_dir, "Subcontinent_correlation_summary.csv")
        ]
        
        print("Removing separate correlation...CSV files...")
        for file_path in separate_files:
            try:
                if os.path.exists(file_path):
                    os.remove(file_path)
                    print(f"  Removed: {os.path.basename(file_path)}")
                else:
                    print(f"  File not found: {os.path.basename(file_path)}")
            except Exception as e:
                print(f"  Error removing {os.path.basename(file_path)}: {e}")
        
        print("Cleanup completed. Only combined correlation...CSV file remains.")
    except Exception as e:
        print(f"Error in combined_correlation_analysis: {e}")

def generate_geographic_resistance_map_plotly(tidy_file, output_dir):
    """
    Generate an interactive geographic map using Plotly, with dropdowns for Gene and Collection Date.
    Shows percentage presence per country, with an "Overall" option to aggregate over all years.
    """
    try:

        # Load and clean data
        df = pd.read_csv(tidy_file)
        df = df[~df["Geographic Location"].isin(["absent", None, ""])]
        df = df[~df["Collection Date"].isin(["absent", None, ""])]

        # Ensure required columns exist
        required_cols = ["Geographic Location", "Collection Date", "Gene", "Assembly BioSample Accession", "Presence"]
        if not all(col in df.columns for col in required_cols):
            raise ValueError(f"Missing required columns in input file: {required_cols}")

        # Group by unique sample to calculate presence count
        grouped = (
            df.groupby(["Gene", "Collection Date", "Geographic Location", "Assembly BioSample Accession"])
            .agg(present_count=("Presence", "sum"), total_count=("Presence", "count"))
            .reset_index()
        )

        # Aggregate to get country-level stats per gene and collection date
        country_stats = (
            grouped.groupby(["Gene", "Collection Date", "Geographic Location"])
            .agg(
                samples_with_gene=("present_count", lambda x: (x > 0).sum()),
                total_samples=("present_count", "count")
            )
            .reset_index()
        )
        country_stats["Presence (%)"] = 100 * country_stats["samples_with_gene"] / country_stats["total_samples"]

        # === Add "Overall" aggregation per Gene and Geographic Location ===
        overall_stats = (
            country_stats
            .groupby(["Gene", "Geographic Location"], as_index=False)
            .agg({
                "samples_with_gene": "sum",
                "total_samples": "sum"
            })
        )
        overall_stats["Collection Date"] = "Overall"
        overall_stats["Presence (%)"] = 100 * overall_stats["samples_with_gene"] / overall_stats["total_samples"]

        # Combine with country_stats
        country_stats = pd.concat([country_stats, overall_stats], ignore_index=True)
        country_stats.sort_values(by=["Gene", "Geographic Location", "Collection Date"], inplace=True)

        # Extract unique genes and dates
        genes = sorted(country_stats["Gene"].unique())
        dates = sorted(country_stats["Collection Date"].unique(), key=lambda x: (x == "Overall", x))

        # Create a matrix of all gene-date combinations with their trace indices
        trace_matrix = {}
        traces = []
        trace_idx = 0
        
        for gene in genes:
            trace_matrix[gene] = {}
            for date in dates:
                filtered_data = country_stats[
                    (country_stats["Gene"] == gene) &
                    (country_stats["Collection Date"] == date)
                ]

                if not filtered_data.empty:
                    trace = dict(
                        type="choropleth",
                        locations=filtered_data["Geographic Location"],
                        locationmode="country names",
                        z=filtered_data["Presence (%)"],
                        colorscale="Reds",
                        zmin=0,
                        zmax=100,
                        hovertemplate='<b>%{hovertext}</b><br>' +
                                    'Presence: %{z:.1f}%<br>' +
                                    'Detected: %{customdata[0]}<br>' +
                                    'Total Sequences: %{customdata[1]}<extra></extra>',
                        hovertext=filtered_data["Geographic Location"],
                        customdata=list(zip(filtered_data["samples_with_gene"], filtered_data["total_samples"])),
                        visible=(gene == genes[0] and date == dates[0]),
                        name=f"{gene} - {date}"
                    )
                    traces.append(trace)
                    trace_matrix[gene][date] = trace_idx
                    trace_idx += 1
                else:
                    trace_matrix[gene][date] = None

        # Set up layout
        fig = dict(
            data=traces,
            layout=dict(
                title=f"Gene Presence Map: {genes[0]} ({dates[0]})",
                geo=dict(showframe=False, showcoastlines=True, projection_type='equirectangular'),
                coloraxis=dict(colorbar=dict(title="Presence (%)"), colorscale="Reds", cmin=0, cmax=100),
                margin=dict(l=40, r=40, t=80, b=40)
            )
        )

        # Create buttons for each gene-date combination
        all_buttons = []
        
        for gene in genes:
            for date in dates:
                if trace_matrix[gene].get(date) is not None:
                    # Create visibility array - only this combination is visible
                    visibility = [False] * len(traces)
                    visibility[trace_matrix[gene][date]] = True
                    
                    button = dict(
                        label=f"{gene} | {date}",
                        method="update",
                        args=[
                            {"visible": visibility},
                            {"title": f"Gene Presence Map: {gene} ({date})"}
                        ]
                    )
                    all_buttons.append(button)

        # Single dropdown with all combinations
        fig['layout']['updatemenus'] = [
            dict(
                buttons=all_buttons,
                direction="down",
                showactive=True,
                x=0.1,
                y=1.02,
                xanchor="left",
                yanchor="top",
                type="dropdown"
            )
        ]
        
        fig['layout']['annotations'] = [
            dict(text="Select Gene | Date:", x=0.03, y=1.02, xref="paper", yref="paper", showarrow=False)
        ]

        # Save interactive HTML
        import plotly.graph_objects as go
        plotly_fig = go.Figure(fig)
        base_name = os.path.basename(tidy_file).replace("_tidy_summary.csv", "")
        output_file = os.path.join(output_dir, f"Resistance_gene_geographic_distribution.html")
        plotly_fig.write_html(output_file)
        logging.info(f"Interactive map saved to {output_file}")

        return plotly_fig

    except Exception as e:
        logging.error(f"Error generating interactive map: {e}")
        raise

def generate_index_html(html_dir, figures_dir, output_file="index.html"):
    # Find all .html files in html_dir (excluding index.html itself)
    html_files = sorted([
        f for f in os.listdir(html_dir)
        if f.endswith(".html") and f != output_file
    ])

    if not html_files:
        print("No HTML files found in:", html_dir)
        return

    # Build dropdown options
    dropdown_options = "\n".join(
        [f'<option value="{os.path.join(os.path.basename(html_dir), file)}">{file}</option>' for file in html_files]
    )

    # Build HTML template
    html_template = f"""<!DOCTYPE html>
<html>
<head>
  <title>Plot Viewer</title>
  <style>
    body {{ font-family: Arial, sans-serif; text-align: center; margin: 0; padding: 20px; background: white; }}
    select {{ font-size: 16px; padding: 8px; }}
    iframe {{ width: 90vw; height: 80vh; border: none; margin-top: 20px; }}
  </style>
</head>
<body>

  <h2>Select Plot to View</h2>
  <select id="plotSelector" onchange="loadPlot()">
    {dropdown_options}
  </select>

  <iframe id="plotFrame" src="{os.path.join(os.path.basename(html_dir), html_files[0])}"></iframe>

  <script>
    function loadPlot() {{
      const selector = document.getElementById("plotSelector");
      const frame = document.getElementById("plotFrame");
      frame.src = selector.value;
    }}
  </script>

</body>
</html>
"""

    # Save index.html in figures_dir
    output_path = os.path.join(figures_dir, output_file)
    with open(output_path, "w") as f:
        f.write(html_template)

    print(f"✅ index.html generated at: {output_path}")

def main(ncbi_dir, abricate_dir, output_dir, fig_format, nseq, genep):
    """Main function to process data and generate outputs."""
    logging.info("Starting the script.")
    
    # Define paths
    ncbi_clean_path = os.path.join(ncbi_dir, "ncbi_clean.csv")
    abricate_summary_files = glob.glob(os.path.join(abricate_dir, "*summary.[ct][sa][bv]"))  # Match .csv and .tab files
    abricate_results_files = glob.glob(os.path.join(abricate_dir, "*results.[ct][sa][bv]"))
    
    if not os.path.exists(ncbi_clean_path):
        logging.error(f"ncbi_clean.csv not found in {ncbi_dir}.")
        return
    
    if not abricate_summary_files:
        logging.error(f"No CSV or TAB summary files (abricate) found in {abricate_dir}.")
        return
        
    if not abricate_results_files:
        logging.error(f"No CSV or TAB results files (abricate) found in {abricate_dir}.")
        return
    
    # Create output subdirectories
    merged_output_dir = os.path.join(output_dir, "merged_output")
    figures_dir = os.path.join(output_dir, "figures")
    os.makedirs(merged_output_dir, exist_ok=True)
    os.makedirs(figures_dir, exist_ok=True)
    
    for abricate_summary_file in abricate_summary_files:
        try:
            # Get base name without extension
            base = os.path.splitext(abricate_summary_file)[0]
            csv_file = base + ".csv"
            tab_file = base + ".tab"

            if os.path.exists(csv_file):
                abricate_summary_file = csv_file
            elif os.path.exists(tab_file):
                convert_tab_to_csv(tab_file, csv_file)
                abricate_summary_file = csv_file
            else:
                print(f"⚠️ No .csv or .tab found for {os.path.basename(base)}, skipping.")
                continue

            # Load and merge data
            merged_df = load_and_merge_data(ncbi_clean_path, abricate_summary_file)
            
            # Save merged data
            output_filename = f"ncbi_{os.path.basename(abricate_summary_file)}"
            save_merged_data(merged_df, merged_output_dir, output_filename)
            
            # Convert to tidy format
            tidy_df = convert_to_tidy_format(merged_df)

            # Add RESISTANCE column from corresponding results file
            basename = os.path.basename(abricate_summary_file).replace("_summary.csv", "")
            expected_results_file = os.path.join(abricate_dir, f"{basename}_results.csv")

            # Convert .tab to .csv if needed
            if not os.path.exists(expected_results_file):
                tab_version = expected_results_file.replace(".csv", ".tab")
                if os.path.exists(tab_version):
                    convert_tab_to_csv(tab_version, expected_results_file)
                    logging.info(f"Converted {tab_version} to CSV.")
                else:
                    logging.warning(f"Results file not found for {basename}. Skipping RESISTANCE enrichment.")
                    expected_results_file = None
            # Merge RESISTANCE and GENE data
            if expected_results_file and os.path.exists(expected_results_file):
                try:
                    results_df = pd.read_csv(expected_results_file, dtype=str)
                    results_df = results_df[['GENE', 'RESISTANCE']].drop_duplicates()

                    # Perform the merge using 'Gene' from tidy_df and 'GENE' from results_df
                    tidy_df = tidy_df.merge(results_df, left_on='Gene', right_on='GENE', how='left')

                    # Drop redundant GENE column
                    tidy_df.drop(columns=['GENE'], inplace=True)

                    logging.info(f"Successfully added RESISTANCE info for {basename}.")
                except Exception as e:
                    logging.error(f"Error merging RESISTANCE data from {expected_results_file}: {e}")

            tidy_file = os.path.join(merged_output_dir, output_filename.replace("_summary.csv", "_tidy_summary.csv"))
            tidy_df.to_csv(tidy_file, index=False)
            logging.info(f"Tidied file saved to {tidy_file}")
            
            # Analyze gene prevalence and generate figures
            base_name = os.path.basename(abricate_summary_file).replace("_summary.csv", "")
            analyze_gene_presence(tidy_df, figures_dir, base_name, fig_format)
            

            # Generate boxplot for gene identity
            generate_gene_identity_boxplot(tidy_file, figures_dir, fig_format)
            # Generate interactive boxplot for gene identity
            generate_gene_identity_boxplot_plotly(tidy_file, figures_dir)

            
            # Create subdirectory 'mean_ARG' inside figures_dir
            mean_ARG_dir = os.path.join(figures_dir, "mean_ARG")
            os.makedirs(mean_ARG_dir, exist_ok=True)
                # Generate mean ARG based on different groups
            generate_mean_arg_lollipop(tidy_file, mean_ARG_dir, fig_format, group_by="Geographic Location")
            generate_mean_arg_lollipop(tidy_file, mean_ARG_dir, fig_format, group_by="Collection Date")
            generate_mean_arg_lollipop(tidy_file, mean_ARG_dir, fig_format, group_by="Continent")
            generate_mean_arg_lollipop(tidy_file, mean_ARG_dir, fig_format, group_by="Subcontinent")


            # Generate barplots for resistance compariosn
            generate_resistance_barplot(tidy_file, figures_dir, fig_format)

            # Create subdirectory 'heatmap' inside figures_dir
            heatmap_dir = os.path.join(figures_dir, "heatmap")
            os.makedirs(heatmap_dir, exist_ok=True)
                # Generate country comparison heatmap
            generate_comparison_heatmap(tidy_file, heatmap_dir, fig_format, group_col="Geographic Location", resistance_col="RESISTANCE", genep_threshold=args.genep, nseq_threshold=args.nseq)
            generate_comparison_heatmap(tidy_file, heatmap_dir, fig_format, group_col="Collection Date", resistance_col="RESISTANCE", genep_threshold=args.genep, nseq_threshold=args.nseq)
            generate_comparison_heatmap(tidy_file, heatmap_dir, fig_format, group_col="Continent", resistance_col="RESISTANCE", genep_threshold=args.genep, nseq_threshold=args.nseq)
            generate_comparison_heatmap(tidy_file, heatmap_dir, fig_format, group_col="Subcontinent", resistance_col="RESISTANCE", genep_threshold=args.genep, nseq_threshold=args.nseq)
            

            # Plotly plots
            # Generate interactive heatmap with multiple grouping options
            generate_comparison_heatmap_plotly(tidy_file, figures_dir)
            # Generate geographic resistance map
            generate_geographic_resistance_map_plotly(tidy_file, figures_dir)
            # Generate mean ARG resistance analysis box plot
            mean_Arg_resistance_analysis_plotly(tidy_file, figures_dir)
            # Generate interactive lollipop plot with dropdown for group selection
            generate_mean_arg_lollipop_plotly(tidy_file, figures_dir)


            # Generate correlation scatterplot analysis
            print("Generating Geographic Location analysis...")
            correlation_scatterplot_analysis(tidy_file, figures_dir, group_col="Geographic Location")
            print("Generating Continent analysis...")
            correlation_scatterplot_analysis(tidy_file, figures_dir, group_col="Continent")
            print("Generating Subcontinent analysis...")
            correlation_scatterplot_analysis(tidy_file, figures_dir, group_col="Subcontinent")
             # Combine the three CSV files
            print("Combining correlation summary CSV files...")
            combined_correlation_analysis(figures_dir)


            # Keeping all the html files in html_dir inside figures_dir
            html_dir = os.path.join(figures_dir, "html_files")
            os.makedirs(html_dir, exist_ok=True)
            # Move all HTML files to html_dir
            for file in os.listdir(figures_dir):
                if file.endswith(".html"):
                    src_path = os.path.join(figures_dir, file)
                    dest_path = os.path.join(html_dir, file)
                    os.rename(src_path, dest_path)
                    logging.info(f"Moved {file} to {html_dir}")
            # Generate index.html for easy navigation
            generate_index_html(html_dir, figures_dir)

            # Moves every .csv file in figures_dir into that subdirectory.
            stat_analysis_dir = os.path.join(figures_dir, "Stat_analysis")
            os.makedirs(stat_analysis_dir, exist_ok=True)

            # Move all .csv files from figures_dir to figures_dir/Stat_analysis
            for file in os.listdir(figures_dir):
                if file.endswith(".csv"):
                    src = os.path.join(figures_dir, file)
                    dst = os.path.join(stat_analysis_dir, file)
                    shutil.move(src, dst)

        except Exception as e:
            logging.error(f"Error processing {abricate_summary_file}: {e}")
    
    logging.info("panr run successfully.")


if __name__ == "__main__":
    # Set up argument parser
    parser = argparse.ArgumentParser(description="Process NCBI and Abricate data.")
    parser.add_argument("--ncbi-dir", required=True, help="Directory containing ncbi_clean.csv.")
    parser.add_argument("--abricate-dir", required=True, help="Directory containing Abricate summary CSV or TAB files.")
    parser.add_argument("--output-dir", required=True, help="Base output directory.")
    parser.add_argument("--genep", type=float, default=10.0, help="Minimum %% gene presence to include in heatmap.")
    parser.add_argument("--nseq", type=int, default=1, help="Minimum number of sequences required per group in heatmaps.")
    parser.add_argument("--format", default="tiff", choices=["tiff", "svg", "png", "pdf"], help="Output format for figures (tiff, svg, png, pdf).")
    parser.add_argument('--version', action='version', version='PanR2 0.1.1')

    args = parser.parse_args()
    
    # Run the main function
    main(args.ncbi_dir, args.abricate_dir, args.output_dir, args.format, args.nseq, args.genep)
    
    
    
    
    
    

