#!/usr/bin/env python
#
#   sqlchaind - daemon to update sql blockchain db
#
import os, sys, socket, getopt, time, signal, threading, daemon
import MySQLdb as db

from Queue import Queue, Empty
from datetime import datetime
from bitcoinrpc.authproxy import AuthServiceProxy
from struct import pack, unpack, unpack_from
from hashlib import sha256

from sqlchain.version import *
from sqlchain.util import *
from sqlchain.blkdat import BlkDatHandler

cfg = { 'log':sys.argv[0]+'.log', 'queue':8, 'no-sigs':False, 'db':'', 'rpc':'', 'path':'/var/data/sqlchain' } 
blksecs = []
memPool = set()

def getBlocks(blk):
    global done
    rpc = AuthServiceProxy(cfg['rpc'])
    sql = db.connect(*cfg['db'].split(':'))
    cur = sql.cursor()
    if blk == 0:
        cur.execute('select ifnull(max(id), -1) from blocks')
        blk = int(cur.fetchone()[0] + 1)
    startblk = blk
    if 'blkdat' in cfg:
        blk = getBlocksDirect(cur, blk) # use direct file access for catch up
    if not done.isSet():
        logts("Using rpc mode. Monitoring blocks / mempool.")
    while not done.isSet():
        blkinfo = rpc.getblockchaininfo()
        if blk > blkinfo['blocks']:
            checkMemPool(cur, rpc)
            time.sleep(10)
            continue
        if blockQ.qsize() >= cfg['queue']:
            time.sleep(10)
            continue
        if blkinfo['pruned']:
            chkPruning(blk - blockQ.qsize())
        rpcstart = time.time()
        blkhash = rpc.getblockhash(blk)
        if blkhash:
            data = decodeBlock(rpc.getblock(blkhash, False).decode('hex'))
            data['height'] = blk
            data['rpc'] = time.time()-rpcstart
            blockQ.put(data)
            blk += 1
    return blk - startblk

def getBlocksDirect(cur, blk):
    global done
    blkscan = threading.Thread(target = BlkDatHandler, args=(cfg,done,True))
    blkscan.start()
    logts("Using blkdat mode: %s" % cfg['blkdat'])
    while not done.isSet():
        cur.execute("select count(1) from information_schema.tables where table_name='blkdat';")
        if cur.fetchone()[0]:
            break
        time.sleep(5)
        logts("Waiting for blkdat to be ready")
    while not done.isSet():
        if blockQ.qsize() >= cfg['queue']:
            time.sleep(0.01)
            continue
        if chkPruning(blk - blockQ.qsize()) - blk < 288:
            logts("Near sync %d. Aborting direct mode" % blk)
            return blk
        cur.execute("select filenum,filepos from blkdat where id=%s limit 1;", (blk,))
        row = cur.fetchone()
        if row:
            filenum,pos = row
            started = time.time()
            with open(cfg['blkdat'] + "/blocks/blk%05d.dat" % filenum, 'rb') as fd:
                fd.seek(pos)
                magic,blksize = unpack('<II', fd.read(8))
                if magic != 0xD9B4BEF9:
                    logts("Error reading block %d - bad blkdat index. Aborting direct mode" % blk)
                    return blk
                data = decodeBlock(fd.read(blksize))
                data['height'] = blk
                data['rpc'] = time.time()-started
                blockQ.put(data)
                blk += 1
        else:
            time.sleep(3)
    log("Waiting for blkscan to finish")
    blkscan.join()
    return blk
    
def chkPruning(blk):
    if blk > 120000 and blk % 100 == 0:
        try:
            rpc = AuthServiceProxy(cfg['rpc'])
            blkinfo = rpc.getblockchaininfo()
            if blkinfo['pruned']:
                rpc.pruneblockchain(blk)
            return blkinfo['headers']
        except Exception, e:
            log( 'Prune rpc ' + str(e) + ' (ignoring)' )
    return blk + 500 # faked to prevent aborting early
    
def BlockHandler():
    global done
    sql = db.connect(*cfg['db'].split(':'))
    cur = sql.cursor()
    while not done.isSet():
        try: insertBlock(cur, blockQ.get(True, 5))
        except Empty: pass

def OutputHandler():
    global flush
    sql = db.connect(*cfg['db'].split(':'))
    cur = sql.cursor()
    ins,outs = [],[]
    while True:
        try: 
            xo,xi = outQ.get(True, 3)
            outs.extend(xo)
            ins.extend(xi)
            if len(outs) > 8192 or len(ins) > 8192:
                cur.executemany("insert ignore into outputs (id,addr_id,value) values(%s,%s,%s)", outs)
                cur.executemany("update outputs set tx_id=%s where id=%s limit 1", ins)
                ins,outs = [],[]
        except Empty: 
            if len(outs) > 0 or len(ins) > 0:
                print "Flushed outQ - outs %d - ins %d" % (len(outs), len(ins))
                cur.executemany("insert ignore into outputs (id,addr_id,value) values(%s,%s,%s)", outs)
                cur.executemany("update outputs set tx_id=%s where id=%s limit 1", ins)
                ins,outs = [],[]
            if flush.isSet():
                return
            pass
    
def checkMemPool(cur, rpc):
    cur.execute("select ifnull(max(sync_id),0) from mempool;")
    sync_id = cur.fetchone()[0]
    if len(memPool) == 0:
        cur.execute("delete from mempool;")
    trxs = rpc.getrawmempool()
    for tx in trxs:
        txx = tx.decode('hex')[::-1][:8] # uses 1/4 space, only for detecting changes in mempool
        if txx not in memPool:
            memPool.add(txx)
            insertTxMemPool(cur, decodeTx(rpc.getrawtransaction(tx,0).decode('hex')), sync_id+1)

def addOrphan(cur, height):
    cur.execute("select ifnull(max(sync_id),0) from mempool;")
    sync_id = cur.fetchone()[0]
    hdr = gethdr(height, 'raw', sqc.cfg['path'])
    cur.execute("select hash,coinbase from blocks where id=%s limit 1;", (height,))
    for blkhash,coinbase in cur:
        cur.execute("insert into orphans (sync_id,block_id,hash,hdr,coinbase) values(%s,%s,%s,%s,%s);", (sync_id,height,blkhash,hdr,coinbase))
    
def decodeBlock(data):
    hdr = ['version','previousblockhash','merkleroot', 'time', 'bits', 'nonce']
    hv = unpack_from('<I32s32s3I', data)
    block = dict(zip(hdr,hv))
    block['hdr'] = data[:80]
    block['hash'] = sha256(sha256(block['hdr']).digest()).digest()
    block['bits'] = '%08x' % block['bits'] 
    txcnt,off = decodeVarInt(data[80:89])
    off += 80
    block['tx'] = []
    while txcnt > 0:
        tx = decodeTx(data[off:])
        block['tx'].append(tx)
        off += tx['size']
        txcnt -= 1
    block['height'] = 0
    block['coinbase'] = block['tx'][0]['vin'][0]['coinbase']
    if block['version'] > 1 and block['height'] >= 227836 and block['coinbase'][0] == '\x03':
        block['height'] = unpack('<I', block['coinbase'][1:4]+'\0')[0]
    return block

def decodeTx(data):
    vers, = unpack_from('<I', data)
    tx = { 'version': vers, 'vin':[], 'vout':[] }
    vicnt,off = decodeVarInt(data[4:13])
    off += 4
    while vicnt > 0:
        txid,vout = unpack_from('<32sI', data, off)
        sigsz,soff = decodeVarInt(data[off+36:off+36+9])
        off += soff+36
        seq, = unpack_from('<I', data, off+sigsz)
        if txid == '\0'*32 and vout == 0xffffffff: 
            tx['vin'].append({'coinbase':data[off:off+sigsz], 'sequence':seq })
        else:
            tx['vin'].append({'txid':txid, 'vout':vout, 'scriptSig':data[off:off+sigsz], 'sequence':seq })
        off += sigsz+4
        vicnt -= 1
    vocnt,voff = decodeVarInt(data[off:off+9])
    off += voff
    n = 0
    while n < vocnt:
        value, = unpack_from('<Q', data, off)
        pksz,soff = decodeVarInt(data[off+8:off+8+9])
        off += 8+soff
        tx['vout'].append({'value':value, 'n':n, 'scriptPubKey':decodeScriptPK( data[off:off+pksz] ) }) 
        off += pksz
        n += 1
    tx['locktime'], = unpack_from('<I', data, off)
    tx['size'] = off+4
    tx['txid'] = sha256(sha256(data[:tx['size']]).digest()).digest()
    return tx

def checkReOrg(cur, data):
    if data['height'] == 0:
        return
    rpc = AuthServiceProxy(cfg['rpc'])
    blkhash,height = data['previousblockhash'],data['height']-1
    while True:
        cur.execute("select id from blocks where hash=%s limit 1;", (blkhash,))
        row = cur.fetchone()
        if row and row[0] == height: # rewind until block in good chain
            break
        height -= 1
        blkhash = rpc.getblockhash(height).decode('hex')[::-1] 
    height += 1
    if height < data['height']:
        cur.execute("update trxs set block_id=-1 where block_id >= %s;", (height*MAX_TX_BLK,)) # set bad chain txs uncfmd
        logts("Block %d *** ReOrg: %d orphan(s), %d txs affected" % (height, data['height']-height, cur.rowcount))
        while height < data['height']:
            blkhash = rpc.getblockhash(height)
            if blkhash:
                data = decodeBlock(rpc.getblock(blkhash, False).decode('hex')) # get good chain blocks
                if data:
                    for n,tx in enumerate(data['tx']):
                        tx_id,found = findTx(cur, tx['txid'], mkNew=True)
                        if found:
                            cur.execute("update trxs set block_id=%s where id=%s limit 1;", (height*MAX_TX_BLK+n, tx_id))
                        else:
                            insertTx(cur, tx, tx_id, height*MAX_TX_BLK + n) # occurs if tx wasn't in our mempool or orphan block
                    addOrphan(cur, height)
                    cur.execute("update blocks set hash=%s,coinbase=%s where id=%s;", (data['hash'], data['coinbase'], height))
                    puthdr(data['height'], data['hdr'], sqc.cfg['path'])
            height += 1
       
def insertTxMemPool(cur, tx, sync_id):
    tx_id,found = findTx(cur, tx['txid'], mkNew=True)
    if not found:
        insertTx(cur, tx, tx_id, -1) # -1 means trx has no block
    cur.execute("insert ignore into mempool (id,sync_id) values(%s,%s);", (tx_id, sync_id))

def insertTx(cur, tx, tx_id, blk_id):
    inlist,outlist = [],[]
    in_ids,txdata = '','' 
    stdSeq = True
    for vin in tx['vin']:
        if vin['sequence'] != 0xffffffff:
            stdSeq = False
            break
    for vin in tx['vin']:
        if 'coinbase' not in vin:
            in_id = findTx(cur, vin['txid'])
            if in_id and vin['vout'] < MAX_IO_TX:
                in_id = (in_id*MAX_IO_TX) + vin['vout']
                inlist.append(( tx_id, in_id ))
                in_ids += pack('<Q', in_id)[:7]
                if not cfg['no-sigs']:
                    txdata += encodeVarInt(len(vin['scriptSig'])) + vin['scriptSig']
                if not stdSeq:
                    txdata += pack('<I', vin['sequence'])
                                       
    for vout in tx['vout']:
        addr_id = insertAddress(cur, vout['scriptPubKey']['addr']) if 'addr' in vout['scriptPubKey'] else 0
        if vout['n'] < MAX_IO_TX:
            outlist.append(( tx_id*MAX_IO_TX + vout['n'], addr_id, vout['value'] ))
            txdata += encodeVarInt(len(vout['scriptPubKey']['data'])) + vout['scriptPubKey']['data']
    outQ.put((outlist,inlist))
    ins,outs,sz,txhdr = mkBlobHdr(len(inlist), len(outlist), tx, stdSeq, cfg['no-sigs'])
    cur.execute("insert ignore into trxs (id,hash,ins,outs,txsize,txdata,block_id) values(%s,%s,%s,%s,%s,%s,%s)", 
        ( tx_id, tx['txid'], ins, outs, sz, insertBlob(txhdr + in_ids + txdata, cfg['path']), blk_id ))    
            
def insertBlock(cur, data):
    checkReOrg(cur, data)
    blkstart = time.time()
    blk_id = data['height']*MAX_TX_BLK
    for n,tx in enumerate(data['tx']):
        tx_id,found = findTx(cur, tx['txid'], mkNew=True)
        if found:
            if tx['txid'][:8] in memPool:
                memPool.remove(tx['txid'][:8])
                cur.execute("delete from mempool where id=%s limit 1;", (tx_id,))
            cur.execute("update trxs set block_id=%s where id=%s limit 1;", (blk_id + n, tx_id))
            if cur.rowcount > 0:
                continue
        insertTx(cur, tx, tx_id, blk_id + n)
        
    cur.execute("insert ignore into blocks (id,hash,coinbase) values (%s,%s,%s);", (data['height'], data['hash'], data['coinbase']))
    puthdr(data['height'], data['hdr'], cfg['path'])
    
    blktime = time.time() - blkstart
    log("Block %d [ Q:%d %4d txs - %s - %3.0fms %2.1fs %3.0f tx/s]" % ( data['height'], blockQ.qsize(), 
        len(data['tx']), datetime.fromtimestamp(data['time']).strftime('%d-%m-%Y'), data['rpc']*1000, blktime, len(data['tx'])/blktime) )

    blksecs.append(blktime)
    if len(blksecs) > 18: # ~3 hour moving avg
        del blksecs[0]
    cur.execute("replace into info (class,`key`,value) value('info','avg-block-sync',%s);", ("%2.1f"%(sum(blksecs)/len(blksecs)), ))

def options(cfg):
    try:                                
        opts,args = getopt.getopt(sys.argv[1:], "hvd:l:r:w:p:q:u:b:f:", 
            ["help", "version", "debug", "db=", "log=", "rpc=", "path=", "queue=", "user=", "block=", "blkdat=", "no-sigs", "defaults" ])
    except getopt.GetoptError:
        usage()
    for opt,arg in opts:
        if opt in ("-h", "--help"):
            usage()
        elif opt in ("-v", "--version"):
            sys.exit(sys.argv[0]+': '+version)
        elif opt in ("-d", "--db"):
            cfg['db'] = arg
        elif opt in ("-l", "--log"):
            cfg['log'] = arg
        elif opt in ("-r", "--rpc"):
            cfg['rpc'] = arg
        elif opt in ("-p", "--path"):
            cfg['path'] = arg
        elif opt in ("-q", "--queue"):
            cfg['queue'] = int(arg)
        elif opt in ("-u", "--user"):
            cfg['user'] = arg 
        elif opt in ("--no-sigs"):
            cfg['no-sigs'] = True 
        elif opt in ("--defaults"):
            savecfg(cfg)
            sys.exit("%s updated" % (sys.argv[0]+'.cfg'))
        elif opt in ("-b", "--block"):
            cfg['block'] = int(arg)
        elif opt in ("--debug"):
            cfg['debug'] = True
        elif opt in ("-f", "--blkdat"):
            cfg['blkdat'] = arg
        
def usage():
    print """Usage: {0} [options...][cfg file]\nCommand options are:\n-h,--help\tShow this help info\n-v,--version\tShow version info
-b,--block\tStart at block number (instead of from last block done)
-f,--blkdat\tSet path to bitcoin data and use direct file access (no mempool/re-org)
--debug\t\tRun in foreground with logging to console
--defaults\tUpdate cfg and exit\nDefault files are {0}.cfg, {0}.log
\nThese options get saved in cfg file as defaults.
-p,--path\tSet path for blob and block header data file (/var/data)
-q,--queue\tSet block queue size (8)\n-u,--user\tSet user to run as\n-d,--db  \tSet mysql db connection, "host:user:pwd:dbname"
-l,--log\tSet log file path\n-r,--rpc\tSet rpc connection, "http://user:pwd@host:port"
--no-sigs\tDo not store input sigScript data """.format(sys.argv[0])
    sys.exit(2) 
    
def sigterm_handler(_signo, _stack_frame):
    global done
    done.set()
def sighup_handler(_signo, _stack_frame):
    logpath = cfg['log'] if 'log' in cfg else sys.argv[0]+'.log'
    sys.stdout.close()
    sys.stdout=open(logpath,'a')
    sys.stderr.close()
    sys.stderr=open(logpath,'a')
    logts("SIGHUP Log reopened")
    
def run():
    global blockQ,outQ,done,flush 
    
    blockQ = Queue()
    outQ = Queue(64)
    done = threading.Event()
    flush = threading.Event()
    
    blkwrk = threading.Thread(target = BlockHandler)
    blkwrk.start()
    outwrk = threading.Thread(target = OutputHandler)
    outwrk.start()
  
    blksdone = None
    workstart = time.time()
    while True:
        try:
            blksdone = getBlocks(cfg['block'] if 'block' in cfg else 0)
            break
        except socket.error:
            log("Cannot connect to rpc")
            time.sleep(5) 
            pass  
        
    done.set()
    blkwrk.join()
    flush.set()
    outwrk.join()
    if blksdone:
        log("Session %.2f blks/s" % float(blksdone / (time.time() - workstart)) )
    
if __name__ == '__main__':

    loadcfg(cfg)
    options(cfg)
    drop2user(cfg)
       
    if cfg['debug']:
        signal.signal(signal.SIGINT, sigterm_handler)
        run()
    else:
        logpath = cfg['log'] if 'log' in cfg else sys.argv[0]+'.log'
        pidpath = cfg['pid'] if 'pid' in cfg else sys.argv[0]+'.pid'
        with daemon.DaemonContext(working_directory='.', umask=0002, stdout=open(logpath,'a'), stderr=open(logpath,'a'), 
                signal_map={signal.SIGTERM:sigterm_handler, signal.SIGHUP:sighup_handler} ):
            with file(pidpath,'w') as f: 
                f.write(str(os.getpid()))
            run()
            os.unlink(pidpath)
    

    


    

