#!python

import argparse
import collections
import logging
import json

from tqdm import tqdm

from opusfilter.pipeline import FilterPipeline
from opusfilter.util import file_open, yaml

logger = logging.getLogger(__name__)


def read_lines(infs):
    """Read parallel segments without newlines"""
    for pair in tqdm(infs):
        yield [segment.rstrip() for segment in pair]


logging.basicConfig(level=logging.INFO)
logging.getLogger('mosestokenizer.tokenizer.MosesTokenizer').setLevel(logging.WARNING)

parser = argparse.ArgumentParser(prog='opusfilter-test',
    description='Test filters on parallel text data')

parser.add_argument('files', nargs='+', metavar='FILE', help='parallel text input file(s)')
parser.add_argument('--yaml', metavar='FILE', help='load YAML configuration file for the filters to test')
parser.add_argument('--add', nargs=2, action='append', default=[], metavar=('CLASS', 'JSON'),
                    help='add filter of CLASS with JSON parameters object to the test')
parser.add_argument('--removed', metavar='FILE', help='write removed segments to JSONL file')
parser.add_argument('--scores', metavar='FILE', help='write filter scores to JSONL file')

args = parser.parse_args()

config = yaml.load(open(args.yaml)) if args.yaml else []
for name, jsonstr in args.add:
    config.append({name: json.loads(jsonstr)})

filter_pipe = FilterPipeline.from_config(config)
infs = [file_open(infile) for infile in args.files]
total = 0
counter = collections.Counter()
logger.info("Calculating total")
for lines in read_lines(zip(*infs)):
    total += 1

if args.removed:
    removed_fobj = file_open(args.removed, 'w')

for filter_ in filter_pipe.filters:
    for inf in infs:
        inf.seek(0)
    name = filter_.name if filter_.name is not None else filter_.__class__.__name__
    logger.info("Testing {}".format(name))
    counter[name] = 0
    for segments in filter_.filterfalse(read_lines(zip(*infs))):
        counter[name] += 1
        if args.removed:
            obj = {'filter': name, 'segments': segments}
            removed_fobj.write(json.dumps(obj, sort_keys=True)+'\n')

for key, value in counter.items():
    print("{}: Removes {} / {} ({:.1f}%)".format(key, value, total, 100 * value / total))

if args.scores:
    logger.info("Collecting scores")
    for inf in infs:
        inf.seek(0)
    with open(args.scores, 'w') as fobj:
        for score_obj in filter_pipe.score(read_lines(zip(*infs))):
            fobj.write(json.dumps(score_obj, sort_keys=True)+'\n')

for inf in infs:
    inf.close()
