#!python

import argparse
import os
from os import path
import zipfile

from opustools_pkg import OpusRead
from random import Random
from xml.parsers.expat import ExpatError

all_collections = [
  'ada83', 'Bianet', 'bible-uedin', 'Books', 'CAPES', 'DGT', 'DOGC', 'ECB', 'EhuHac', 'Elhuyar', 'EMEA', 'EUbookshop', 'EUconst',
  'Europarl', 'Finlex', 'fiskmo', 'giga-fren', 'GlobalVoices', 'GNOME', 'hrenWaC', 'JRC-Acquis', 'KDE4', 'KDEdoc', 'MBS', 'memat',
  'MontenegrinSubs', 'MPC1', 'MultiUN', 'News-Commentary', 'OfisPublik', 'OpenOffice', 'OpenSubtitles', 'ParaCrawl', 'PHP', 'QED',
  'RF', 'sardware', 'SciELO', 'SETIMES', 'SPC', 'Tanzil', 'Tatoeba', 'TED2013', 'TedTalks', 'TEP', 'TildeMODEL', 'Ubuntu', 'UN',
  'UNPC', 'wikimedia', 'Wikipedia', 'WikiSource', 'WMT-News', 'XhosaNavy'
]

collection_choices = ['ALL'] + all_collections

parser = argparse.ArgumentParser(description='All aboard the OPUS Express! Create test/dev/train sets from OPUS data.')

parser.add_argument('-f', '--force', help='suppress warnings (default: False)', action='store_true')
parser.add_argument('-s', '--src-lang', help='source language (e.g. `en\')', type=str, metavar='lang_id', required=True)
parser.add_argument('-t', '--tgt-lang', help='target language (e.g. `pt\')', type=str, metavar='lang_id', required=True)
parser.add_argument('-c', '--collections', help='OPUS collection(s) to fetch (default: `OpenSubtitles\')\nCollections list:\n%s' % str(collection_choices), nargs='*', choices=collection_choices, metavar='coll_name', default=['OpenSubtitles'])
parser.add_argument('--root-dir', help='Root directory for OPUS (default:`/proj/nlpl/data/OPUS\')', type=str, metavar='/path/to/OPUS', default='/proj/nlpl/data/OPUS')
parser.add_argument('--test-override', help='path to file containing resource IDs to reserve for the test set (default: None)', type=str, metavar='/path/to/file', default=None)
parser.add_argument('--test-quota', help='test set size in sentences (default: 10000)', type=int, metavar='num_sents', default=10000)
parser.add_argument('--dev-quota', help='development set size in sentences (default: 10000)', type=int, metavar='num_sents', default=10000)
parser.add_argument('--doc-bounds', help='preserve document blocks (also marks document boundaries) (default: False)', action='store_true')
parser.add_argument('--quality-aware', help='reserve one-to-one aligned samples with high overlap for test/dev sets (incompatible with `--doc-bounds\') (default: False)', action='store_true')
parser.add_argument('--overlap-threshold', help='threshold for alignment overlap in `--quality-aware\' mode (default: 0.8)', type=float, metavar='min_pct', default=0.8)
parser.add_argument('--shuffle', help='shuffle samples (incompatible with `--doc-bounds\') (default: False)', action='store_true')
parser.add_argument('--test-set', help='filename stub for output test set (default: `test\')', type=str, metavar='filename', default='test')
parser.add_argument('--dev-set', help='filename stub for output development set (default: `dev\')', type=str, metavar='filename', default='dev')
parser.add_argument('--train-set', help='filename stub for output training set (default: `train\')', type=str, metavar='filename', default='train')

args = parser.parse_args()

if not args.force and args.shuffle and args.doc_bounds:
  answer = input('Using --doc-bounds will override --shuffle. Continue? (y/n) ')
  
  if answer != 'y':
    args.shuffle = False
    exit()

if not args.force and args.quality_aware and args.doc_bounds:
  answer = input('Using --doc-bounds will override --quality-aware. Continue? (y/n) ')
  
  if answer != 'y':
    args.doc_bounds = False
    exit()

src_lang = args.src_lang
tgt_lang = args.tgt_lang

src_lang, tgt_lang = sorted((src_lang, tgt_lang))

collections = args.collections

if 'ALL' in collections:
  collections = all_collections

test_override_path = args.test_override
test_override = set()

if test_override_path:
  print('Caching test override document IDs...')

  with open(test_override_path, 'r') as test_override_file:
    for line in test_override_file:
      line = line.strip()

      if line:
        doc_id = int(line)
        test_override.add(doc_id)

  print('...done!')

seed = 8675309
fixed = Random(seed)

dump_test = []
dump_hiqu = []
dump_loqu = []

test, test_size = [], 0
dev, dev_size = [], 0
train, train_size = [], 0

#########################
## OPUS READ OVERRIDES ##
#########################

class OpusReadPipe(OpusRead):
  
  def __init__(self, arguments):
    self.bin = []
    super(OpusReadPipe, self).__init__(arguments)

  def sendPairOutput(self, wpair):
    if not self.bin:
      self.bin = [[None, None]]
    
    if self.bin[-1][0]:
      self.bin.append([None, None])
    
    self.bin[-1][0] = wpair
  
  def sendIdOutput(self, id_details):
    if not self.bin:
      self.bin = [[None, None]]
    
    if self.bin[-1][1]:
      self.bin.append([None, None])
    
    self.bin[-1][1] = id_details

#####################
## DATA PROCESSING ##
#####################

for collection in collections:
  archive_path = args.root_dir + '/%s/latest/xml/%s-%s.xml.gz' % (collection, src_lang, tgt_lang)
  
  if not os.path.isfile(archive_path):
    print('Skipping %s (no %s-%s)...' % (collection, src_lang, tgt_lang))
    continue
  
  print('Checking out %s...' % collection)
  
  try:
    tmp = 'tmp-%d' % fixed.randint(10000000, 99999999)
    
    reader = OpusReadPipe(('-rd %s ' % (args.root_dir + '/')
                         + '-a overlap '
                         + '-d %s ' % (collection)
                         + '-s %s -t %s ' % (src_lang, tgt_lang)
                         + '-wm moses -w %s %s -id %s' % (tmp, tmp, tmp)).split())
    
    reader.printPairs() # Doesn't actually print (since overridden), but passes on data.
    
    os.remove(tmp)
    
    if not reader.bin:
      print('...skipping %s (no %s-%s).' % (collection, src_lang, tgt_lang))
    
    else:
      print('...%d samples processed.' % len(reader.bin))
      
      print('...collating samples...')
      
      for wpair, id_details in reader.bin:
        src_sent, tgt_sent = wpair
        
        src_sent = src_sent.strip()
        tgt_sent = tgt_sent.strip()
        
        src_uri, tgt_uri, src_align, tgt_align, overlap = id_details
        
        resource_id = src_uri.split('/')[2] if collection == 'OpenSubtitles' else 0
        one_to_one = len(src_align) == len(tgt_align) == 1
        overlap = float(overlap) if overlap is not None else float('-inf')
        
        if resource_id in test_override:
          dump = dump_test
        elif args.quality_aware and one_to_one and overlap >= args.overlap_threshold:
          dump = dump_hiqu
        else:
          dump = dump_loqu
        
        stats_line = '\t'.join([src_uri, tgt_uri,
            ' '.join(src_align), ' '.join(tgt_align), str(overlap)])
        
        dump.append((src_sent, tgt_sent, stats_line))
  except ExpatError:
    print('...skipping (ill-formatted XML).')
  except FileNotFoundError:
    print('...skipping (broken links / XML archive not found).')
  except KeyError:
    print('...skipping (broken links / XMLs missing in archive).')
  except Exception:
    print('...skipping (unknown exception).')
  else:
    print('...done!')
    
if not dump_test and not dump_hiqu and not dump_loqu:
  print('Alas, OPUS Express turned up with no data!\n¯\_(ツ)_/¯')
  exit()

if args.shuffle:
  print('Pre-shuffling bins...')
  
  fixed.shuffle(dump_test)
  fixed.shuffle(dump_hiqu)
  fixed.shuffle(dump_loqu)
  
  print('...done!')

print('Splitting data into test/dev/train sets...')

for dump in [dump_test, dump_hiqu, dump_loqu]:
  for item in dump:
    if test_size < args.test_quota:
      dump = test
    elif dev_size < args.dev_quota:
      dump = dev
    else:
      dump = train
    
    if args.doc_bounds:
      if test and item[2] == test[-1][2]:
        dump = test
      elif dev and item[2] == dev[-1][2]:
        dump = dev
      elif train and item[2] == train[-1][2]:
        dump = train
      elif dump:
        dump.append(('', '', ''))
    
    dump.append(item)
    
    if dump == test:
      test_size += 1 if item != ('', '', '') else 0
    elif dump == dev:
      dev_size += 1 if item != ('', '', '') else 0
    else:
      train_size += 1 if item != ('', '', '') else 0
  
print('...done!')

if args.shuffle:
  print('Post-shuffling test/dev/train sets...')

  fixed.shuffle(test)
  fixed.shuffle(dev)
  fixed.shuffle(train)

  print('...done!')

test_set = args.test_set

test_src_path = '%s.%s' % (test_set, src_lang)
test_tgt_path = '%s.%s' % (test_set, tgt_lang)
test_ids_path = '%s.%s' % (test_set, 'ids')

if not args.force:
  while path.isfile(test_src_path) or path.isfile(test_tgt_path) or path.isfile(test_ids_path):
    answer = input('Using `--test-set %s\' will cause files to be overwritten. Please input another name, or type OVERWRITE to continue: ' % test_set)
    
    if answer == 'OVERWRITE':
      break
    else:
      test_set = answer
      test_src_path = '%s.%s' % (test_set, src_lang)
      test_tgt_path = '%s.%s' % (test_set, tgt_lang)
      test_ids_path = '%s.%s' % (test_set, 'ids')

with open(test_src_path, mode='w', encoding='utf-8') as test_src_file:
  with open(test_tgt_path, mode='w', encoding='utf-8') as test_tgt_file:
    with open(test_ids_path, mode='w', encoding='utf-8') as test_ids_file:
      print('Writing test data to `%s.{%s,%s,%s}\'...' % (test_set, src_lang, tgt_lang, 'ids'))
      
      num_written = 0
      for src_line, tgt_line, ids_line in test:
        test_src_file.write(src_line + '\n')
        test_tgt_file.write(tgt_line + '\n')
        test_ids_file.write(ids_line + '\n')
        
        num_written += 1
        if num_written % 1000000 == 0:
          print('%d/%d lines written...' % (num_written, len(test)))
      
      print('...done!')

dev_set = args.dev_set

dev_src_path = '%s.%s' % (dev_set, src_lang)
dev_tgt_path = '%s.%s' % (dev_set, tgt_lang)
dev_ids_path = '%s.%s' % (dev_set, 'ids')

if not args.force:
  while path.isfile(dev_src_path) or path.isfile(dev_tgt_path) or path.isfile(dev_ids_path):
    answer = input('Using `--dev-set %s\' will cause files to be overwritten. Please input another name, or type OVERWRITE to continue: ' % dev_set)
  
    if answer == 'OVERWRITE':
      break
    else:
      dev_set = answer
      dev_src_path = '%s.%s' % (dev_set, src_lang)
      dev_tgt_path = '%s.%s' % (dev_set, tgt_lang)
      dev_ids_path = '%s.%s' % (dev_set, 'ids')

with open(dev_src_path, mode='w', encoding='utf-8') as dev_src_file:
  with open(dev_tgt_path, mode='w', encoding='utf-8') as dev_tgt_file:
    with open(dev_ids_path, mode='w', encoding='utf-8') as dev_ids_file:
      print('Writing development data to `%s.{%s,%s,%s}\'...' % (dev_set, src_lang, tgt_lang, 'ids'))

      num_written = 0
      for src_line, tgt_line, ids_line in dev:
        dev_src_file.write(src_line + '\n')
        dev_tgt_file.write(tgt_line + '\n')
        dev_ids_file.write(ids_line + '\n')
        
        num_written += 1
        if num_written % 1000000 == 0:
          print('%d/%d lines written...' % (num_written, len(dev)))
      
      print('...done!')

train_set = args.train_set

train_src_path = '%s.%s' % (train_set, src_lang)
train_tgt_path = '%s.%s' % (train_set, tgt_lang)
train_ids_path = '%s.%s' % (train_set, 'ids')

if not args.force:
  while path.isfile(train_src_path) or path.isfile(train_tgt_path) or path.isfile(train_ids_path):
    answer = input('Using `--train-set %s\' will cause files to be overwritten. Please input another name, or type OVERWRITE to continue: ' % train_set)
  
    if answer == 'OVERWRITE':
      break
    else:
      train_set = answer
      train_src_path = '%s.%s' % (train_set, src_lang)
      train_tgt_path = '%s.%s' % (train_set, tgt_lang)
      train_ids_path = '%s.%s' % (train_set, 'ids')

with open(train_src_path, mode='w', encoding='utf-8') as train_src_file:
  with open(train_tgt_path, mode='w', encoding='utf-8') as train_tgt_file:
    with open(train_ids_path, mode='w', encoding='utf-8') as train_ids_file:
      print('Writing training data to `%s.{%s,%s,%s}\'...' % (train_set, src_lang, tgt_lang, 'ids'))

      num_written = 0
      for src_line, tgt_line, ids_line in train:
        train_src_file.write(src_line + '\n')
        train_tgt_file.write(tgt_line + '\n')
        train_ids_file.write(ids_line + '\n')

        num_written += 1
        if num_written % 1000000 == 0:
          print('%d/%d lines written...' % (num_written, len(train)))

      print('...done!')

