#!python
"""Estimates the probability of rows having a specific value for a column.

Input rows are read from a CSV file, or the population model's data if no file
is specified. The estimated probabilities are appended to the data rows as an
additional column, and written to a new CSV file.
"""

import argparse
import logging
import math
import pandas
import sys

import edpanalyst

logger = logging.getLogger(__name__)
MAX_RETRIES = 3
BATCH_SIZE = 1000


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--pmid', required=True, help='Population model ID.')
    parser.add_argument(
        '--input_csv', help='CSV containing rows for which to predict values. '
        'If not specified, the population data is used.')
    parser.add_argument(
        '--target_column', required=True,
        help='Column whose values to predict. Must be categorical.')
    parser.add_argument(
        '--target_value', required=True,
        help='Value of `target_column` for which to predict probability.')
    parser.add_argument(
        '--ignore_existing_values', action='store_true',
        help='If set, disregard existing values of `target_column` and '
        'use the model to predict. By default, existing values for '
        '`target_column` will produce output probabilities of 0 or 1.')
    parser.add_argument(
        '--predicted_prob_column', required=True,
        help='Column with predicted probability to write to the CSV output.')
    parser.add_argument('--output_csv', required=True,
                        help='File name of output CSV.')
    parser.add_argument('--quiet', action='store_true',
                        help='If set, suppress informational output.')
    return parser.parse_args()


def is_present(v):
    # Missing values show up as NaNs, even with dtype=str when reading CSVs.
    if isinstance(v, float):
        return not math.isnan(v)
    return (v is not None) and (v != '')


def is_modeled(schema, col):
    try:
        return schema[col].stat_type != 'void'
    except KeyError:
        return False


def convert_value(schema, col, value):
    if schema[col].is_real():
        return float(value)
    else:
        return str(value)


def probabilities_for_rows(pm, schema, rows, target_column, target_value,
                           ignore_existing):
    # Build the table of givens for a single `joint_probability` call.
    # We'll need to use all columns that are present in at least one row.
    given_columns = set()
    # While checking each row, record an exact probability if the target column
    # has a value (and we're not ignoring it).
    exact_probs = {}
    for index, row in enumerate(rows):
        for col, value in row.items():
            if col == target_column:
                if is_present(value) and not ignore_existing:
                    exact_probs[index] = (1.0
                                          if value == target_value else 0.0)
            else:
                if is_present(value) and is_modeled(schema, col):
                    given_columns.add(col)

    # Cache the categorical values for all columns so we can efficiently detect
    # unknown values in data rows.
    cat_values = {}
    for col in given_columns:
        cat_values[col] = set(v.value for v in (schema[col].values or []))
    # Second pass to build the givens table.
    givens = {col: [] for col in given_columns}
    for row in rows:
        for col in given_columns:
            val = row[col]
            if not is_present(val):
                givens[col].append(None)
            elif cat_values[col] and val not in cat_values[col]:
                # Unknown categorical value, perhaps because it wasn't present
                # in the data that the model was built from. Ignore for now.
                # TODO(bnenning): Come up with a better way of handling this.
                # Ideally the model should be rebuilt with all expected values.
                givens[col].append(None)
            else:
                givens[col].append(convert_value(schema, col, val))

    target = {target_column: [target_value] * len(rows)}
    response = None
    retries = 0
    while response is None:
        try:
            response = pm.joint_probability(target, givens=givens)
        except Exception:
            # Allow for transient network or server failures.
            retries += 1
            if retries > MAX_RETRIES:
                raise
            else:
                logger.exception('Retrying')

    probs = list(response['p'])
    # Override any rows we have an exact probability for.
    for index, p in exact_probs.items():
        probs[index] = p
    return probs


def main(args):
    logging.basicConfig(level=logging.WARNING if args.quiet else logging.INFO)
    session = edpanalyst.Session()
    pm = session.popmod(args.pmid)
    if args.input_csv:
        logger.info('Reading CSV from %s', args.input_csv)
        # Force strings since some categorical values look like numbers.
        rows = pandas.read_csv(args.input_csv, dtype=str)
    else:
        logger.info('Using data in population')
        rows = pm.select()

    records = rows.to_dict('records')
    logger.info('Read %d rows', len(rows))

    probs = []
    row_index = 0
    while row_index < len(records):
        batch = records[row_index:row_index + BATCH_SIZE]
        try:
            batch_probs = probabilities_for_rows(
                pm, pm.schema, batch, args.target_column, args.target_value,
                args.ignore_existing_values)
        except:
            logger.exception('Error calling EDP')
            sys.exit(1)
        probs.extend(batch_probs)
        row_index += BATCH_SIZE
        if (row_index % 5000 == 0):
            logger.info('Processed %d rows', row_index)

    assert len(probs) == len(rows)
    nonnull_probs = [p for p in probs if p is not None]
    if not nonnull_probs:
        logger.warn('No probabilities were computed')
        sys.exit(1)
    logger.info('Mean probability: %.3f',
                sum(nonnull_probs) / len(nonnull_probs))
    output_df = pandas.concat(
        [rows, pandas.Series(probs, name=args.predicted_prob_column)], axis=1)
    logger.info('Writing to %s', args.output_csv)
    with open(args.output_csv, 'wt') as f:
        f.write(output_df.to_csv(index=False))


if __name__ == '__main__':
    main(parse_args())
