#!/usr/bin/env python3

import sys

from pathlib import Path
import urllib.request

import json

from cvutils import CV 
from cvutils import Corpora 
from cvutils import Segmenter
from cvutils import Validator
from cvutils import Alphabet 
from cvutils import Phonemiser

from cvutils import wikipedia

debug = False
debug = True

def help():
	print('covo   dump     [pages-articles.xml.bz2]',file=sys.stderr)
	print('       dump-url [locale]',file=sys.stderr)
	print('       segment  [locale]',file=sys.stderr)
	print('       norm     [locale]',file=sys.stderr)
	print('       phon     [locale]',file=sys.stderr)
	print('       opus     [locale]',file=sys.stderr)
	print('       filter   [locale]',file=sys.stderr)
	print('       getckp   [locale] [path]',file=sys.stderr)
	print('       text     [locale] [tsv file1 tsv file2 ...]',file=sys.stderr)
	print('       avail    [task]',file=sys.stderr)
	print('       missing  [task]',file=sys.stderr)

if len(sys.argv) == 1:
	help()
	sys.exit(-1)

mode = sys.argv[1]

if mode == 'help':
	help()
	sys.exit(-1)

if mode == 'dump':
	dump_name = sys.argv[2]
	if dump_name == '-':
		dump_name = '/dev/stdin'
	wikipedia.process(dump_name)

if mode == 'dump-url':
	dump_locale = sys.argv[2]
	c = Corpora(dump_locale)
	print(c.dump_url())

if mode == 'alphabet':
	locale = sys.argv[2]
	a = Alphabet(locale)
	print(a.get_alphabet())
		
if mode == 'segment':
	locale = sys.argv[2]
	s = Segmenter(locale)
	line = sys.stdin.readline()
	while line:
		for sentence in s.segment(line):
			print(sentence)	
		
		line = sys.stdin.readline()

elif mode == 'validate' or mode == 'norm':
	locale = sys.argv[2]
	def cleanup(skipped, chars, alphabet, count_valid, total):
		if debug:
			print('',file=sys.stderr)
			print('\n'.join([s for s in skipped if not s.strip() == ""]), file=sys.stderr)
			print('',file=sys.stderr)
			missing = list(chars - alphabet)
			missing.sort()
			for c in missing:
				print('%04x\t%s' % (ord(c), c), file=sys.stderr)			
		print('%d/%d (%.2f%%)' % (count_valid, total, (count_valid/total)*100.0),file=sys.stderr)
	

	v = Validator(locale)
	a = Alphabet(locale)
	line = sys.stdin.readline()
	chars = set()
	alphabet = set(a.get_alphabet())
	skipped = []
	count_valid = 0
	total = 0
	try: 
		while line:
			(valid, sent) = v.normalise(line)
			if valid:
				count_valid += 1
				print(sent)
			else:
				skipped.append(sent)
				[chars.add(c) for c in sent]
			total += 1	
			line = sys.stdin.readline()
	except KeyboardInterrupt:
		cleanup(skipped, chars, alphabet, count_valid, total)

	cleanup(skipped, chars, alphabet, count_valid, total)

elif mode == 'phon' or mode == 'phonemise':
	locale = sys.argv[2]
	p = Phonemiser(locale)
	for line in sys.stdin:
		phons = [p.phonemise(w) for w in line.split(' ')]
		print(' '.join([p for p in phons if p]))


elif mode == 'opus':
	locale = sys.argv[2]
	c = Corpora(locale)
	crps = c.opus()
	if crps:
			
		crps.sort(reverse=True)
		for line in crps:
			if line[0] < 0:
				continue
			if locale in line[1]:
				continue
			print('%s\t%s' % (line[3][1], line[3][0]))

elif mode == 'text':
	import csv

	locale = sys.argv[2]
	v = Validator(locale)
	# Do this with pathlib to allow globbing
	for fn in sys.argv[3:]:
		with open(fn) as csv_file:
			csv_reader = csv.reader(csv_file, delimiter=',')
			next(csv_reader)
			for row in csv_reader:
				if not row:
					continue
				(valid, sent) = v.normalise(row[2])
				if valid:
					print(sent)		
		
elif mode == 'filter':
	locale = sys.argv[2]
	umbral = 1
	if len(sys.argv) == 4:
		umbral = int(sys.argv[3])
	c = Corpora(locale)
	c.filter(sys.stdin, sys.stdout, umbral=1)

elif mode == 'getckp':
	locale = sys.argv[2]
	path = 'source-checkpoints'
	if len(sys.argv) ==4:
		path = sys.argv[3]
	Path(path).mkdir(parents=True, exist_ok=True)

	try:
		g = urllib.request.urlopen('https://tepozcatl.omnilingo.cc/' + locale + '/checkpoints/best_dev_checkpoint')	
	except urllib.error.HTTPError:
		print('[Checkpoint not found]', file=sys.stderr)
		sys.exit(-1)	
	txt = g.read().strip()
	if txt == '':
		print('[Checkpoint not found]', file=sys.stderr)
		sys.exit(-1)	
	row = txt.decode('utf-8').split('\n')
	ref = row[1].split('"')[1].split('/')[-1]
	print('[Checkpoint found]  %s' % ref)

	for fn in ['.data-00000-of-00001', '.index', '.meta']:
		print('[Downloading] %s ' % (ref + fn))
		fd = urllib.request.urlopen('https://tepozcatl.omnilingo.cc/' + locale + '/checkpoints/' + ref + fn)
		op = open(path + '/' + ref + fn, 'wb')
		op.write(fd.read())
		op.close()

	print('[Done] Your checkpoint is in %s' % path + '/')
	
elif mode == 'check' or mode == 'missing':
	exclude = ['zh-TW', 'zh-CN', 'zh-HK', 'ja']
	check = ''
	if len(sys.argv) == 3:
		check = sys.argv[2]	
	g = urllib.request.urlopen('https://raw.githubusercontent.com/common-voice/common-voice/main/locales/contributable.json')
	txt = g.read().strip()
	contributable = set(json.loads(txt.decode('utf-8')))
	cv = CV()
	print('Missing: ')
	if check[:2] == 'al' or check == '':
		missing_alphabets = list(contributable - set(cv.alphabets()))
		missing_alphabets.sort()
		print(' Alphabets:', ' '.join([code for code in missing_alphabets if not code in exclude]))
	if check[:3] == 'val' or check == '':
		missing_validators = list(contributable - set(cv.validators()))
		missing_validators.sort()
		print(' Validators:', ' '.join([code for code in missing_validators if not code in exclude]))
	if check[:3] == 'pho' or check == '':
		missing_phonemisers = list(contributable - set(cv.phonemisers()))
		missing_phonemisers.sort()
		print(' Phonemisers:', ' '.join([code for code in missing_phonemisers if not code in exclude]))
	if check[:3] == 'seg' or check == '':
		missing_segmenters = list(contributable - set(cv.segmenters()))
		missing_segmenters.sort()
		print(' Segmenters:', ' '.join([code for code in missing_segmenters if not code in exclude]))

elif mode == 'avail':
	check = ''
	if len(sys.argv) == 3:
		check = sys.argv[2]	
	cv = CV()
	print('Available:')
	if check[:2] == 'al' or check == '':
		alphabets = cv.alphabets()
		alphabets.sort()
		print(' Alphabets:', ' '.join([code for code in alphabets]))
	if check[:3] == 'val' or check == '':
		validators = cv.validators()
		validators.sort()
		print(' Validators:', ' '.join([code for code in validators]))
	if check[:3] == 'pho' or check == '':
		phonemisers = cv.phonemisers()
		phonemisers.sort()
		print(' Phonemisers:', ' '.join([code for code in phonemisers]))
	if check[:3] == 'seg' or check == '':
		segmenters = cv.segmenters()
		segmenters.sort()
		print(' Segmenters:', ' '.join([code for code in segmenters]))

