#!/usr/bin/env python3

import sys

from pathlib import Path
import urllib.request

import json, os
import ssl

from cvutils import CV 
from cvutils import Corpora 
from cvutils import Segmenter
from cvutils import Validator
from cvutils import Alphabet 
from cvutils import Phonemiser
from cvutils import Tokeniser

from cvutils import wikipedia

debug = False

def help():
	print('covo   dump     [pages-articles.xml.bz2]',file=sys.stderr)
	print('       dump-url [locale]',file=sys.stderr)
	print('       alphabet [locale]',file=sys.stderr)
	print('       segment  [locale]',file=sys.stderr)
	print('       tokenise [locale]',file=sys.stderr)
	print('       norm     [locale]',file=sys.stderr)
	print('       phon     [locale]',file=sys.stderr)
	print('       opus     [locale]',file=sys.stderr)
	print('       gb       [list|update] [locale]',file=sys.stderr)
	print('       filter   [locale]',file=sys.stderr)
	print('       clean    [locale] [tsv_file]',file=sys.stderr)
	print('       getckp   [locale] [path]',file=sys.stderr)
	print('       text     [locale] [tsv_file1 tsv_file2 ...]',file=sys.stderr)
	print('       avail    [task]',file=sys.stderr)
	print('       missing  [task]',file=sys.stderr)
	print('       export   [locale] [tsv_dir] [audio_dir]',file=sys.stderr)

if len(sys.argv) == 1:
	help()
	sys.exit(-1)

for arg in sys.argv:
	if arg == 'debug':
		debug = True

mode = sys.argv[1]

# Check the mode here

if mode == 'help' or mode == '--help':
	help()
	sys.exit(-1)

elif mode == 'dump':
	dump_name = sys.argv[2]
	if dump_name == '-':
		dump_name = '/dev/stdin'
	wikipedia.process(dump_name)

elif mode == 'dump-url':
	dump_locale = sys.argv[2]
	c = Corpora(dump_locale)
	print(c.dump_url())

elif mode == 'alphabet':
	locale = sys.argv[2]
	a = Alphabet(locale)
	print('\n'.join(a.get_alphabet()))
	print()
		
elif mode == 'tokenise':
	locale = sys.argv[2]
	t = Tokeniser(locale)
	line = sys.stdin.readline()
	while line:
		print(' '.join(t.tokenise(line)).strip())
		line = sys.stdin.readline()

elif mode == 'segment':
	locale = sys.argv[2]
	s = Segmenter(locale)
	line = sys.stdin.readline()
	while line:
		for sentence in s.segment(line):
			print(sentence)	
		
		line = sys.stdin.readline()

elif mode == 'validate' or mode == 'norm' or mode == 'valid':
	locale = sys.argv[2]
	def cleanup(skipped, chars, alphabet, count_valid, total):
		if debug:
			print('',file=sys.stderr)
			print('----', file=sys.stderr)
			print('\n'.join([s for s in skipped if not s.strip() == ""]), file=sys.stderr)
			print('',file=sys.stderr)
			#missing = list(chars - alphabet)
			missing = [(v, k) for (k, v) in chars.items()]
			missing.sort(reverse=True)
			a = list(alphabet)
			a.sort()
			print('Alphabet:', [(c, '%04x' % ord(c)) for c in a] , file=sys.stderr)
			print('Unhandled characters:\n----', file=sys.stderr)
			for (f, c) in missing:
				if c not in alphabet:
					print('%d\t%04x\t%s' % (f, ord(c), c), file=sys.stderr)			
		print('%d/%d (%.2f%%)' % (count_valid, total, (count_valid/total)*100.0),file=sys.stderr)
	

	v = Validator(locale)
	a = Alphabet(locale)
	line = sys.stdin.readline()
	chars = {}
	alphabet = set(a.get_alphabet())
	skipped = []
	count_valid = 0
	total = 0
	try: 
		while line:
			(valid, sent) = v.normalise(line)
			if valid:
				count_valid += 1
				print(sent)
#			elif not valid and debug:
#				print(sent + '\t---')
			else:
				skipped.append(sent)
				for c in sent:
					if c in alphabet:
						continue
					if c not in chars:
						chars[c] = 0
					chars[c] += 1
			total += 1	
			line = sys.stdin.readline()
	except KeyboardInterrupt:
		cleanup(skipped, chars, alphabet, count_valid, total)

	cleanup(skipped, chars, alphabet, count_valid, total)

elif mode == 'phon' or mode == 'phonemise':
	locale = sys.argv[2]
	p = Phonemiser(locale)
	for line in sys.stdin:
		phons = [p.phonemise(w) for w in line.split(' ')]
		print(' '.join([p for p in phons if p]))

elif mode == 'tran' or mode == 'translit' or mode == 'transliterate':
	from cvutils import Transliterator
	locale = sys.argv[2]
	t = Transliterator(locale,normalise=False)
	for line in sys.stdin:
		trans = [t.transliterate(w) for w in line.split(' ')]
		print(' '.join([t for t in trans if t]))



elif mode == 'opus':
	locale = sys.argv[2]
	c = Corpora(locale)
	crps = c.opus()
	if crps:
			
		crps.sort(reverse=True)
		for line in crps:
			if line[0] < 0:
				continue
			if locale in line[1]:
				continue
			print('%s\t%s' % (line[3][1], line[3][0]))

elif mode == 'text':
	import csv

	locale = sys.argv[2]
	v = Validator(locale)
	# Do this with pathlib to allow globbing
	for fn in sys.argv[3:]:
		with open(fn) as csv_file:
			csv_reader = csv.reader(csv_file, delimiter=',')
			next(csv_reader)
			for row in csv_reader:
				if not row:
					continue
				if len(row) != 3:
					continue
				(valid, sent) = v.normalise(row[2])
				if valid:
					print(sent)		

elif mode == 'clean':
	"""Cleans an existing .csv file"""
	import csv

	locale = sys.argv[2]
	v = Validator(locale)
	# Do this with pathlib to allow globbing
	total = 0
	valid_sents = 0
	discarded = []
	malformed = []
	with open(sys.argv[3]) as csv_file:
		csv_reader = csv.reader(csv_file, delimiter=',')
		csv_writer = csv.writer(sys.stdout, delimiter=',')
		csv_writer.writerow(next(csv_reader))
		for row in csv_reader:
			if not row:
				continue
			if len(row) != 3:
				total += 1
				malformed.append(total+1)
				continue
			(valid, sent) = v.normalise(row[2])
			if valid:
				valid_sents += 1
				row[2] = sent
				csv_writer.writerow(row)
			else:
				discarded.append(row)
			total += 1

	print('%d/%d (%.2f%%)' % (valid_sents, total, (valid_sents/total)*100), file=sys.stderr)
	missing = {}
	if debug:
		
		for row in discarded:
			(skipped, unalphabetic) = v.check(row[2])
			for c in unalphabetic:
				if c not in missing:
					missing[c] = 0
				missing[c] += 1
		missing = [(v, k) for (k, v) in missing.items()]
		missing.sort(reverse=True)
		print('Malformed rows:\n----', file=sys.stderr)
		print(' '.join([str(i) for i in malformed]), file=sys.stderr)
		print('Unhandled characters:\n----', file=sys.stderr)
		for (f, c) in missing:
			print('%d\t%04x\t%s' % (f, ord(c), c), file=sys.stderr)			


elif mode == 'filter':
	locale = sys.argv[2]
	umbral = 1
	if len(sys.argv) == 4:
		umbral = int(sys.argv[3])
	c = Corpora(locale)
	c.filter(sys.stdin, sys.stdout, umbral=1)

elif mode == 'check' or mode == 'missing':
	exclude = ['zh-TW', 'zh-CN', 'zh-HK', 'ja', 'nan-tw', 'tok', 'yue']
	check = ''
	if len(sys.argv) == 3:
		check = sys.argv[2]	
	g = urllib.request.urlopen('https://commonvoice.mozilla.org/api/v1/stats/languages')
	txt = g.read().strip()
	j_contributable = json.loads(txt.decode('utf-8'))
	contributable = set([i["locale"] for i in j_contributable if i["is_contributable"] == 1])
#	print(contributable)
	cv = CV()
	print('Missing: ')
	if check[:2] == 'al' or check == '':
		missing_alphabets = list(contributable - set(cv.alphabets()))
		missing_alphabets.sort()
		print(' Alphabets:', ' '.join([code for code in missing_alphabets if not code in exclude]))
	if check[:3] == 'val' or check == '':
		missing_validators = list(contributable - set(cv.validators()))
		missing_validators.sort()
		print(' Validators:', ' '.join([code for code in missing_validators if not code in exclude]))
	if check[:3] == 'pho' or check == '':
		missing_phonemisers = list(contributable - set(cv.phonemisers()))
		missing_phonemisers.sort()
		print(' Phonemisers:', ' '.join([code for code in missing_phonemisers if not code in exclude]))
	if check[:3] == 'seg' or check == '':
		missing_segmenters = list(contributable - set(cv.segmenters()))
		missing_segmenters.sort()
		print(' Segmenters:', ' '.join([code for code in missing_segmenters if not code in exclude]))

elif mode == 'avail':
	check = ''
	if len(sys.argv) == 3:
		check = sys.argv[2]	
	cv = CV()
	print('Available:')
	if check[:2] == 'al' or check == '':
		alphabets = cv.alphabets()
		alphabets.sort()
		print(' Alphabets:', ' '.join([code for code in alphabets]))
	if check[:3] == 'val' or check == '':
		validators = cv.validators()
		validators.sort()
		print(' Validators:', ' '.join([code for code in validators]))
	if check[:3] == 'pho' or check == '':
		phonemisers = cv.phonemisers()
		phonemisers.sort()
		print(' Phonemisers:', ' '.join([code for code in phonemisers]))
	if check[:3] == 'seg' or check == '':
		segmenters = cv.segmenters()
		segmenters.sort()
		print(' Segmenters:', ' '.join([code for code in segmenters]))

elif mode == 'gutenberg' or mode == 'guten' or mode == 'gb':
	cmd = sys.argv[2]
	HOME = os.getenv('HOME')
	covo_cache_dir = HOME + "/.covo/"

	from cvutils import gutenberg

	def gb_update(covo_cache_dir):
		if not os.path.isdir(covo_cache_dir):
			os.mkdir(covo_cache_dir, mode=0o755)
		fd = open(covo_cache_dir + "GUTINDEX.ALL", 'w')
		g = urllib.request.urlopen('https://www.gutenberg.org/dirs/GUTINDEX.ALL')
		txt = g.read().strip()
		fd.write(txt.decode('utf-8'))
		fd.close()
		return len(txt)

	if cmd == 'update':
		res = gb_update(covo_cache_dir)
		print('Downloaded index, %d bytes' % (res))
		sys.exit(0)
	elif cmd == 'list':
		locale = sys.argv[3]
		index_path = covo_cache_dir + "GUTINDEX.ALL"
		if not os.path.isfile(index_path):
			res = gb_update(covo_cache_dir)
			print('Downloaded index, %d bytes' % (res))
		gb = gutenberg.Gutenberg(index_path)
		ids = gb.get_catalogue(locale)
		for idx in ids:
			print('https://www.gutenberg.org/files/%s/%s-0.txt' % (idx, idx))
	else:
		print('Operation not supported.', file=sys.stderr)
		

elif mode == 'getckp' or mode == 'get-ckp':
	locale = sys.argv[2]
	path = 'source-checkpoints'
	if len(sys.argv) ==4:
		path = sys.argv[3]
	Path(path).mkdir(parents=True, exist_ok=True)


	ctx = ssl.create_default_context()
	ctx.check_hostname = False
	ctx.verify_mode = ssl.CERT_NONE

	try:
		g = urllib.request.urlopen('https://itml.cl.indiana.edu/models/' + locale + '/checkpoints/best_dev_checkpoint', context=ctx)	
	except urllib.error.HTTPError:
		print('[Checkpoint not found]', file=sys.stderr)
		sys.exit(-1)	
	txt = g.read().strip()
	if txt == '':
		print('[Checkpoint not found]', file=sys.stderr)
		sys.exit(-1)	
	row = txt.decode('utf-8').split('\n')
	ref = row[1].split('"')[1].split('/')[-1]
	print('[Checkpoint found]  %s' % ref)

	for fn in ['.data-00000-of-00001', '.index', '.meta']:
		print('[Downloading] %s ' % (ref + fn))
		fd = urllib.request.urlopen('https://itml.cl.indiana.edu/models/' + locale + '/checkpoints/' + ref + fn, context=ctx)
		op = open(path + '/' + ref + fn, 'wb')
		op.write(fd.read())
		op.close()

	print('[Done] Your checkpoint is in %s' % path + '/')

elif mode == 'export':
	tipus = sys.argv[2]
	locale = sys.argv[3]
	input_dir = sys.argv[4]
	output_dir = sys.argv[4] + '/clips/'
	if len(sys.argv) == 6:
		output_dir = sys.argv[5]

	a = Alphabet(locale)
	v = Validator(locale)

	if tipus == 'coqui':

		from cvutils.exporters.coqui import CoquiExporter

		cx = CoquiExporter(v, a)

		cx.process(input_dir, output_dir)
	elif tipus == 'nemo':

		from cvutils.exporters.nemo import NemoExporter

		cx = NemoExporter(v, a)

		cx.process(input_dir, output_dir)

else:
	print('Mode ' + sys.argv[1] + ' not found. Maybe try: covo help', file=sys.stderr)
	sys.exit(-1)
