#!python

import argparse
import json
import keyword
import os
import os.path
import sys
import re
import functools


class CodeGen:
    def __init__(self, schema_name, schema, writer):
        self.schema_name = schema_name
        self.schema = schema
        self.types = schema['types']
        self.query_type = self.get_path('queryType', 'name')
        self.mutation_type = self.get_path('mutationType', 'name')
        self.subscription_type = self.get_path('subscriptionType', 'name')
        self.directives = schema.get('directives', [])
        self.analyze()
        self.writer = writer
        self.written_types = set()

    def get_path(self, *path, fallback=None):
        d = self.schema
        path = list(path)
        while d and path:
            d = d.get(path.pop(0))

        if d is None:
            return fallback
        return d

    builtin_types = ('Int', 'Float', 'String', 'Boolean', 'ID')
    datetime_types = ('DateTime', 'Date', 'Time')
    relay_types = ('Node', 'PageInfo')

    def analyze(self):
        self.uses_datetime = False
        self.uses_relay = False
        for t in self.types:
            name = t['name']
            if name in self.datetime_types:
                self.uses_datetime = True
                if self.uses_relay:
                    break
            elif name in self.relay_types:
                self.uses_relay = True
                if self.uses_datetime:
                    break

    def write(self):
        self.write_header()
        self.write_types()

    def write_banner(self, text):
        bar = '#' * 72
        self.writer('''
%(bar)s
# %(text)s
%(bar)s
''' % {
            'bar': bar,
            'text': text,
        })

    @staticmethod
    def has_iface(ifaces, name):
        for iface in ifaces:
            if name == iface['name']:
                return True
        return False

    # fields without interfaces first, then order the interface
    # implementor after the interface declaration
    def depend_cmp(a, b):
        a_ifaces = a['interfaces']
        b_ifaces = b['interfaces']
        if not a_ifaces and b_ifaces:
            return -1
        elif a_ifaces and not b_ifaces:
            return 1
        elif not a_ifaces and not b_ifaces:
            return 0

        has_iface = CodeGen.has_iface
        if len(a_ifaces) < len(b_ifaces):
            if has_iface(a_ifaces, b['name']):
                return 1
            elif has_iface(b_ifaces, a['name']):
                return -1
            else:
                return 0
        else:
            if has_iface(b_ifaces, a['name']):
                return -1
            elif has_iface(a_ifaces, b['name']):
                return 1
            else:
                return 0

    @staticmethod
    def get_depend_sort_key():
        return functools.cmp_to_key(CodeGen.depend_cmp)

    def write_types(self):
        mapper_basic = {
            'ENUM': self.write_type_enum,
            'SCALAR': self.write_type_scalar,
        }
        mapper_input_containers = {
            'INPUT_OBJECT': self.write_type_input_object,
        }
        mapper_output_containers = {
            'INTERFACE': self.write_type_interface,
            'OBJECT': self.write_type_object,
        }
        mapper_post = {
            'UNION': self.write_type_union,
        }

        self.write_banner('Scalars and Enumerations')
        todo = self.types
        remaining = []
        for t in todo:
            kind = t['kind']
            mapped = mapper_basic.get(kind)
            if mapped is None:
                remaining.append(t)
            else:
                mapped(t)

        self.write_banner('Input Objects')
        todo = remaining
        remaining = []
        for t in todo:
            kind = t['kind']
            mapped = mapper_input_containers.get(kind)
            if mapped is None:
                remaining.append(t)
            else:
                mapped(t)

        remaining.sort(key=self.get_depend_sort_key())
        self.write_banner('Output Objects and Interfaces')
        todo = remaining
        remaining = []
        for t in todo:
            kind = t['kind']
            mapped = mapper_output_containers.get(kind)
            if mapped is None:
                remaining.append(t)
            else:
                mapped(t)

        self.write_banner('Unions')
        todo = remaining
        remaining = []
        for t in todo:
            kind = t['kind']
            mapped = mapper_post.get(kind)
            if mapped is None:
                remaining.append(t)
            else:
                mapped(t)

        assert not remaining

    builtin_enum_names = ('__TypeKind', '__DirectiveLocation')

    def write_type_enum(self, t):
        name = t['name']
        if name in self.builtin_enum_names:
            return
        self.writer('''\
class %(name)s(sgqlc.types.Enum):
    __schema__ = %(schema_name)s
    __choices__ = %(choices)r


''' % {
            'name': name,
            'schema_name': self.schema_name,
            'choices': tuple(v['name'] for v in t['enumValues']),
        })
        self.written_types.add(name)

    re_camel_case_words = re.compile('([^A-Z]+|[A-Z]+[^A-Z]*)')

    @classmethod
    def graphql_to_python(cls, name):
        s = []
        for w in cls.re_camel_case_words.findall(name):
            s.append(w.lower())
        name = '_'.join(s)
        if keyword.iskeyword(name):
            return name + '_'
        return name

    def get_type_ref(self, t):
        kind = t['kind']
        if kind == 'NON_NULL':
            of_type = t['ofType']
            return 'sgqlc.types.non_null(%s)' % self.get_type_ref(of_type)
        elif kind == 'LIST':
            of_type = t['ofType']
            return 'sgqlc.types.list_of(%s)' % self.get_type_ref(of_type)

        name = t['name']
        if name in self.written_types:
            return name
        else:
            return repr(name)

    def write_field_input(self, field):
        name = field['name']
        tref = self.get_type_ref(field['type'])
        self.writer('''\
    %(py_name)s = sgqlc.types.Field(%(type)s, graphql_name=%(gql_name)r)
''' % {
            'py_name': self.graphql_to_python(name),
            'gql_name': name,
            'type': tref,
        })

    def write_arg(self, arg):
        name = arg['name']
        tref = self.get_type_ref(arg['type'])
        defval = arg['defaultValue']
        if defval:
            if defval.startswith('$'):
                defval = 'sgqlc.types.Variable(%r)' % defval[1:]
            else:
                try:
                    defval = json.loads(defval)
                except json.JSONDecodeError:
                    pass
                defval = '%s(%r)' % (tref, defval)

        self.writer('''\
        (%(py_name)r, sgqlc.types.Arg(%(type)s, graphql_name=%(gql_name)r, \
default=%(default)s)),
''' % {
            'py_name': self.graphql_to_python(name),
            'gql_name': name,
            'type': tref,
            'default': defval,
        })

    def write_field_output(self, field):
        name = field['name']
        tref = self.get_type_ref(field['type'])
        self.writer('''\
    %(py_name)s = sgqlc.types.Field(%(type)s, graphql_name=%(gql_name)r\
''' % {
            'py_name': self.graphql_to_python(name),
            'gql_name': name,
            'type': tref,
        })
        args = field['args']
        if args:
            self.writer(', args=sgqlc.types.ArgDict((\n')
            for a in args:
                self.write_arg(a)
            self.writer('))\n    ')

        self.writer(')\n')

    def write_type_input_object(self, t):
        name = t['name']
        self.writer('''\
class %(name)s(sgqlc.types.Input):
    __schema__ = %(schema_name)s
''' % {
            'name': name,
            'schema_name': self.schema_name,
        })
        for field in t['inputFields']:
            self.write_field_input(field)
        self.writer('\n\n')
        self.written_types.add(name)

    def write_type_interface(self, t):
        name = t['name']
        self.writer('''\
class %(name)s(sgqlc.types.Interface):
    __schema__ = %(schema_name)s
''' % {
            'name': name,
            'schema_name': self.schema_name,
        })
        for field in t['fields']:
            self.write_field_output(field)
        self.writer('\n\n')
        self.written_types.add(name)

    builtin_object_names = (
        '__Schema',
        '__Type',
        '__Field',
        '__Directive',
        '__EnumValue',
        '__InputValue',
    )

    def write_type_object(self, t):
        name = t['name']
        if name in self.builtin_object_names:
            return
        if self.uses_relay and name.endswith('Connection'):
            bases = ['sgqlc.types.relay.Connection']
        else:
            bases = ['sgqlc.types.Type']

        for iface in (t['interfaces'] or ()):
            assert iface['name'] in self.written_types, iface['name']
            bases.append(iface['name'])

        self.writer('''\
class %(name)s(%(bases)s):
    __schema__ = %(schema_name)s
''' % {
            'name': name,
            'schema_name': self.schema_name,
            'bases': ', '.join(bases),
        })
        for field in t['fields']:
            self.write_field_output(field)
        self.writer('\n\n')
        self.written_types.add(name)

    def write_type_scalar(self, t):
        name = t['name']
        if name in self.builtin_types:
            self.writer('%(name)s = sgqlc.types.%(name)s' % t)
        elif name in self.datetime_types:
            self.writer('%(name)s = sgqlc.types.datetime.%(name)s' % t)
        else:
            self.writer('''\
class %(name)s(sgqlc.types.Scalar):
    __schema__ = %(schema_name)s
''' % {
                'name': name,
                'schema_name': self.schema_name,
            })
        self.writer('\n\n')
        self.written_types.add(name)

    def write_type_union(self, t):
        name = t['name']
        self.writer('''\
class %(name)s(sgqlc.types.Union):
    __schema__ = %(schema_name)s
    __types__ = (%(types)s)


''' % {
            'name': name,
            'schema_name': self.schema_name,
            'types': ', '.join(v['name'] for v in t['possibleTypes']),
        })
        self.written_types.add(name)

    def write_header(self):
        self.writer('import sgqlc.types\n')
        self.write_datetime_import()
        self.write_relay_import()
        self.writer('\n\n%s = sgqlc.types.Schema()\n\n\n' % self.schema_name)
        self.write_relay_fixup()

    def write_datetime_import(self):
        if not self.uses_datetime:
            return
        self.writer('import sgqlc.types.datetime\n')

    def write_relay_import(self):
        if not self.uses_relay:
            return
        self.writer('import sgqlc.types.relay\n')

    def write_relay_fixup(self):
        if not self.uses_relay:
            return
        self.writer('''\
# Unexport Node/PageInfo, let schema re-declare them
%(schema_name)s -= sgqlc.types.relay.Node
%(schema_name)s -= sgqlc.types.relay.PageInfo


''' % {
            'schema_name': self.schema_name,
        })


def get_basename_noext(path):
    return os.path.splitext(os.path.basename(path))[0]


ap = argparse.ArgumentParser(
    description='Generate sgqlc types using GraphQL introspection data',
)

# Generic options to access the GraphQL API
ap.add_argument('schema.json', type=argparse.FileType('r'), nargs='?',
                help=('The input schema as JSON file. '
                      'Usually the output from introspection query.'),
                default=sys.stdin)

ap.add_argument('schema.py', type=argparse.FileType('w'), nargs='?',
                help=('The output schema as Python file using sgqlc.types. '
                      'Defaults to the input schema name with .py extension.'),
                default=None)

ap.add_argument('--schema-name', '-s',
                help=('The schema name to use. '
                      'Defaults to output (or input) basename without '
                      'extension and invalid python identifiers replaced '
                      ' with "_".'),
                default=None)

args = vars(ap.parse_args())  # vars: schema.json and schema.py

in_file = args['schema.json']
out_file = args['schema.py']

in_fname = args['schema.json'].name

schema_name = args['schema_name']
if not schema_name:
    if out_file:
        if out_file.name != '<stdout>':
            schema_name = get_basename_noext(out_file.name)
        elif in_fname != '<stdin>':
            schema_name = get_basename_noext(in_fname)
        else:
            schema_name = 'generated_schema'
    else:
        if in_fname == '<stdin>':
            schema_name = 'generated_schema'
        else:
            schema_name = get_basename_noext(in_fname)


schema_name = re.sub('[^A-Za-z0-9_]+', '_', schema_name)
schema_name = re.sub('(^_+|_+$)', '', schema_name)
if re.match('^[0-9]', schema_name):
    schema_name = '_' + schema_name

if not out_file:
    if in_fname == '<stdin>':
        out_file = sys.stdout
    else:
        wd = os.path.dirname(in_fname)
        out_fname = os.path.join(wd, schema_name + '.py')
        out_file = open(out_fname, 'w')
        sys.stdout.write('Writing to: %s\n' % (out_fname,))

schema = json.load(in_file)
if not isinstance(schema, dict):
    raise SystemExit('schema must be a JSON object')

if not schema.get('types'):
    if schema.get('data', {}).get('__schema', None):
        schema = schema['data']['__schema']  # plain HTTP endpoint result
    elif schema.get('__schema'):
        schema = schema['__schema']  # introspection field
    else:
        raise SystemExit('schema must be introspection object or query result')

gen = CodeGen(schema_name, schema, out_file.write)
gen.write()
out_file.close()
