#!/usr/bin/env python
#
#   sqlchaind - daemon to update sql blockchain db
#
import os, sys, socket, getopt, time, signal, threading, daemon
import MySQLdb as db

from Queue import Queue, Empty
from datetime import datetime
from struct import pack, unpack, unpack_from
from hashlib import sha256

from sqlchain.version import *
from sqlchain.util import *
from sqlchain.blkdat import BlkDatHandler

__builtins__.sqc = dotdict()  # container for super globals

cfg = { 'log':sys.argv[0]+'.log', 'queue':8, 'no-sigs':False, 'db':'', 'rpc':'', 'path':'/var/data/sqlchain' } 
blksecs = []
memPool = set()
maxblk = 120000

def getBlocks(blk):
    global maxblk
    sql = db.connect(*cfg['db'].split(':'))
    cur = sql.cursor()
    if blk == 0:
        cur.execute('select ifnull(max(id), -1) from blocks')
        blk = int(cur.fetchone()[0] + 1)
    startblk = blk
    blkinfo = rpc.getblockchaininfo()
    maxblk = blkinfo['headers']
    if 'blkdat' in cfg and (maxblk - startblk) > 500:
        blk = getBlocksDirect(cur, blk) # use direct file access for catch up
    if not sqc.done.isSet():
        logts("Using rpc mode. Monitoring blocks / mempool.")
    while not sqc.done.isSet():
        blkinfo = rpc.getblockchaininfo()
        if blk > blkinfo['blocks']:
            checkMemPool(cur)
            time.sleep(5)
            continue
        if blockQ.qsize() >= cfg['queue']:
            time.sleep(5)
            continue
        if blkinfo['pruned']:
            chkPruning(blk - blockQ.qsize())
        rpcstart = time.time()
        blkhash = rpc.getblockhash(blk)
        if blkhash is not None:
            data = decodeBlock(rpc.getblock(blkhash, False).decode('hex'))
            data['height'] = blk
            data['rpc'] = time.time()-rpcstart
            blockQ.put(data)
            blk += 1
    return blk - startblk

def getBlocksDirect(cur, blk):
    global maxblk
    blkscan = threading.Thread(target = BlkDatHandler, args=(cfg,True))
    blkscan.start()
    logts("Using blkdat mode: %s" % cfg['blkdat'])
    while not sqc.done.isSet():
        cur.execute("select count(1) from information_schema.tables where table_name='blkdat';")
        if cur.fetchone()[0]:
            break
        time.sleep(5)
        logts("Waiting for blkdat to be ready")
    while not sqc.done.isSet():
        if blockQ.qsize() >= cfg['queue']:
            time.sleep(0.01)
            continue
        chkPruning(blk - blockQ.qsize())
        if (maxblk - blk) < 500:
            logts("Near sync %d. Aborting direct mode" % blk)
            return blk
        cur.execute("select filenum,filepos from blkdat where id=%s limit 1;", (blk,))
        row = cur.fetchone()
        if row:
            filenum,pos = row
            started = time.time()
            with open(cfg['blkdat'] + "/blocks/blk%05d.dat" % filenum, 'rb') as fd:
                fd.seek(pos)
                magic,blksize = unpack('<II', fd.read(8))
                if magic != 0xD9B4BEF9:
                    logts("Error reading block %d - bad blkdat index. Aborting direct mode" % blk)
                    return blk
                data = decodeBlock(fd.read(blksize))
                data['height'] = blk
                data['rpc'] = time.time()-started
                blockQ.put(data)
                blk += 1
        else:
            time.sleep(3)
    log("Waiting for blkdat handler to finish")
    blkscan.join()
    return blk
    
def chkPruning(blk):
    global maxblk
    if blk > 120000 and blk % 100 == 0:
        blkinfo = rpc.getblockchaininfo()
        if blkinfo['pruned']:
            rpc.pruneblockchain(blk)
        maxblk = blkinfo['headers']
    
def BlockHandler():
    sql = db.connect(*cfg['db'].split(':'))
    cur = sql.cursor()
    while not sqc.done.isSet():
        try: insertBlock(cur, blockQ.get(True, 5))
        except Empty: pass

def OutputHandler():
    sql = db.connect(*cfg['db'].split(':'))
    cur = sql.cursor()
    ins,outs = [],[]
    while True:
        try: 
            xo,xi = outQ.get(True, 3)
            outs.extend(xo)
            ins.extend(xi)
            if len(outs) + len(ins) > 8192:
                cur.executemany("insert ignore into outputs (id,addr_id,value) values(%s,%s,%s)", outs)
                cur.executemany("update outputs set tx_id=%s where id=%s limit 1", ins)
                ins,outs = [],[]
        except Empty: 
            if len(outs) > 0 or len(ins) > 0:
                cur.executemany("insert ignore into outputs (id,addr_id,value) values(%s,%s,%s)", outs)
                cur.executemany("update outputs set tx_id=%s where id=%s limit 1", ins)
            if sqc.flush.isSet():
                print "Flushed outQ - outs %d - ins %d" % (len(outs), len(ins))
                return
            ins,outs = [],[]
            pass
    
def checkMemPool(cur):
    cur.execute("select ifnull(max(sync_id),0) from mempool;")
    sync_id = cur.fetchone()[0]
    if len(memPool) == 0:
        cur.execute("delete from mempool;")
    trxs = rpc.getrawmempool()
    for tx in trxs:
        txx = tx.decode('hex')[::-1][:8] # uses 1/4 space, only for detecting changes in mempool
        if txx not in memPool:        
            rawtx = rpc.getrawtransaction(tx,0)
            if rawtx is not None:
                insertTxMemPool(cur, decodeTx(rawtx.decode('hex')), sync_id+1)
                memPool.add(txx)           

def addOrphan(cur, height):
    cur.execute("select ifnull(max(sync_id),0) from mempool;")
    sync_id = cur.fetchone()[0]
    hdr = gethdr(height, 'raw', cfg['path'])
    cur.execute("select hash,coinbase from blocks where id=%s limit 1;", (height,))
    for blkhash,coinbase in cur:
        cur.execute("insert into orphans (sync_id,block_id,hash,hdr,coinbase) values(%s,%s,%s,%s,%s);", (sync_id,height,blkhash,hdr,coinbase))
    
def decodeBlock(data):
    hdr = ['version','previousblockhash','merkleroot', 'time', 'bits', 'nonce']
    hv = unpack_from('<I32s32s3I', data)
    block = dict(zip(hdr,hv))
    block['hdr'] = data[:80]
    block['hash'] = sha256(sha256(block['hdr']).digest()).digest()
    block['bits'] = '%08x' % block['bits'] 
    txcnt,off = decodeVarInt(data[80:89])
    off += 80
    block['tx'] = []
    while txcnt > 0:
        tx = decodeTx(data[off:])
        block['tx'].append(tx)
        off += tx['size']
        txcnt -= 1
    block['height'] = 0
    block['coinbase'] = block['tx'][0]['vin'][0]['coinbase']
    if block['version'] > 1 and block['height'] >= 227836 and block['coinbase'][0] == '\x03':
        block['height'] = unpack('<I', block['coinbase'][1:4]+'\0')[0]
    return block

def decodeTx(data):
    vers, = unpack_from('<I', data)
    tx = { 'version': vers, 'vin':[], 'vout':[] }
    vicnt,off = decodeVarInt(data[4:13])
    off += 4
    while vicnt > 0:
        txid,vout = unpack_from('<32sI', data, off)
        sigsz,soff = decodeVarInt(data[off+36:off+36+9])
        off += soff+36
        seq, = unpack_from('<I', data, off+sigsz)
        if txid == '\0'*32 and vout == 0xffffffff: 
            tx['vin'].append({'coinbase':data[off:off+sigsz], 'sequence':seq })
        else:
            tx['vin'].append({'txid':txid, 'vout':vout, 'scriptSig':data[off:off+sigsz], 'sequence':seq })
        off += sigsz+4
        vicnt -= 1
    vocnt,voff = decodeVarInt(data[off:off+9])
    off += voff
    n = 0
    while n < vocnt:
        value, = unpack_from('<Q', data, off)
        pksz,soff = decodeVarInt(data[off+8:off+8+9])
        off += 8+soff
        tx['vout'].append({'value':value, 'n':n, 'scriptPubKey':decodeScriptPK( data[off:off+pksz] ) }) 
        off += pksz
        n += 1
    tx['locktime'], = unpack_from('<I', data, off)
    tx['size'] = off+4
    tx['txid'] = sha256(sha256(data[:tx['size']]).digest()).digest()
    return tx

def checkReOrg(cur, data):
    if data['height'] == 0:
        return
    blkhash,height = data['previousblockhash'],data['height']-1
    while True:
        cur.execute("select id from blocks where hash=%s limit 1;", (blkhash,))
        row = cur.fetchone()
        if row and row[0] == height: # rewind until block in good chain
            break
        height -= 1
        blkhash = rpc.getblockhash(height).decode('hex')[::-1] 
    height += 1
    if height < data['height']:
        cur.execute("update trxs set block_id=-1 where block_id >= %s;", (height*MAX_TX_BLK,)) # set bad chain txs uncfmd
        logts("Block %d *** ReOrg: %d orphan(s), %d txs affected" % (height, data['height']-height, cur.rowcount))
        while height < data['height']:
            blkhash = rpc.getblockhash(height)
            data = decodeBlock(rpc.getblock(blkhash, False).decode('hex')) # get good chain blocks
            if data:
                for n,tx in enumerate(data['tx']):
                    tx_id,found = findTx(cur, tx['txid'], mkNew=True)
                    if found:
                        cur.execute("update trxs set block_id=%s where id=%s limit 1;", (height*MAX_TX_BLK+n, tx_id))
                    else:
                        insertTx(cur, tx, tx_id, height*MAX_TX_BLK + n) # occurs if tx wasn't in our mempool or orphan block
                addOrphan(cur, height)
                cur.execute("update blocks set hash=%s,coinbase=%s where id=%s;", (data['hash'], data['coinbase'], height))
                puthdr(data['height'], data['hdr'], cfg['path'])
            height += 1
       
def insertTxMemPool(cur, tx, sync_id):
    tx_id,found = findTx(cur, tx['txid'], mkNew=True)
    if not found:
        insertTx(cur, tx, tx_id, -1) # -1 means trx has no block
    cur.execute("insert ignore into mempool (id,sync_id) values(%s,%s);", (tx_id, sync_id))

def insertTx(cur, tx, tx_id, blk_id):
    inlist,outlist = [],[]
    in_ids,txdata = '','' 
    stdSeq = True
    for vin in tx['vin']:
        if vin['sequence'] != 0xffffffff:
            stdSeq = False
            break
    for vin in tx['vin']:
        if 'coinbase' not in vin:
            in_id = findTx(cur, vin['txid'])
            if in_id and vin['vout'] < MAX_IO_TX:
                in_id = (in_id*MAX_IO_TX) + vin['vout']
                inlist.append(( tx_id, in_id ))
                in_ids += pack('<Q', in_id)[:7]
                if not cfg['no-sigs']:
                    txdata += encodeVarInt(len(vin['scriptSig'])) + vin['scriptSig']
                if not stdSeq:
                    txdata += pack('<I', vin['sequence'])
                                       
    for vout in tx['vout']:
        addr_id = insertAddress(cur, vout['scriptPubKey']['addr']) if 'addr' in vout['scriptPubKey'] else 0
        if vout['n'] < MAX_IO_TX:
            outlist.append(( tx_id*MAX_IO_TX + vout['n'], addr_id, vout['value'] ))
            txdata += encodeVarInt(len(vout['scriptPubKey']['data'])) + vout['scriptPubKey']['data']
    outQ.put((outlist,inlist))
    ins,outs,sz,txhdr = mkBlobHdr(len(inlist), len(outlist), tx, stdSeq, cfg['no-sigs'])
    cur.execute("insert ignore into trxs (id,hash,ins,outs,txsize,txdata,block_id) values(%s,%s,%s,%s,%s,%s,%s)", 
        ( tx_id, tx['txid'], ins, outs, sz, insertBlob(txhdr + in_ids + txdata, cfg['path']), blk_id ))    
            
def insertBlock(cur, data):
    checkReOrg(cur, data)
    blkstart = time.time()
    blk_id = data['height']*MAX_TX_BLK
    for n,tx in enumerate(data['tx']):
        tx_id,found = findTx(cur, tx['txid'], mkNew=True)
        if found:
            if tx['txid'][:8] in memPool:
                memPool.remove(tx['txid'][:8])
                cur.execute("delete from mempool where id=%s limit 1;", (tx_id,))
            cur.execute("update trxs set block_id=%s where id=%s limit 1;", (blk_id + n, tx_id))
            if cur.rowcount > 0:
                continue
        insertTx(cur, tx, tx_id, blk_id + n)
        
    cur.execute("insert ignore into blocks (id,hash,coinbase) values (%s,%s,%s);", (data['height'], data['hash'], data['coinbase']))
    puthdr(data['height'], data['hdr'], cfg['path'])
    
    blktime = time.time() - blkstart
    log("Block %d [ Q:%d %4d txs - %s - %3.0fms %2.1fs %3.0f tx/s]" % ( data['height'], blockQ.qsize(), 
        len(data['tx']), datetime.fromtimestamp(data['time']).strftime('%d-%m-%Y'), data['rpc']*1000, blktime, len(data['tx'])/blktime) )

    blksecs.append(blktime)
    if len(blksecs) > 18: # ~3 hour moving avg
        del blksecs[0]
    cur.execute("replace into info (class,`key`,value) value('info','avg-block-sync',%s);", ("%2.1f"%(sum(blksecs)/len(blksecs)), ))

def options(cfg):
    try:                                
        opts,args = getopt.getopt(sys.argv[1:], "hvd:l:r:w:p:q:u:b:f:", 
            ["help", "version", "debug", "db=", "log=", "rpc=", "path=", "queue=", "user=", "block=", "blkdat=", "no-sigs", "defaults" ])
    except getopt.GetoptError:
        usage()
    for opt,arg in opts:
        if opt in ("-h", "--help"):
            usage()
        elif opt in ("-v", "--version"):
            sys.exit(sys.argv[0]+': '+version)
        elif opt in ("-d", "--db"):
            cfg['db'] = arg
        elif opt in ("-l", "--log"):
            cfg['log'] = arg
        elif opt in ("-r", "--rpc"):
            cfg['rpc'] = arg
        elif opt in ("-p", "--path"):
            cfg['path'] = arg
        elif opt in ("-q", "--queue"):
            cfg['queue'] = int(arg)
        elif opt in ("-u", "--user"):
            cfg['user'] = arg 
        elif opt in ("--no-sigs"):
            cfg['no-sigs'] = True 
        elif opt in ("--defaults"):
            savecfg(cfg)
            sys.exit("%s updated" % (sys.argv[0]+'.cfg'))
        elif opt in ("-b", "--block"):
            cfg['block'] = int(arg)
        elif opt in ("--debug"):
            cfg['debug'] = True
        elif opt in ("-f", "--blkdat"):
            cfg['blkdat'] = arg
        
def usage():
    print """Usage: {0} [options...][cfg file]\nCommand options are:\n-h,--help\tShow this help info\n-v,--version\tShow version info
-b,--block\tStart at block number (instead of from last block done)
-f,--blkdat\tSet path to bitcoin data and use direct file access (no mempool/re-org)
--debug\t\tRun in foreground with logging to console
--defaults\tUpdate cfg and exit\nDefault files are {0}.cfg, {0}.log
\nThese options get saved in cfg file as defaults.
-p,--path\tSet path for blob and block header data file (/var/data)
-q,--queue\tSet block queue size (8)\n-u,--user\tSet user to run as\n-d,--db  \tSet mysql db connection, "host:user:pwd:dbname"
-l,--log\tSet log file path\n-r,--rpc\tSet rpc connection, "http://user:pwd@host:port"
--no-sigs\tDo not store input sigScript data """.format(sys.argv[0])
    sys.exit(2) 
    
def sigterm_handler(_signo, _stack_frame):
    sqc.done.set()
def sighup_handler(_signo, _stack_frame):
    logpath = cfg['log'] if 'log' in cfg else sys.argv[0]+'.log'
    sys.stdout.close()
    sys.stdout=open(logpath,'a')
    sys.stderr.close()
    sys.stderr=open(logpath,'a')
    logts("SIGHUP Log reopened")
    
def run():
    global blockQ,outQ 
    
    blockQ = Queue()
    outQ = Queue(64)
    sqc.done = threading.Event()
    sqc.flush = threading.Event()
    
    blkwrk = threading.Thread(target = BlockHandler)
    blkwrk.start()
    outwrk = threading.Thread(target = OutputHandler)
    outwrk.start()
  
    blksdone = None
    workstart = time.time()
    while not sqc.done.isSet():
        try:
            blksdone = getBlocks(cfg['block'] if 'block' in cfg else 0)
            break
        except socket.error:
            log("Cannot connect to rpc")
            time.sleep(5) 
            pass  
        
    sqc.done.set()
    blkwrk.join()
    sqc.flush.set()
    outwrk.join()
    if blksdone:
        log("Session %d blocks, %.2f blocks/s" % (blksdone, float(blksdone / (time.time() - workstart))) )
    
if __name__ == '__main__':

    loadcfg(cfg)
    options(cfg)
    drop2user(cfg)
    
    rpc = rpcPool(cfg)
    
    if cfg['debug']:
        signal.signal(signal.SIGINT, sigterm_handler)
        run()
    else:
        logpath = cfg['log'] if 'log' in cfg else sys.argv[0]+'.log'
        pidpath = cfg['pid'] if 'pid' in cfg else sys.argv[0]+'.pid'
        with daemon.DaemonContext(working_directory='.', umask=0002, stdout=open(logpath,'a'), stderr=open(logpath,'a'), 
                signal_map={signal.SIGTERM:sigterm_handler, signal.SIGHUP:sighup_handler} ):
            with file(pidpath,'w') as f: 
                f.write(str(os.getpid()))
            run()
            os.unlink(pidpath)
    

    


    

