#!python

import re
import os
import sys
import argparse
from sqlport import node
from sqlport.util import pretty_print
from sqlport.logger import Logger

def parse_args():
    parser = argparse.ArgumentParser(description='Ports SQL code to another dialect.')

    group = parser.add_mutually_exclusive_group()
    group.add_argument('--outfile', '-o', default='-', help="""
    output file path pattern with place holders:
    "#" = input file path;
    "%%" = input file path with last file extension removed;
    "%%%%" = input file path with last two file extensions removed;
    ...""")
    group.add_argument('--outdir', '-d', action='store', help="output base directory")
    group.add_argument('--replace', '-r', action='store_const', const="#",  dest="outfile", help="replace input file")

    group = parser.add_mutually_exclusive_group()
    group.add_argument('infile', metavar="INFILE", default=['-'], nargs='*')
    group.add_argument('--file-list', '-f', metavar="FILE", const='-', nargs='?', help="read file list from file or stdin")

    parser.add_argument('--quiet', '-q', default=0, action='count', help="do not output anything")
    parser.add_argument('--verbose', '-v', default=0, action='count', help="verbose output")
    parser.add_argument('--debug', '-D', default=0, action='count', help="debugging output")

    parser.add_argument('--parse-tree', '-T', action='store_true', help="show parse tree")
    parser.add_argument('--lex', '-L', action='store_true', help="show lexer output")
    parser.add_argument('--informix', '-i', action='store_true', help="generate informix SQL")
    parser.add_argument('--max-errors', '-e', type=int, default=1000, help="stop after this many errors")

    return parser.parse_args()

ofile_regex = re.compile(r"#|%+")

def map_outfile(ifile, ofile):
    iparts = ifile.split('.')
    def handler(match):
        x = match.group()
        if x == "#":
            return ifile
        else:
            return '.'.join(iparts[:-len(x)])
    return ofile_regex.sub(handler, ofile)

def read_file(infile):
    if infile == '-':
        return sys.stdin.read()
    else:
        return open(infile).read()

class ErrorCount:
    def __init__(self):
        self.value = 0

    def __call__(self, obj, t):
        self.value += 1

def main():
    args = parse_args()
    error_count = ErrorCount()
    if args.debug:
        Logger.level += args.debug
    from sqlport.engine import lex, parse, TooManyErrors
    if args.informix:
        from sqlport.writers import informix
        node.writer = informix.writer        
    if args.outdir:
        args.outfile = "{}/#".format(args.outdir)
    seen_outfiles = set()
    if args.file_list:
        args.infile = [ x for x in read_file(args.file_list).split('\n') if x ]
    for infile in args.infile:
        text = read_file(infile)
        outfile = map_outfile(infile, args.outfile)
        if outfile == '-':
            outfh = sys.stdout
        else:
            dirpath = os.path.dirname(outfile)
            if dirpath:
                os.makedirs(dirpath, exist_ok=True)
            outfh = open(outfile, 'a' if outfile in seen_outfiles else 'w')
        seen_outfiles.add(outfile)
        if args.lex:
            lex(text, outfh, args.verbose, onerror=error_count)
        else:
            try:
                tree = parse(text, outfh, args.verbose, onerror=error_count, maxerrors=args.max_errors)
            except TooManyErrors:
                sys.exit(0)
            if args.parse_tree:
                #tree.set_parents()
                pretty_print(tree)
            elif not args.quiet:
                tree.render(outfh)
    if error_count.value > 0:
        sys.exit(1)

if __name__ == "__main__":
    main()
