#!/usr/bin/env python


from collections import OrderedDict
from datetime import datetime
from time import time, sleep
import threading, mimetypes, json
import os, sys, getopt, cgi, signal, daemon

from geventwebsocket import WebSocketServer, WebSocketApplication, Resource

from sqlchain.version import version, P2SH_FLAG, BECH32_FLAG
from sqlchain.rpc import do_RPC
from sqlchain.bci import isTxAddrs
from sqlchain.dbpool import DBPool
from sqlchain.util import dotdict, loadcfg, savecfg, drop2user, getssl, log, logts

__builtins__.sqc = dotdict()  # container for super globals

sqc.cfg = { 'log':sys.argv[0]+'.log', 'listen':'localhost:8085', 'www':'www', 'block':0,
            'pool':4, 'dbinfo-ts':datetime.now().strftime('%s'),
            'dbinfo':-1, 'path':'/var/data/sqlchain', 'cointype':'bitcoin' }

sqc.server = {}
sqc.clients = {} # active websockets we publish to
sqc.syncTxs,sqc.lastBlk = [],{} # current sync data shared for every sync/subscription
sqc.sync = threading.Condition()
sqc.sync_id = 0

def do_Root(env, send_resp):
    try:
        path = env['PATH_INFO']
        if env['REQUEST_METHOD'] == 'POST': # POST
            if path == '/': # the /rpc api is mirrored here as form params
                form = cgi.FieldStorage(fp=env['wsgi.input'], environ=env, keep_blank_values=True)
                env['PATH_INFO'] = "/rpc/%s/%s" % (form['method'].value, "/".join(form.getlist('params')))
                return do_RPC(env, send_resp)
        elif sqc.cfg['www']: # GET static website files, if path configured
            path = '/main.html' if path in ['', '/'] else path
            if os.path.isfile(sqc.cfg['www']+path):
                _,ext = os.path.splitext(path)
                filesize = str(os.path.getsize(sqc.cfg['www']+path))
                with open(sqc.cfg['www']+path) as fd:
                    send_resp('200 OK', [('Content-Type', mimetypes.types_map[ext]), ('Content-Length', filesize),
                                         ('Expires', datetime.utcfromtimestamp(time()+3600).strftime("%a, %d %b %Y %H:%M:%S %ZGMT"))])
                    return [ fd.read() ]
        send_resp('404 - File Not Found: %s' % path, [("Content-Type", "text/html")], sys.exc_info())
        if not sqc.cfg['www']:
            return []
        with open(sqc.cfg['www']+'/404.html') as fd:
            return [ fd.read() ]
    except IOError:
        pass

class BCIWebSocket(WebSocketApplication):
    remote = None
    def on_open(self, *args, **kwargs):
        self.remote = self.ws.environ['REMOTE_ADDR']
        logts("WS Client connected from %s" % self.remote)
        sqc.clients[self.ws.handler.active_client] = { 'subs':[], 'addrs':set() }

    def on_message(self, msg, *args, **kwargs): # pylint:disable=arguments-differ
        if msg:
            msg = json.loads(msg)
            if msg['op'] in [ 'blocks_sub', 'unconfirmed_sub' ]:
                sqc.clients[self.ws.handler.active_client]['subs'].append(msg['op'])
            if msg['op'] == 'addr_sub' and 'addr' in msg:
                sqc.clients[self.ws.handler.active_client]['addrs'].add(msg['addr'])
            if msg['op'] == 'ping_block':
                self.ws.send({ 'op': 'block', 'x': sqc.lastBlk })
            if msg['op'] == 'ping_tx':
                if 'lasttx' in sqc.clients[self.ws.handler.active_client]:
                    self.ws.send(json.dumps({ 'op': 'utx', 'x': sqc.clients[self.ws.handler.active_client]['lasttx'] }))

    def on_close(self, *args, **kwargs):
        logts("WS Client disconnected %s %s" % (self.remote, ''.join(args)))
        del sqc.clients[self.ws.handler.active_client]

# monitor mempool, block, orphan changes - publish to websocket subscriptions, notify waiting sync connections
def syncMonitor():
    with sqc.dbpool.get().cursor() as cur:
        cur.execute("select greatest(ifnull(m,0),ifnull(o,0)) from (select max(sync_id) as m from mempool) m,(select max(sync_id) as o from orphans) o;")
        sqc.sync_id = cur.fetchone()[0]
        cur.execute("select ifnull(max(id),0) from blocks;")
        sqc.cfg['block'] = cur.fetchone()[0]
        if sqc.cfg['dbinfo'] == 0:
            sqc.dbwrk = threading.Thread(target = mkDBInfo)
            sqc.dbwrk.start()
    while not sqc.done.isSet():
        with sqc.dbpool.get().cursor() as cur:
            txs = []
            cur.execute("select hash from mempool m, trxs t where m.sync_id > %s and t.id=m.id;", (sqc.sync_id,))
            for txhash, in cur:
                txs.append(sqlchain.bci.bciTxWS(cur, txhash[::-1].encode('hex')))
            if len(txs) > 0:
                sqc.syncTxs = txs
            cur.execute("select count(*) from orphans where sync_id > %s;", (sqc.sync_id,))
            new_orphans = cur.fetchone()[0] > 0
            cur.execute("select max(id) from blocks;")
            block = cur.fetchone()[0]
            cur.execute("replace into info (class,`key`,value) values('info','ws-clients',%s),('info','connections',%s);", (len(sqc.clients), len(sqc.server.pool) if sqc.server.pool else 0))
        if block != sqc.cfg['block'] or new_orphans or len(txs) > 0:
            do_Sync(block)
        if sqc.cfg['dbinfo'] > 0 and (datetime.now() - datetime.fromtimestamp(int(sqc.cfg['dbinfo-ts']))).total_seconds() > sqc.cfg['dbinfo']*60:
            sqc.dbwrk = threading.Thread(target = mkDBInfo)
            sqc.dbwrk.start()
        sleep(sqc.cfg['sync'] if 'sync' in sqc.cfg else 5)
    if sqc.dbwrk:
        sqc.dbwrk.join()

def do_Sync(block):
    if block != sqc.cfg['block']:
        sqc.cfg['block'] = min(block, sqc.cfg['block']+1)
        with sqc.dbpool.get().cursor() as cur:
            sqc.lastBlk = sqlchain.bci.bciBlockWS(cur, block)
        for client in sqc.server.clients.values():
            if 'blocks_sub' in sqc.clients[client]['subs']:
                client.ws.send(json.dumps({ 'op': 'block', 'x': sqc.lastBlk }))
    sqc.sync_id += 1
    with sqc.sync:
        sqc.sync.notifyAll()
    if len(sqc.syncTxs) > 0:
        for client in sqc.server.clients.values():
            for tx in sqc.syncTxs:
                if 'unconfirmed_sub' in sqc.clients[client]['subs'] or (sqc.clients[client]['addrs'] and isTxAddrs(tx, sqc.clients[client]['addrs'])):
                    client.ws.send(json.dumps({ 'op': 'utx', 'x': tx }))
                    sqc.clients[client]['lasttx'] = tx

def mkDBInfo():
    with sqc.dbpool.get().cursor() as cur:
        logts("Updating server db info")
        sqc.cfg['dbinfo-ts'] = datetime.now().strftime('%s')
        savecfg(sqc.cfg)
        sqlchain.insight.apiStatus(cur, 'db')
        cur.execute("select count(*) from address where (id & %s = %s);", (P2SH_FLAG,P2SH_FLAG))
        cur.execute("replace into info (class,`key`,value) values('db','address:p2sh',%s);", (cur.fetchone()[0], ))
        cur.execute("select count(*) from address where (id & %s = %s);", (BECH32_FLAG,BECH32_FLAG))
        cur.execute("replace into info (class,`key`,value) values('db','address:p2wpkh',%s);", (cur.fetchone()[0], ))
        cur.execute("select count(*) from bech32 where 1;")
        cur.execute("replace into info (class,`key`,value) values('db','address:p2wsh',%s);", (cur.fetchone()[0], ))
        cur.execute("select count(*) from address where  cast(conv(hex(reverse(unhex(substr(sha2(addr,0),1,10)))),16,10) as unsigned) != floor(id);")
        cur.execute("replace into info (class,`key`,value) values('db','address:id-collisions',%s);", (cur.fetchone()[0], ))
        cur.execute("select count(*) from trxs where  strcmp(reverse(unhex(hex(id*8))), left(hash,5)) > 0;")
        cur.execute("replace into info (class,`key`,value) values('db','trxs:id-collisions',%s);", (cur.fetchone()[0], ))
        cur.execute("select count(*) from outputs where addr_id=0;")
        cur.execute("replace into info (class,`key`,value) values('db','outputs:non-std',%s);", (cur.fetchone()[0], ))
        cur.execute("select count(*) from outputs where tx_id is null;")
        cur.execute("replace into info (class,`key`,value) values('db','outputs:unspent',%s);", (cur.fetchone()[0], ))
        cur.execute("replace into info (class,`key`,value) values('db','all:updated',now());")
        logts("DB info update complete")
    sqc.dbwrk = None

def options(cfg): # pylint:disable=too-many-branches
    try:
        opts,_ = getopt.getopt(sys.argv[1:], "hvb:p:c:d:l:w:h:p:r:u:i:",
            ["help", "version", "debug", "db=", "log=", "www=", "listen=", "path=", "rpc=", "user=", "dbinfo=", "defaults" ])
    except getopt.GetoptError:
        usage()
    for opt,arg in opts:
        if opt in ("-h", "--help"):
            usage()
        elif opt in ("-v", "--version"):
            sys.exit(sys.argv[0]+': '+version)
        elif opt in ("-d", "--db"):
            cfg['db'] = arg
        elif opt in ("-l", "--log"):
            cfg['log'] = arg
        elif opt in ("-w", "--www"):
            cfg['www'] = arg
        elif opt in ("-p", "--path"):
            cfg['path'] = arg
        elif opt in ("-h", "--listen"):
            cfg['listen'] = arg
        elif opt in ("-r", "--rpc"):
            cfg['rpc'] = arg
        elif opt in ("-u", "--user"):
            cfg['user'] = arg
        elif opt in ("-i","--dbinfo"):
            cfg['dbinfo'] = int(arg)
        elif opt in "--defaults":
            savecfg(cfg)
            sys.exit("%s updated" % (sys.argv[0]+'.cfg'))
        elif opt in "--debug":
            cfg['debug'] = True

def usage():
    print """Usage: {0} [options...][cfg file]\nCommand options are:\n-h,--help\tShow this help info\n-v,--version\tShow version info
--debug\t\tRun in foreground with logging to console
--defaults\tUpdate cfg and exit\nDefault files are {0}.cfg, {0}.log
\nThese options get saved in cfg file as defaults.
-p,--path\tSet path for blob and header data files (/var/data/sqlchain)
-h,--listen\tSet host:port for web server\n-w,--www\tWeb server root directory\n-u,--user\tSet user to run as
-d,--db  \tSet mysql db connection, "host:user:pwd:dbname"\n-l,--log\tSet log file path
-r,--rpc\tSet rpc connection, "http://user:pwd@host:port"
-i,--dbinfo\tSet db info update period in minutes (default=180, 0=at start, -1=never) """.format(sys.argv[0])
    sys.exit(2)

def sigterm_handler(_signo, _stack_frame):
    logts("Shutting down.")
    sqc.done.set()
    if sqc.syncd:
        sqc.syncd.join()
    if not sqc.cfg['debug']:
        os.unlink(sqc.cfg['pid'] if 'pid' in sqc.cfg else sys.argv[0]+'.pid')
    sys.exit(0)

def sighup_handler(_signo, _stack_frame):
    path = sqc.cfg['log'] if 'log' in sqc.cfg else sys.argv[0]+'.log'
    sys.stdout.close()
    sys.stdout=open(path,'a')
    sys.stderr.close()
    sys.stderr=open(path,'a')
    logts("SIGHUP Log reopened")

def run():
    sqc.done = threading.Event()
    sqc.dbpool = DBPool(sqc.cfg['db'].split(':'), sqc.cfg['pool'], 'MySQLdb')

    mimetypes.init()
    mimetypes.add_type('application/x-font-woff', '.woff')
    mimetypes.add_type('application/x-font-woff2', '.woff2')
    mimetypes.add_type('application/x-font-woff', '.ttf')

    logts("Starting on %s" % sqc.cfg['listen'])
    host,port = sqc.cfg['listen'].split(':')
    sqc.server = WebSocketServer((host, int(port)), APIs, spawn=10000, **getssl(sqc.cfg))
    sqc.server.start()

    if 'sync' not in sqc.cfg or sqc.cfg['sync'] > 0:
        log("Sync monitoring at %d second intervals" % (sqc.cfg['sync'] if 'sync' in sqc.cfg else 5,))
        sqc.syncd = threading.Thread(target = syncMonitor)
        sqc.syncd.daemon = True
        sqc.syncd.start()
    else:
        log("Sync monitor disabled")

    drop2user(sqc.cfg, chown=True)

    sqc.server.serve_forever()

if __name__ == '__main__':

    loadcfg(sqc.cfg)
    options(sqc.cfg)

    import sqlchain.insight, sqlchain.bci, sqlchain.rpc

    APIs = Resource(OrderedDict((
        ('/api', sqlchain.insight.do_API),
        ('/bci', sqlchain.bci.do_BCI),
        ('/rpc', sqlchain.rpc.do_RPC),
        ('/ws',  BCIWebSocket),
        ('/',    do_Root) ))) # order important

    if sqc.cfg['debug']:
        signal.signal(signal.SIGINT, sigterm_handler)
        run()
    else:
        logpath = sqc.cfg['log'] if 'log' in sqc.cfg else sys.argv[0]+'.log'
        pidpath = sqc.cfg['pid'] if 'pid' in sqc.cfg else sys.argv[0]+'.pid'
        with daemon.DaemonContext(working_directory='.', umask=0002, stdout=open(logpath,'a'), stderr=open(logpath,'a'),
                signal_map={signal.SIGTERM:sigterm_handler, signal.SIGHUP:sighup_handler } ):
            with file(pidpath,'w') as f:
                f.write(str(os.getpid()))
            run()
