"""Test different optimization settingss."""


import itertools
import pickle
import re
import sys
import time

import pandas as pd

import polyclonal


# define parameters to explore
params_to_test = {
    'noise': ['exact', 'noisy'],
    'sitefirst': ['sitefirst', 'nosite'],
    'collapse': ['collapsed', 'uncollapsed'],
    'reg_escape_weight': [0, 0.001, 0.005, 0.01, 0.05, 0.1],
    'reg_spread_weight': [0, 0.05, 0.25, 0.5, 1],
    }
params_to_test_df = pd.DataFrame.from_records(
                            list(itertools.product(*params_to_test.values())),
                            columns=params_to_test,
                            )
paramspace = snakemake.utils.Paramspace(params_to_test_df)


rule all:
    input:
        'results/opt_settings/aggregated_fit_results.csv'

rule prep_variants_df:
    """Data used to fit model."""
    input: csv="../notebooks/RBD_variants_escape_{noise}.csv"
    output: csv="results/opt_settings/variants_{noise}.csv"
    run:
        (pd.read_csv(input.csv, na_filter=None)
         .query('library == "avg2muts"')
         .query('concentration in [0.25, 1, 4]')
         .to_csv(output.csv, index=False)
         )

rule prep_test_variants:
    """Data used to test variant predictions."""
    input: csv='../notebooks/RBD_variants_escape_exact.csv'
    output: csv='results/opt_settings/test_variants.csv'
    run:
        (pd.read_csv(input.csv, na_filter=None)
         .query('library == "avg3muts"')
         .query('concentration in [0.25, 1, 4]')
         .to_csv(output.csv, index=False)
         )

rule fit_model:
    """Fit a model."""
    input:
        data_csv=rules.prep_variants_df.output.csv,
        test_csv=rules.prep_test_variants.output.csv,
        mut_escape='../notebooks/RBD_mut_escape_df.csv',
    output:
        csv=f"results/opt_settings/{paramspace.wildcard_pattern}.csv",
        pickle=f"results/opt_settings/{paramspace.wildcard_pattern}.pickle",
    log: log=f"results/opt_settings/{paramspace.wildcard_pattern}.log",
    run:
        variants_df = pd.read_csv(input.data_csv, na_filter=None)
        poly_abs = polyclonal.Polyclonal(
                data_to_fit=variants_df,
                activity_wt_df=pd.DataFrame.from_records(
                           [('epitope 1', 3.0),
                            ('epitope 2', 2.0),
                            ('epitope 3', 1.0),
                            ],
                        columns=['epitope', 'activity'],
                        ),
                site_escape_df=pd.DataFrame.from_records(
                           [('epitope 1', 484, 10.0),
                            ('epitope 2', 446, 10.0),
                            ('epitope 3', 417, 10.0),
                            ],
                        columns=['epitope', 'site', 'escape'],
                        ),
                data_mut_escape_overlap='fill_to_data',
                collapse_identical_variants='mean' if wildcards.collapse == 'collapsed' else False,
                )

        print('Start model fitting')
        start_time = time.time()
        with open(log.log, 'w') as f:
            opt_res = poly_abs.fit(
                    reg_escape_weight=float(wildcards.reg_escape_weight),
                    reg_spread_weight=float(wildcards.reg_spread_weight),
                    log=f,
                    logfreq=5,
                    fit_site_level_first=True if wildcards.sitefirst == 'sitefirst' else False,
                    )
        print('Model fitting complete')

        # correlation of predictions on test variants
        test_data = pd.read_csv(input.test_csv, na_filter=None)
        exact_vs_pred = poly_abs.prob_escape(variants_df=test_data)
        corr_prob_escape = exact_vs_pred['prob_escape'].corr(
                                exact_vs_pred['predicted_prob_escape'])

        # correlation of underlying mut escape (beta) values
        exact_mut_escape = pd.read_csv(input.mut_escape)
        corrs = []
        for _, exact_ep_escape in exact_mut_escape.groupby('epitope'):
            exact_vs_pred = (
                exact_ep_escape
                .rename(columns={'escape': 'actual_escape'})
                [['mutation', 'actual_escape']]
                .merge(poly_abs.mut_escape_df.rename(columns={'escape': 'pred_escape'}),
                       on='mutation',
                       validate='one_to_many',
                       )
                )
            corrs.append(
                exact_vs_pred
                .groupby('epitope')
                .apply(lambda x: x['actual_escape'].corr(x['pred_escape']))
                .max()
                )
        corr_mut_escape = sum(corrs) / len(corrs)

        elapsed = time.time() - start_time
        print(f"Writing results to {output.csv}")
        with open(output.csv, 'w') as f:
            f.write('loss,time,corr_prob_escape,corr_mut_escape\n'
                    f"{opt_res.fun:.5g},{elapsed:.5g},{corr_prob_escape:.5g},"
                    f"{corr_mut_escape:.5g}\n")

        print(f"Writing model to {output.pickle}")
        with open(output.pickle, 'wb') as f:
            pickle.dump(poly_abs, f)

rule aggregate_fit_results:
    """Aggregate single CSV of fitting results."""
    input: 
        csvs=expand("results/opt_settings/{params}.csv",
                    params=paramspace.instance_patterns)
    output: csv='results/opt_settings/aggregated_fit_results.csv'
    params: wc_values=paramspace.instance
    run:
        dfs = []
        for csv_file in input.csvs:
            csv_params = {param: re.search(f"{param}~([^/]+)(/|\.csv)",
                                           csv_file).group(1)
                          for param in params_to_test}
            dfs.append(pd.read_csv(csv_file).assign(**csv_params))
        pd.concat(dfs).to_csv(output.csv, index=False)
