#!/usr/bin/env python3

"""Quick CSV stats for specified files."""

import argparse
import collections
import csv
import glob
import os.path as pth
import re
import sqlite3

from datasimple import cli

SQL_RESERVED = set("""
    ABORT	            DEFAULT	    INNER	    REGEXP
    ACTION	            DEFERRABLE	INSERT	    REINDEX
    ADD	                DEFERRED	INSTEAD	    RELEASE
    AFTER	            DELETE	    INTERSECT	RENAME
    ALL	                DESC	    INTO	    REPLACE
    ALTER	            DETACH	    IS	        RESTRICT
    ANALYZE	            DISTINCT	ISNULL	    RIGHT
    AND	                DROP	    JOIN	    ROLLBACK
    AS	                EACH	    KEY	        ROW
    ASC	                ELSE	    LEFT	    SAVEPOINT
    ATTACH	            END	        LIKE	    SELECT
    AUTOINCREMENT	    ESCAPE	    LIMIT	    SET
    BEFORE	            EXCEPT	    MATCH	    TABLE
    BEGIN	            EXCLUSIVE	NATURAL	    TEMP
    BETWEEN	            EXISTS	    NO	        TEMPORARY
    BY	                EXPLAIN	    NOT	        THEN
    CASCADE	            FAIL	    NOTNULL	    TO
    CASE	            FOR	        NULL	    TRANSACTION
    CAST	            FOREIGN	    OF	        TRIGGER
    CHECK	            FROM	    OFFSET	    UNION
    COLLATE	            FULL	    ON	        UNIQUE
    COLUMN	            GLOB	    OR	        UPDATE
    COMMIT	            GROUP	    ORDER	    USING
    CONFLICT	        HAVING	    OUTER	    VACUUM
    CONSTRAINT	        IF	        PLAN	    VALUES
    CREATE	            IGNORE	    PRAGMA	    VIEW
    CROSS	            IMMEDIATE	PRIMARY	    VIRTUAL
    CURRENT_DATE	    IN	        QUERY	    WHEN
    CURRENT_TIME	    INDEX	    RAISE	    WHERE
    CURRENT_TIMESTAMP	INDEXED	    RECURSIVE	WITH
    DATABASE	        INITIALLY	REFERENCES	WITHOUT
""".split())


def log(msg, *args, **kwrds):
    """Simple output helper with levelled prefix."""
    if args:
        msg = msg.format(*args)
    level = kwrds.get('level', 0)
    cli.log(('    ' * level) + msg)


class Field(object):
    """Quick field stats for our file."""

    def __init__(self, name, valfilt=str):
        """Ctor for field stats."""
        self.name = str(name)
        self.val_counts = collections.defaultdict(int)
        self.mnval = None
        self.mxval = None
        self.valfilt = valfilt
        self.count = 0

    def one(self, val):
        """Should be called once for each value seen."""
        self.count += 1

        val = self.valfilt(val)
        self.val_counts[val] += 1

        if self.mnval is None or val < self.mnval:
            self.mnval = val
        if self.mxval is None or val > self.mxval:
            self.mxval = val

    def report(self, topn=5):
        """Report final stats."""
        log(
            'Field "{:s}" => {:,d} unique values (out of {:,d} values seen)',
            self.name,
            len(self.val_counts),
            self.count
        )
        if self.count < 1:
            log('Nothing shown - no data found for field!', level=1)
            return

        log('Min val: {:20s}', self.mnval, level=1)
        log('Max val: {:20s}', self.mxval, level=1)

        log('{:d} most common vals:', topn, level=1)
        log('{:20s} {:s}', 'Count', 'Value', level=2)
        log('{:20s} {:s}', '-'*20, '-'*20, level=2)
        for cnt, val in self.most_common(topn):
            log('{:20,d} {:s}', cnt, val, level=2)

    def most_common(self, topn=1):
        """Return n most common values."""
        return sorted(
            ((cnt, val) for val, cnt in self.val_counts.items()),
            reverse=True
        )[0:topn]


def reader_to_table(conn, tabname, csvreader):
    """CSV to table routine."""
    headers = None
    count = 0
    cur = conn.cursor()

    blank_counter = 1
    previous_cols = set()

    def safename(s):
        nonlocal blank_counter
        n = s.strip()
        n = re.sub(r'(\#)([0-9])*',  r'Num\2', n)
        n = re.sub(r'[ -/#]+',        '_',     n)
        n = re.sub(r'[^0-9a-zA-Z_]+', '',      n)
        n = n.strip('_').replace('__', '_').strip()

        if re.match(r'^[0-9]+', n):
            n = 'n' + n  # Name can't start with number

        if not n:
            n = 'BlankName{:d}'.format(blank_counter)
            blank_counter += 1

        if n.upper() in SQL_RESERVED:
            n = n + '_'

        if n in previous_cols:
            num = 0
            while True:
                tryn = n + ('_dup%03d' % num)
                if tryn not in previous_cols:
                    n = tryn
                    break
                num += 1

        previous_cols.add(n)
        return n

    def val_filter(v):
        if v == '\\N':
            v = None  # Like MySQl import, convert \N to null
        return v

    def rec_filter(rec):
        return tuple(val_filter(val) for val in rec)

    tabname = safename(tabname)

    for rec in csvreader:
        if not headers:
            headers = ','.join([safename(c) for c in rec])
            sql = 'create table {} ({})'.format(tabname, headers)
            log('CREATING: {}', sql)
            cur.execute(sql)
            conn.commit()
            insert_sql = 'insert into {} ({}) values ({})'.format(
                tabname,
                headers,
                ','.join(['?'] * len(rec))
            )
            continue

        cur.execute(insert_sql, rec_filter(rec))
        count += 1
        if count % 50000 == 0:
            conn.commit()
            log('    Inserted {:,d} records', count)

    conn.commit()
    log('BASE TABLE DONE: {:,d} records', count)


def one(encoding, filename, fld_check, topn, sqlite_file):
    """Process single instance."""
    # Set up
    log(filename)
    fields = [Field(f) for f in fld_check]

    # Get header list
    headers = []
    with open(filename, encoding=encoding) as fh:
        for line in csv.reader(fh):
            headers = line
            break
    if not headers:
        log(' no header line - skipping this file')
        return

    log('Found {:,d} headers: {:s}', len(headers), ','.join(headers))

    # Process all records
    count = 0
    with open(filename, encoding=encoding) as fh:
        for rec in csv.DictReader(fh):
            count += 1
            for f in fields:
                if f.name in rec:
                    f.one(rec[f.name])

    log('Found {:d} recs', count)
    for f in fields:
        f.report(topn)

    # Handle sqlite
    if sqlite_file:
        tabname = pth.splitext(pth.split(filename)[-1])[0]
        log('Creating table {:s} in {:s} from {:s}', tabname, sqlite_file, filename)
        conn = sqlite3.connect(sqlite_file)
        with open(filename, encoding=encoding) as fh:
            reader_to_table(conn, tabname, csv.reader(fh))


def main():
    """Entry point."""
    parser = argparse.ArgumentParser()
    parser.add_argument(
        'file_patterns',
        metavar='FILE_NAME_PATTERN',
        type=str,
        nargs='+',
        help='input CSV files: multiple and globbing allowed'
    )
    parser.add_argument(
        '-f', '--fields',
        action='append',
        help='Column to report on (col may be missing)',
        required=False
    )
    parser.add_argument(
        '-n', '--topn',
        help='For all fields specified by -f/--fields, number of most frequent values shown',
        required=False,
        type=int,
        default=5
    )
    parser.add_argument(
        '-s', '--sqlite',
        help='Name of sqlite file to create from CSV data',
        required=False,
        type=str,
        default=''
    )
    parser.add_argument(
        '-e', '--encoding',
        help='Encoding to use reading the CSV file (default utf-8). Problems with Windows files? Try latin-1',
        required=False,
        type=str,
        default='utf-8'
    )
    args = parser.parse_args()

    fns = sorted(set(
        fn for p in args.file_patterns for fn in glob.glob(p) if pth.isfile(fn)
    ))
    for fn in fns:
        one(args.encoding, fn, args.fields or [], args.topn, args.sqlite)
        log('*'*70)


if __name__ == '__main__':
    main()
