#!/usr/bin/env python
#
#   sqlchaind - daemon to update sql blockchain db
#
# pylint:disable=no-member
from Queue import Queue, Empty
from datetime import datetime
from struct import pack, unpack

import os, sys, socket, getopt, time, signal, threading, daemon
import MySQLdb as db

from sqlchain.version import version, coincfg, BLKDAT_NEAR_SYNC, BLKDAT_MAGIC, MAX_IO_TX, MAX_TX_BLK
from sqlchain.util import dotdict, sqlchain_overlay, loadcfg, savecfg, drop2user, rpcPool, blockwork, int2bin32, log, logts
from sqlchain.util import encodeVarInt, decodeBlock, decodeTx, findTx, insertAddress, mkBlobHdr, insertBlob, puthdr, gethdr
from sqlchain.blkdat import BlkDatHandler

__builtins__.sqc = dotdict()  # container for super globals

sqc.cfg = { 'log':sys.argv[0]+'.log', 'queue':8, 'no-sigs':False, 'db':'', 'rpc':'', 'path':'/var/data/sqlchain', 'cointype':'bitcoin' }
sqc.bestblk = 120000
sqc.zmq = True
blksecs = []
memPool = set()

def getBlocks(blk):
    sql = db.connect(*sqc.cfg['db'].split(':'))
    sql.autocommit(True) # only mempool, most data inserted in handlers
    cur = sql.cursor()
    cur.execute("show tables like 'bech32';") # test if old db version and abort with log msg
    if cur.rowcount == 0:
        log("sqlChain Database upgrade required for this daemon version.\nCannot continue. Run sqlchain-upgrade-db.")
        sqc.done.set()
        return 0
    blkinfo = sqc.rpc.getblockchaininfo() # wait for node to be ready
    if blkinfo is None:
        return 0
    if blk == 0:
        cur.execute('select ifnull(max(id), -1) from blocks')
        blk = int(cur.fetchone()[0] + 1)
        cur.execute("select hex(chainwork) from blocks where id=%s;", (blk-1,))
        row = cur.fetchone()
        chainwork = int(row[0],16) if row is not None else 0
        if blk == 0 and 'max_blks' in sqc.cfg:
            blk = blkinfo['blocks'] - sqc.cfg['max_blks']
            log("Using block limit %d" % sqc.cfg['max_blks'])
            blkhash = sqc.rpc.getblockhash(blk)
            chainwork = int(sqc.rpc.getblockheader(blkhash)['chainwork'],16)
    else:
        blkhash = sqc.rpc.getblockhash(blk)
        chainwork = int(sqc.rpc.getblockheader(blkhash)['chainwork'],16)
    startblk = blk
    
    sqc.bestblk = blkinfo['headers'] if 'headers' in blkinfo else 0
    if 'blkdat' in sqc.cfg and (sqc.bestblk - startblk) > coincfg(BLKDAT_NEAR_SYNC):
        blk = getBlocksDirect(cur, blk, chainwork) # use direct file access for catch up
    if not sqc.done.isSet():
        log("Using rpc mode. Monitoring blocks / mempool on " + sqc.cfg['cointype'])
    poll_delay = 0.05
    while not sqc.done.isSet():
        blkinfo = sqc.rpc.getblockchaininfo()
        if blkinfo is None or blk > blkinfo['blocks']:
            if sqc.zmq and 'zmq' in sqc.cfg:
                blk = getBlocksZMQ(cur, blk, chainwork) # try to upgrade to ZMQ, more efficient
            else:
                checkMemPool(cur)
                time.sleep(5)
            continue
        if blockQ.qsize() >= sqc.cfg['queue']:
            time.sleep(min(poll_delay,5))
            poll_delay *= 2
            continue
        poll_delay = 0.05
        if 'pruned' in blkinfo and blkinfo['pruned']:
            chkPruning(blk - blockQ.qsize())
        rpcstart = time.time()
        blkhash = sqc.rpc.getblockhash(blk)
        if blkhash is not None:
            data = decodeBlock(sqc.rpc.getblock(blkhash, False).decode('hex'))
            data['height'] = blk
            chainwork += blockwork(data['bits'])
            data['chainwork'] = int2bin32(chainwork)
            data['rpc'] = time.time()-rpcstart
            blockQ.put(data)
            blk += 1
    return blk - startblk

def getBlocksDirect(cur, blk, chainwork):
    blkscan = threading.Thread(target = BlkDatHandler, args=(True,))
    blkscan.start()
    idle_count = 0
    log("Using blkdat mode: %s" % sqc.cfg['blkdat'])
    while not sqc.done.isSet():
        if blockQ.qsize() >= sqc.cfg['queue']:
            time.sleep(0.01)
            continue
        chkPruning(blk - blockQ.qsize())
        if (sqc.bestblk - blk) < coincfg(BLKDAT_NEAR_SYNC):
            logts("Near sync %d. Aborting direct mode" % blk)
            return blk
        if idle_count >= 60:
            logts("No blkdat activity, 3 minutes. Aborting direct mode")
            return blk
        cur.execute("select filenum,filepos from blkdat where id=%s limit 1;", (blk,))
        row = cur.fetchone()
        if row:
            filenum,pos = row
            started = time.time()
            with open(sqc.cfg['blkdat'] + "/blocks/blk%05d.dat" % filenum, 'rb') as fd:
                fd.seek(pos)
                magic,blksize = unpack('<II', fd.read(8))
                if magic != coincfg(BLKDAT_MAGIC):
                    logts("Error reading block %d - bad blkdat index. Aborting direct mode" % blk)
                    return blk
                data = decodeBlock(fd.read(blksize))
                data['height'] = blk
                chainwork += blockwork(data['bits'])
                data['chainwork'] = int2bin32(chainwork)
                data['rpc'] = time.time()-started
                blockQ.put(data)
                idle_count = 0
                blk += 1
        else:
            time.sleep(3)
            idle_count += 1
    log("Waiting for blkdat handler to finish")
    blkscan.join()
    return blk

def getBlocksZMQ(cur, blk, chainwork):
    try:
        import zmq#.green as zmq
    except ImportError:
        log("Warning: Could not find pyzmq; ZMQ disabled")
        sqc.zmq = False # prevent repeated inits
        return blk
    zctx = zmq.Context()
    zsck = zctx.socket(zmq.SUB)
    zsck.setsockopt(zmq.SUBSCRIBE, "rawblock")
    zsck.setsockopt(zmq.SUBSCRIBE, "rawtx")
    zsck.setsockopt(zmq.RCVTIMEO, 1000)
    zsck.setsockopt(zmq.LINGER, 0)
    zsck.connect(sqc.cfg['zmq'])
    log("Upgrade to ZMQ on %s." % sqc.cfg['zmq'])
    while not sqc.done.isSet():
        try:
            msg = zsck.recv_multipart()
            topic,body = msg[0],msg[1]
            if topic == "rawblock":
                started = time.time()
                data = decodeBlock(body)
                data['height'] = blk
                chainwork += blockwork(data['bits'])
                data['chainwork'] = int2bin32(chainwork)
                data['rpc'] = time.time()-started
                blockQ.put(data)
                blk += 1
            elif topic == "rawtx":
                insertTxMemPool(cur, decodeTx(body), 0) # 0 sync_id, no mempool
        except zmq.Again:
            pass
    zctx.destroy()
    return blk

def chkPruning(blk):
    if blk > 120000 and blk % 100 == 0:
        blkinfo = sqc.rpc.getblockchaininfo()
        if blkinfo is not None and blkinfo['pruned']:
            keep = 20 if not 'prune' in sqc.cfg else max(sqc.cfg['prune'], 20)
            sqc.rpc.pruneblockchain(blk-keep) # keep at least 20 blocks for reorgs but can now config higher
        sqc.bestblk = blkinfo['headers']

def limitBlocks(cur, max_blks):
    cur.execute("select id from blocks order by id desc limit %s, 1", (max_blks,))
    row = cur.fetchone()
    if row:
        blkid = row[0]
        cur.execute("select id from trxs where block_id < %s", ((blkid+1)*MAX_TX_BLK,))
        txids = [ txid for txid, in cur ]
        for txid in txids:
            cur.execute("delete from outputs where id >= %s*{0} and id < %s*{0}".format(MAX_IO_TX), (txid,txid))
            cur.execute("delete from trxs where id=%s", (txid,))
        cur.execute("delete from blocks where id<=%s", (blkid,))
    
def BlockHandler():
    sql = db.connect(*sqc.cfg['db'].split(':'))
    cur = sql.cursor()
    while not sqc.done.isSet():
        try:
            insertBlock(cur, blockQ.get(True, 5))
            sql.commit()
            if 'max_blks' in sqc.cfg:
                limitBlocks(cur, sqc.cfg['max_blks'])
                sql.commit()
        except Empty:
            pass

def OutputHandler():
    sql = db.connect(*sqc.cfg['db'].split(':'))
    cur = sql.cursor()
    ins,outs = [],[]
    while True:
        try:
            xo,xi = outQ.get(True, 3)
            sqc.flushed.clear()
            outs.extend(xo)
            ins.extend(xi)
            if len(outs) + len(ins) > 8192:
                cur.executemany("insert ignore into outputs (id,addr_id,value) values(%s,%s,%s)", outs)
                cur.executemany("update outputs set tx_id=%s where id=%s limit 1", ins)
                sql.commit()
                ins,outs = [],[]
        except Empty:
            if len(outs) > 0 or len(ins) > 0:
                cur.executemany("insert ignore into outputs (id,addr_id,value) values(%s,%s,%s)", outs)
                cur.executemany("update outputs set tx_id=%s where id=%s limit 1", ins)
                sql.commit()
            sqc.flushed.set()
            if sqc.alldone.isSet():
                print "Flushed outQ - outs %d - ins %d" % (len(outs), len(ins))
                return
            ins,outs = [],[]

def checkMemPool(cur):
    cur.execute("select ifnull(max(sync_id),0) from mempool;")
    sync_id = cur.fetchone()[0]
    if len(memPool) == 0:
        cur.execute("delete from mempool;")
    trxs = sqc.rpc.getrawmempool()
    if trxs is not None:
        for tx in trxs:
            txx = tx.decode('hex')[::-1][:8] # uses 1/4 space, only for detecting changes in mempool
            if txx not in memPool:
                rawtx = sqc.rpc.getrawtransaction(tx,0)
                if rawtx is not None:
                    insertTxMemPool(cur, decodeTx(rawtx.decode('hex')), sync_id+1)
                    memPool.add(txx)

def addOrphan(cur, height):
    cur.execute("select ifnull(max(sync_id),0) from mempool;")
    sync_id = cur.fetchone()[0]
    hdr = gethdr(height, sqc.cfg, 'raw')
    cur.execute("select hash,coinbase from blocks where id=%s limit 1;", (height,))
    for blkhash,coinbase in cur:
        cur.execute("insert into orphans (sync_id,block_id,hash,hdr,coinbase) values(%s,%s,%s,%s,%s);", (sync_id,height,blkhash,hdr,coinbase))

def checkReOrg(cur, data):
    if data['height'] == 0:
        return
    blkhash,height = data['previousblockhash'],data['height']-1
    while True:
        cur.execute("select id from blocks where hash=%s limit 1;", (blkhash,))
        row = cur.fetchone()
        if row is None:
            log("No previous block %d - ok if first run, ReOrg aborted" % height)
            return
        if row and row[0] == height: # rewind until block in good chain
            break
        height -= 1
        blkhash = sqc.rpc.getblockhash(height).decode('hex')[::-1]
        if blkhash is None:
            log("Rewind failure during ReOrg - Check manually.")
            return
    height += 1
    if height < data['height']:
        sqc.flushed.wait() # make sure all outputs committed before we start re-org
        cur.execute("update trxs set block_id=-1 where block_id >= %s;", (height*MAX_TX_BLK,)) # set bad chain txs uncfmd
        logts("Block %d *** ReOrg: %d orphan(s), %d txs affected" % (height, data['height']-height, cur.rowcount))
        doReOrg(cur, data, height)

def doReOrg(cur, data, height):
    while height < data['height']:
        blkhash = sqc.rpc.getblockhash(height)
        if blkhash is None:
            log("Block failure during ReOrg - Check manually.")
            return
        data = decodeBlock(sqc.rpc.getblock(blkhash, False).decode('hex')) # get good chain blocks
        if data:
            for n,tx in enumerate(data['tx']):
                tx_id,found = findTx(cur, tx['txid'], mkNew=True)
                if found:
                    cur.execute("update trxs set block_id=%s where id=%s limit 1;", (height*MAX_TX_BLK+n, tx_id))
                else:
                    insertTx(cur, tx, tx_id, height*MAX_TX_BLK + n) # occurs if tx wasn't in our mempool or orphan block
            addOrphan(cur, height)
            data['chainwork'] = sqc.rpc.getblockheader(data['hash'][::-1].encode('hex'))['chainwork'].decode('hex')
            cur.execute("update blocks set hash=%s,coinbase=%s,chainwork=%s,blksize=%s where id=%s;", (data['hash'], data['coinbase'], data['chainwork'], data['size'], height))
            puthdr(data['height'], sqc.cfg, data['hdr'])
        height += 1

def insertTxMemPool(cur, tx, sync_id):
    tx_id,found = findTx(cur, tx['txid'], mkNew=True)
    if not found:
        insertTx(cur, tx, tx_id, -1) # -1 means trx has no block
    cur.execute("insert ignore into mempool (id,sync_id) values(%s,%s);", (tx_id, sync_id))

def insertTx(cur, tx, tx_id, blk_id): # pylint:disable=too-many-locals
    inlist,outlist = [],[]
    in_ids,txdata = '',''
    tx['stdSeq'] = True
    for vin in tx['vin']:
        if vin['sequence'] != 0xffffffff:
            tx['stdSeq'] = False
            break
    for vin in tx['vin']:
        if 'coinbase' not in vin:
            in_id = findTx(cur, vin['txid'])
            if in_id and vin['vout'] < MAX_IO_TX:
                in_id = (in_id*MAX_IO_TX) + vin['vout']
                inlist.append(( tx_id, in_id ))
                in_ids += pack('<Q', in_id)[:7]
                if not sqc.cfg['no-sigs']:
                    txdata += encodeVarInt(len(vin['scriptSig'])) + vin['scriptSig']
                if not tx['stdSeq']:
                    txdata += pack('<I', vin['sequence'])

    for vout in tx['vout']:
        addr_id = insertAddress(cur, vout['scriptPubKey']['addr']) if 'addr' in vout['scriptPubKey'] else 0
        if vout['n'] < MAX_IO_TX:
            outlist.append(( tx_id*MAX_IO_TX + vout['n'], addr_id, vout['value'] ))
            txdata += encodeVarInt(len(vout['scriptPubKey']['data'])) + vout['scriptPubKey']['data']
    outQ.put((outlist,inlist))
    ins,outs,sz,txhdr = mkBlobHdr(tx, len(inlist), len(outlist), sqc.cfg['no-sigs'])
    if (not sqc.cfg['no-sigs']) and 'segwit' in tx and tx['segwit']:
        txdata += tx['witness']
    cur.execute("insert ignore into trxs (id,hash,ins,outs,txsize,txdata,block_id) values(%s,%s,%s,%s,%s,%s,%s)",
        ( tx_id, tx['txid'], ins, outs, sz, insertBlob(txhdr + in_ids + txdata, sqc.cfg), blk_id ))

def insertBlock(cur, data):
    checkReOrg(cur, data)
    blkstart = time.time()
    blk_id = data['height']*MAX_TX_BLK
    for n,tx in enumerate(data['tx']):
        tx_id,found = findTx(cur, tx['txid'], mkNew=True)
        if found:
            if tx['txid'][:8] in memPool:
                memPool.remove(tx['txid'][:8])
                cur.execute("delete from mempool where id=%s limit 1;", (tx_id,))
            cur.execute("update trxs set block_id=%s where id=%s limit 1;", (blk_id + n, tx_id))
            if cur.rowcount > 0:
                continue
        insertTx(cur, tx, tx_id, blk_id + n)

    cur.execute("insert ignore into blocks (id,hash,coinbase,chainwork,blksize) values (%s,%s,%s,%s,%s);", (data['height'], data['hash'], data['coinbase'], data['chainwork'], data['size']))
    puthdr(data['height'], sqc.cfg, data['hdr'])

    blktime = time.time() - blkstart
    log("Block %d [ Q:%d %4d txs - %s - %3.0fms %2.1fs %3.0f tx/s]" % ( data['height'], blockQ.qsize(),
        len(data['tx']), datetime.fromtimestamp(data['time']).strftime('%d-%m-%Y'), data['rpc']*1000, blktime, len(data['tx'])/blktime) )

    blksecs.append(blktime)
    if len(blksecs) > 18: # ~3 hour moving avg
        del blksecs[0]
    cur.execute("replace into info (class,`key`,value) value('info','avg-block-sync',%s);", ("%2.1f"%(sum(blksecs)/len(blksecs)), ))

def options(cfg): # pylint:disable=too-many-branches
    try:
        opts,_ = getopt.getopt(sys.argv[1:], "hvd:l:r:w:p:q:u:b:f:",
            ["help", "version", "debug", "db=", "log=", "rpc=", "path=", "queue=", "user=", "block=", "blkdat=", "no-sigs", "defaults" ])
    except getopt.GetoptError:
        usage()
    for opt,arg in opts:
        if opt in ("-h", "--help"):
            usage()
        elif opt in ("-v", "--version"):
            sys.exit(sys.argv[0]+': '+version)
        elif opt in ("-d", "--db"):
            cfg['db'] = arg
        elif opt in ("-l", "--log"):
            cfg['log'] = arg
        elif opt in ("-r", "--rpc"):
            cfg['rpc'] = arg
        elif opt in ("-p", "--path"):
            cfg['path'] = arg
        elif opt in ("-q", "--queue"):
            cfg['queue'] = int(arg)
        elif opt in ("-u", "--user"):
            cfg['user'] = arg
        elif opt in "--no-sigs":
            cfg['no-sigs'] = True
        elif opt in "--defaults":
            savecfg(cfg)
            sys.exit("%s updated" % (sys.argv[0]+'.cfg'))
        elif opt in ("-b", "--block"):
            cfg['block'] = int(arg)
        elif opt in "--debug":
            cfg['debug'] = True
        elif opt in ("-f", "--blkdat"):
            cfg['blkdat'] = arg

def usage():
    print """Usage: {0} [options...][cfg file]\nCommand options are:\n-h,--help\tShow this help info\n-v,--version\tShow version info
-b,--block\tStart at block number (instead of from last block done)
-f,--blkdat\tSet path to block data and use direct file access (no mempool/re-org)
--debug\t\tRun in foreground with logging to console
--defaults\tUpdate cfg and exit\nDefault files are {0}.cfg, {0}.log
\nThese options get saved in cfg file as defaults.
-p,--path\tSet path for blob and block header data file (/var/data/sqlchain)
-q,--queue\tSet block queue size (8)\n-u,--user\tSet user to run as\n-d,--db  \tSet mysql db connection, "host:user:pwd:dbname"
-l,--log\tSet log file path\n-r,--rpc\tSet rpc connection, "http://user:pwd@host:port"
--no-sigs\tDo not store input sigScript data """.format(sys.argv[0])
    sys.exit(2)

def sigterm_handler(_signo, _stack_frame):
    sqc.done.set()
def sighup_handler(_signo, _stack_frame):
    path = sqc.cfg['log'] if 'log' in sqc.cfg else sys.argv[0]+'.log'
    sys.stdout.close()
    sys.stdout=open(path,'a')
    sys.stderr.close()
    sys.stderr=open(path,'a')
    logts("SIGHUP Log reopened")

def run():
    sqc.done = threading.Event()
    sqc.alldone = threading.Event()
    sqc.flushed = threading.Event()

    sqlchain_overlay(sqc.cfg['cointype'])

    blkwrk = threading.Thread(target = BlockHandler)
    blkwrk.start()
    outwrk = threading.Thread(target = OutputHandler)
    outwrk.start()

    blksdone = None
    workstart = time.time()
    while not sqc.done.isSet():
        try:
            blksdone = getBlocks(sqc.cfg['block'] if 'block' in sqc.cfg else 0)
            break
        except socket.error:
            log("Cannot connect to rpc")
            time.sleep(5)

    sqc.done.set()
    blkwrk.join()
    sqc.alldone.set()
    outwrk.join()
    if blksdone:
        log("Session %d blocks, %.2f blocks/s" % (blksdone, float(blksdone / (time.time() - workstart))) )

if __name__ == '__main__':

    loadcfg(sqc.cfg)
    options(sqc.cfg)
    drop2user(sqc.cfg)

    sqc.rpc = rpcPool(sqc.cfg)
    blockQ = Queue()
    outQ = Queue(64)

    if sqc.cfg['debug']:
        signal.signal(signal.SIGINT, sigterm_handler)
        run()
    else:
        logpath = sqc.cfg['log'] if 'log' in sqc.cfg else sys.argv[0]+'.log'
        pidpath = sqc.cfg['pid'] if 'pid' in sqc.cfg else sys.argv[0]+'.pid'
        with daemon.DaemonContext(working_directory='.', umask=0002, stdout=open(logpath,'a'), stderr=open(logpath,'a'),
                signal_map={signal.SIGTERM:sigterm_handler, signal.SIGHUP:sighup_handler} ):
            with file(pidpath,'w') as f:
                f.write(str(os.getpid()))
            run()
            os.unlink(pidpath)
