#!/usr/bin/env python3

"""Simple SQL to XLSX utility.

Run query vs the database and use the types in the resulting query to style
cells. Missing workbooks are created. Worksheets with the same name are
overwritten.
"""

import os.path as pth
import sys

from datasimple.sqlite import connect
from datasimple.xl import XlsxImporter
from datasimple.cli import log


class SqliteImporter(XlsxImporter):
    def add_args(self, argparser):
        argparser.description = 'Create an XLSX sheet from a query vs a SQLite database'
        argparser.epilog = 'NOTE: You can use the sqlite functions added by datasimple.sqlite'

        argparser.add_argument(
            '-d', '--db',
            help='Name of sqlite database file to read',
            required=True,
            type=str
        )
        argparser.add_argument(
            '-q', '--sql',
            help='Name of SQL file to execute (use - for STDIN)',
            required=True,
            type=str
        )

    def validate_args(self, args):
        if not pth.isfile(args.db):
            raise ValueError('SQLite file {} does not exist'.format(args.db))
        if args.sql != '-' and not pth.isfile(args.sql):
            raise ValueError('SQL query file {} does not exist'.format(args.sql))

    def customize_val_mapper(self, value_mapper):
        pass  # No val mapper customizations

    def before_save(self, args, wb, sheet):
        pass  # No pre-save customizations

    def get_data(self, args):
        log('Reading query [!y]{:s}[!/y]', args.sql if args.sql != '-' else '<STDIN>')
        if args.sql == '-':
            sql = sys.stdin.read()
        else:
            with open(args.sql) as inp:
                sql = inp.read().strip()

        stmts = [s.strip() for s in sql.split(';') if s.strip()]
        pre_stmts, query_stmt = stmts[:-1], stmts[-1]
        assert query_stmt, 'Missing a SQL query!'

        log('Opening [!y]{:s}[!/y]', args.db)
        conn = connect(args.db)
        cur = conn.cursor()

        for s in pre_stmts:
            log(s)
            cur.execute(s)
            log('Rows Affected: [!y]{:,d}[!/y]', cur.rowcount)
        conn.commit()

        cur.execute(query_stmt)
        query_cols = [d[0] for d in cur.description]

        def rows():
            yield from cur.fetchall()

        return query_cols, rows()


if __name__ == '__main__':
    SqliteImporter().main()
