#!/usr/bin/env python
#
#   sqlchaind - daemon to update sql blockchain db
#
import os, sys, socket, getopt, time, signal, threading
import MySQLdb as db

from Queue import Queue, Empty
from datetime import datetime
from bitcoinrpc.authproxy import AuthServiceProxy
from struct import pack, unpack, unpack_from
from hashlib import sha256

from sqlchain.version import *
from sqlchain.util import *

cfg = { 'log':sys.argv[0]+'.log', 'queue':8, 'no-sigs':False, 'db':'', 'rpc':'' } 
blksecs = []
memPool = set()

def getBlocks(blk):
    global done
    rpc = AuthServiceProxy(cfg['rpc'])
    sql = db.connect(*cfg['db'].split(':'))
    cur = sql.cursor()
    if blk == 0:
        cur.execute('select ifnull(max(id), -1) from blocks')
        blk = int(cur.fetchone()[0] + 1)
    startblk = blk
    if 'blkdat' in cfg:
        blk = getBlocksDirect(cur, blk) # use direct file access for catch up
    logts("Using rpc mode. Monitoring blocks / mempool.")
    while not done.isSet():
        blkinfo = rpc.getblockchaininfo()
        if blk > blkinfo['blocks']:
            checkMemPool(cur, rpc)
            time.sleep(10)
            continue
        if blockQ.qsize() >= cfg['queue']:
            time.sleep(10)
            continue
        rpcstart = time.time()
        blkhash = rpc.getblockhash(blk)
        if blkhash:
            data = decodeBlock(rpc.getblock(blkhash, False).decode('hex'))
            data['height'] = blk
            data['rpc'] = time.time()-rpcstart
            blockQ.put(data)
            blk += 1
    return blk - startblk

def getBlocksDirect(cur, blk):
    from sqlchain.blkdat import BlkDatHandler, chkPruning
    blkscan = threading.Thread(target = BlkDatHandler, args=(cfg,done))
    blkscan.start()
    logts("Using blkdat mode: %s" % cfg['blkdat'])
    while not done.isSet():
        if blockQ.qsize() >= cfg['queue']:
            time.sleep(5)
            continue           
        cur.execute("select filenum,filepos from blkdat where id=%s limit 1;", (blk,))
        row = cur.fetchone()
        if row:
            filenum,pos = row
            chkPruning(cfg, int(filenum))
            started = time.time()
            with open(cfg['blkdat'] + "/blocks/blk%05d.dat" % filenum, 'rb') as fd:
                fd.seek(pos)
                magic,blksize = unpack('<II', fd.read(8))
                if magic != 0xD9B4BEF9:
                    logts("Error reading block %d - bad blkdat index. Aborting direct mode" % blk)
                    return blk
                data = decodeBlock(fd.read(blksize))
                data['height'] = blk
                data['rpc'] = time.time()-started
                blockQ.put(data)
                blk += 1
        else:
            time.sleep(5)
    log("Waiting for blkscan to finish")
    blkscan.join()
    return blk
            
def BlockHandler():
    global done
    sql = db.connect(*cfg['db'].split(':'))
    cur = sql.cursor()
    while not done.isSet():
        try: insertBlock(cur, blockQ.get(True, 5))
        except Empty: pass

def checkMemPool(cur, rpc):
    cur.execute("select ifnull(max(sync_id),0) from mempool;")
    sync_id = cur.fetchone()[0]
    if len(memPool) == 0:
        cur.execute("delete from mempool;")
    trxs = rpc.getrawmempool()
    for tx in trxs:
        txx = tx.decode('hex')[::-1][:8] # uses 1/4 space, only for detecting changes in mempool
        if txx not in memPool:
            memPool.add(txx)
            insertTxMemPool(cur, decodeTx(rpc.getrawtransaction(tx,0).decode('hex')), sync_id+1)

def addOrphan(cur, height):
    cur.execute("select ifnull(max(sync_id),0) from mempool;")
    sync_id = cur.fetchone()[0]
    hdr = gethdr(height, 'raw')
    cur.execute("select hash,coinbase from blocks where id=%s limit 1;", (height,))
    for blkhash,coinbase in cur:
        cur.execute("insert into orphans (sync_id,block_id,hash,hdr,coinbase) values(%s,%s,%s,%s,%s);", (sync_id,height,blkhash,hdr,coinbase))
    
def decodeBlock(data):
    hdr = ['version','previousblockhash','merkleroot', 'time', 'bits', 'nonce']
    hv = unpack_from('<I32s32s3I', data)
    block = dict(zip(hdr,hv))
    block['hdr'] = data[:80]
    block['hash'] = sha256(sha256(block['hdr']).digest()).digest()
    block['bits'] = '%08x' % block['bits'] 
    txcnt,off = decodeVarInt(data[80:89])
    off += 80
    block['tx'] = []
    while txcnt > 0:
        tx = decodeTx(data[off:])
        block['tx'].append(tx)
        off += tx['size']
        txcnt -= 1
    block['height'] = 0
    block['coinbase'] = block['tx'][0]['vin'][0]['coinbase']
    if block['version'] > 1 and block['height'] >= 227836 and block['coinbase'][0] == '\x03':
        block['height'] = unpack('<I', block['coinbase'][1:4]+'\0')[0]
    return block

def decodeTx(data):
    vers, = unpack_from('<I', data)
    tx = { 'version': vers, 'vin':[], 'vout':[] }
    vicnt,off = decodeVarInt(data[4:13])
    off += 4
    while vicnt > 0:
        txid,vout = unpack_from('<32sI', data, off)
        sigsz,soff = decodeVarInt(data[off+36:off+36+9])
        off += soff+36
        seq, = unpack_from('<I', data, off+sigsz)
        if txid == '\0'*32 and vout == 0xffffffff: 
            tx['vin'].append({'coinbase':data[off:off+sigsz], 'sequence':seq })
        else:
            tx['vin'].append({'txid':txid, 'vout':vout, 'scriptSig':data[off:off+sigsz], 'sequence':seq })
        off += sigsz+4
        vicnt -= 1
    vocnt,voff = decodeVarInt(data[off:off+9])
    off += voff
    n = 0
    while n < vocnt:
        value, = unpack_from('<Q', data, off)
        pksz,soff = decodeVarInt(data[off+8:off+8+9])
        off += 8+soff
        tx['vout'].append({'value':value, 'n':n, 'scriptPubKey':decodeScriptPK( data[off:off+pksz] ) }) 
        off += pksz
        n += 1
    tx['locktime'], = unpack_from('<I', data, off)
    tx['size'] = off+4
    tx['txid'] = sha256(sha256(data[:tx['size']]).digest()).digest()
    return tx

def checkReOrg(cur, data):
    rpc = AuthServiceProxy(cfg['rpc'])
    blkhash,height = data['previousblockhash'],data['height']-1
    while True:
        cur.execute("select id from blocks where hash=%s limit 1;", (blkhash,))
        row = cur.fetchone()
        if row and row[0] == height: # rewind until block in good chain
            break
        height -= 1
        blkhash = rpc.getblockhash(height).decode('hex')[::-1] 
    height += 1
    if height < data['height']:
        cur.execute("update trxs set block_id=-1 where block_id >= %s;", (height*MAX_TX_BLK,)) # set bad chain txs uncfmd
        logts("Block %d *** ReOrg: %d orphan(s), %d txs affected" % (height, data['height']-height, cur.rowcount))
        while height < data['height']:
            blkhash = rpc.getblockhash(height)
            if blkhash:
                data = decodeBlock(rpc.getblock(blkhash, False).decode('hex')) # get good chain blocks
                if data:
                    for n,tx in enumerate(data['tx']):
                        tx_id = findTx(cur, tx['txid'], mkNew=True)
                        cur.execute("update trxs set block_id=%s where id=%s limit 1;", (height*MAX_TX_BLK+n, tx_id))
                        if cur.rowcount < 1:
                            insertTx(cur, tx, tx_id, height*MAX_TX_BLK + n) # occurs if tx wasn't in our mempool or orphan block
                    addOrphan(cur, height)
                    cur.execute("update blocks set hash=%s,coinbase=%s where id=%s;", (data['hash'], data['coinbase'], height))
                    puthdr(data['height'], data['hdr'])
            height += 1
       
def insertTxMemPool(cur, tx, sync_id):
    tx_id = findTx(cur, tx['txid'], mkNew=True)
    if tx_id != None:
        ins,outs = insertTx(cur, tx, tx_id, -1) # -1 means trx has no block
        cur.executemany("insert ignore into outputs (id,addr_id,value) values(%s,%s,%s)", outs)
        cur.executemany("update outputs set tx_id=%s where id=%s limit 1", ins)
    cur.execute("insert ignore into mempool (id,sync_id) values(%s,%s);", (tx_id, sync_id))

def insertTx(cur, tx, tx_id, blk_id):
    inlist,outlist = [],[]
    in_ids,txdata = '','' 
    stdSeq = True
    for vin in tx['vin']:
        if vin['sequence'] != 0xffffffff:
            stdSeq = False
            break
    for vin in tx['vin']:
        if 'coinbase' not in vin:
            in_id = findTx(cur, vin['txid'])
            if in_id and vin['vout'] < MAX_IO_TX:
                in_id = (in_id*MAX_IO_TX) + vin['vout']
                inlist.append(( tx_id, in_id ))
                in_ids += pack('<Q', in_id)[:7]
                if not cfg['no-sigs']:
                    txdata += encodeVarInt(len(vin['scriptSig'])) + vin['scriptSig']
                if not stdSeq:
                    txdata += pack('<I', vin['sequence'])
                                       
    for vout in tx['vout']:
        addr_id = insertAddress(cur, vout['scriptPubKey']['addr']) if 'addr' in vout['scriptPubKey'] else 0
        if vout['n'] < MAX_IO_TX:
            outlist.append(( tx_id*MAX_IO_TX + vout['n'], addr_id, vout['value'] ))
            txdata += encodeVarInt(len(vout['scriptPubKey']['data'])) + vout['scriptPubKey']['data']
            
    ins,outs,sz,txhdr = mkBlobHdr(len(inlist), len(outlist), tx, stdSeq, cfg['no-sigs'])
    cur.execute("insert ignore into trxs (id,hash,ins,outs,txsize,txdata,block_id) values(%s,%s,%s,%s,%s,%s,%s)", 
        ( tx_id, tx['txid'], ins, outs, sz, insertBlob(txhdr + in_ids + txdata), blk_id ))    
    return inlist,outlist
            
def insertBlock(cur, data):
    checkReOrg(cur, data)
    blkstart = time.time()
    blk_id = data['height']*MAX_TX_BLK
    ins,outs = [],[]
    for n,tx in enumerate(data['tx']):
        tx_id = findTx(cur, tx['txid'], mkNew=True)
        if tx['txid'][:8] in memPool:
            memPool.remove(tx['txid'][:8])
            cur.execute("delete from mempool where id=%s limit 1;", (tx_id,))
        cur.execute("update trxs set block_id=%s where id=%s limit 1;", (blk_id + n, tx_id))
        if cur.rowcount > 0:
            continue
        xi,xo = insertTx(cur, tx, tx_id, blk_id + n)
        ins.extend(xi)
        outs.extend(xo)
    cur.executemany("insert ignore into outputs (id,addr_id,value) values(%s,%s,%s)", outs)
    cur.executemany("update outputs set tx_id=%s where id=%s limit 1", ins)
    cur.execute("insert ignore into blocks (id,hash,coinbase) values (%s,%s,%s);", (data['height'], data['hash'], data['coinbase']))
    
    puthdr(data['height'], data['hdr'])
    
    blktime = time.time() - blkstart
    log("Block %d [ Q:%d %4d txs - %s - %3.0fms %2.1fs %3.0f tx/s]" % ( data['height'], blockQ.qsize(), 
        len(data['tx']), datetime.fromtimestamp(data['time']).strftime('%d-%m-%Y'), data['rpc']*1000, blktime, len(data['tx'])/blktime) )

    blksecs.append(blktime)
    if len(blksecs) > 18: # ~3 hour moving avg
        del blksecs[0]
    cur.execute("replace into info (class,`key`,value) value('info','avg-block-sync',%s);", ("%2.1f"%(sum(blksecs)/len(blksecs)), ))

def options(cfg):
    try:                                
        opts,args = getopt.getopt(sys.argv[1:], "hvd:l:r:w:p:q:u:b:f:", 
            ["help", "version", "debug", "db=", "log=", "rpc=", "prunegap=", "queue=", "user=", "block=", "blkdat=", "no-sigs", "defaults" ])
    except getopt.GetoptError:
        usage()
    for opt,arg in opts:
        if opt in ("-h", "--help"):
            usage()
        elif opt in ("-v", "--version"):
            sys.exit(sys.argv[0]+': '+version)
        elif opt in ("-d", "--db"):
            cfg['db'] = arg
        elif opt in ("-l", "--log"):
            cfg['log'] = arg
        elif opt in ("-r", "--rpc"):
            cfg['rpc'] = arg
        elif opt in ("-p", "--prunegap"):
            cfg['prunegap'] = int(arg)
        elif opt in ("-q", "--queue"):
            cfg['queue'] = int(arg)
        elif opt in ("-u", "--user"):
            cfg['user'] = arg 
        elif opt in ("--no-sigs"):
            cfg['no-sigs'] = True 
        elif opt in ("--defaults"):
            savecfg(cfg)
            sys.exit("%s updated" % (sys.argv[0]+'.cfg'))
        elif opt in ("-b", "--block"):
            cfg['block'] = int(arg)
        elif opt in ("--debug"):
            cfg['debug'] = True
        elif opt in ("-f", "--blkdat"):
            cfg['blkdat'] = arg
        
def usage():
    print """Usage: {0} [options...][cfg file]\nCommand options are:\n-h,--help\tShow this help info\n-v,--version\tShow version info
-b,--block\tStart at block number (instead of from last block done)
-f,--blkdat\tSet path to bitcoin data and use direct file access (no mempool/re-org)
--debug\t\tRun in foreground with logging to console
--defaults\tUpdate cfg and exit\nDefault files are {0}.cfg, {0}.log
\nThese options get saved in cfg file as defaults.
-p,--prunegap\tSet min number blocks pruneable before stoping bitcoind (10)
-q,--queue\tSet block queue size (8)\n-u,--user\tSet user to run as\n-d,--db  \tSet mysql db connection, "host:user:pwd:dbname"
-l,--log\tSet log file path\n-r,--rpc\tSet rpc connection, "http://user:pwd@host:port"
--no-sigs\tDo not store input sigScript data """.format(sys.argv[0])
    sys.exit(2) 
    
def sigterm_handler(_signo, _stack_frame):
    global done
    done.set()
    if not cfg['debug']:
        os.unlink(cfg['pid'] if 'pid' in cfg else sys.argv[0]+'.pid')
    
def run():
    global blockQ,done 
    
    blockQ = Queue()
    done = threading.Event()
    
    blkwrk = threading.Thread(target = BlockHandler)
    blkwrk.start()
  
    blksdone = None
    workstart = time.time()
    while True:
        try:
            blksdone = getBlocks(cfg['block'] if 'block' in cfg else 0)
            break
        except socket.error:
            log("Cannot connect to rpc")
            time.sleep(5) 
            pass  
        
    done.set()
    blkwrk.join()
    if blksdone:
        log("Session %.2f blks/s" % float(blksdone / (time.time() - workstart)) )
    
if __name__ == '__main__':

    loadcfg(cfg)
    options(cfg)
    drop2user(cfg)
       
    if cfg['debug']:
        signal.signal(signal.SIGINT, sigterm_handler)
        run()
    else:
        with daemon.DaemonContext(working_directory='.', stdout=open(sys.argv[0]+'.log','a'), stderr=open(sys.argv[0]+'.log','a'), 
                signal_map={signal.SIGTERM:sigterm_handler } ):
            with file(cfg['pid'] if 'pid' in cfg else sys.argv[0]+'.pid','w') as f: 
                f.write(str(os.getpid()))
            run()

    

    


    

