#!/usr/bin/env python2.7
"""A script for sending files to VOSpace via multiple connection streams."""


import sys
import threading, Queue
import datetime, optparse
import errno
import logging
from vos import vos, __version__

usage = """
  vsync [options] files vos:Destination/
     
      Version: %s """ % (__version__.version)


def computeMD5(filename, block_size=8192):
    import hashlib
    md5 = hashlib.md5()
    r = open(filename, 'r')
    while True:
        buf = r.read(block_size)
        if len(buf) == 0:
            break
        md5.update(buf)
    r.close()
    return md5.hexdigest()

def fileMD5(filename):
    import os
    md5 = None
    if opt.cache_nodes:
        md5 = md5Cache.get(filename)
    if md5 is None or md5[2] < os.stat(filename).st_mtime:
        md5 = computeMD5(filename)
        if opt.cache_nodes:
            stat = os.stat(filename)
            md5Cache.update(filename, md5, stat.st_size, stat.st_mtime)
    else:
        md5 = md5[0]
    return md5

class ThreadCopy(threading.Thread):
    def __init__(self, queue):
        super(ThreadCopy, self).__init__()

        try:
            self.client = vos.Client()
        except Exception as e:
            sys.exit(e.errno)
        self.queue = queue
        self.filesSent = 0
        self.filesSkipped = 0
        self.bytesSent = 0
        self.bytesSkipped = 0
        self.filesErrored = 0


    def run(self):
        while True:
            (src, dest) = self.queue.get()
            requeue = (src, dest)
            now = datetime.datetime.now()
            srcMD5 = None
            stat = os.stat(src)
            if not opt.ignore_checksum and not opt.overwrite :
                srcMD5 = fileMD5(src)
            logger.debug("Source MD5: %s " % (srcMD5))
            if not opt.overwrite:
                """Check if the file is the same"""
                try:
                    nodeInfo = None
                    if opt.cache_nodes:
                        nodeInfo = md5Cache.get(dest)
                    if nodeInfo is None:
                        logger.debug("Getting node info from VOSpace")
                        logger.debug(str(nodeDict.keys()))
                        logger.debug(str(dest))
                        node = self.client.getNode(dest)
                        destMD5 = node.props.get('MD5', 'd41d8cd98f00b204e9800998ecf8427e')
                        destLength = node.attr['st_size']
                        destTime = node.attr['st_ctime']
                        if opt.cache_nodes:
                            md5Cache.update(dest, destMD5, destLength, destTime)
                    else:
                        destMD5 = nodeInfo[0]
                        destLength = nodeInfo[1]
                        destTime = nodeInfo[2]
                    logger.debug("Dest MD5: %s " % (destMD5))
                    if (not opt.ignore_checksum and srcMD5 == destMD5) or (opt.ignore_checksum and destTime >= stat.st_mtime and destLength == stat.st_size) :
                        logger.info("skipping: %s  matches %s" % (src, dest))
                        self.filesSkipped += 1
                        self.bytesSkipped += destLength
                        self.queue.task_done()
                        continue
                except (IOError, OSError) as e:
                    """Ignore the erorr"""
                    logger.debug(str(e))
                    pass
            logger.info("%s -> %s" % (src, dest))
            try:
                self.client.copy(src, dest, sendMD5=True)
                node = self.client.getNode(dest)
                destMD5 = node.props.get('MD5', 'd41d8cd98f00b204e9800998ecf8427e')
                destLength = node.attr['st_size']
                destTime = node.attr['st_ctime']
                if opt.cache_nodes:
                       md5Cache.update(dest, destMD5, destLength, destTime)
                self.filesSent += 1
                self.bytesSent += stat.st_size
            except (IOError, OSError) as e:
                logger.error("Error writing %s to server, skipping" % (src))
                logger.error(str(e))
                import re
                if re.search('NodeLocked',str(e)) != None:
                    logger.error("Use vlock to unlock the node before syncing to %s." % (dest))
                try:
                    if e.errno == 104:
                        self.queue.put(requeue)
                except Exception as e2:
                    logger.error("Error during requeue")
                    logger.error(str(e2))
                    pass
                self.filesErrored += 1
                pass
            self.queue.task_done()


def mkdirs(dirs):


    logger.debug("%s %s" % (dirs, str(goodDirs)))
    ## if we've seen this before skip it.
    if dirs in goodDirs:
        return

    ## try and make a new directory and return
    ## failure indicates we should see if subdirs exist
    try:
        c.mkdir(dirs)
        logger.info("Made directory %s " % (dirs))
        goodDirs.append(dirs)
        return
    except IOError as e:
        if e.errno != errno.EEXIST:
            raise e

    ## OK, must already have existed, add to list
    goodDirs.append(dirs)

    return


import os
def copy(source, dest):
    ## strip down dest until we find a part that exists
    ## and then build up the path.  Dest should include the filename
    if os.path.islink(source):
        logger.error("%s is a link, skipping" % (source))
        return
    if not os.access(source, os.R_OK):
        logger.error("Failed to open file %s, skipping" % (source))
        return
    import re
    if re.match('^[A-Za-z0-9\\._\\-\\(\\);:&\\*\\$@!+=\\/]*$', source) is None:
        logger.error("filename %s contains illegal characters, skipping" % (source))
        return

    
    dirname = os.path.dirname(dest)
    mkdirs(dirname)
    if opt.include is not None and not re.search(opt.include, source):
        return
    queue.put((source, dest), timeout=3600)

def startStreams(nstreams):
    streams = []
    for i in range(nstreams):
        logger.info("Launching vospace connection stream %d" % (i))
        t = ThreadCopy(queue)
        t.setDaemon(True)
        t.start()
        streams.append(t)
    return streams


def buildFileList(basePath, destRoot='', recursive=False, ignore=None):
    """Build a list of files that should be copied into VOSpace"""
    import string
    spinner = ['-', '\\', '|', '/', '-', '\\', '|', '/']
    count = 0
    import re
    for (root, dirs, filenames) in os.walk(basePath):
        for thisDirname in dirs:
            if not recursive:
                continue
            thisDirname = os.path.join(root, thisDirname)
            skip = False
            if ignore is not None:
                for thisIgnore in ignore.split(','):
                    if not thisDirname.find(thisIgnore) < 0:
                        logger.info ("excluding: %s " % (thisDirname))
                        skip = True
                        continue
            if skip:
                continue
            cprefix = os.path.commonprefix((basePath, thisDirname))
            thisDirname = os.path.normpath(destRoot + "/" + thisDirname[len(cprefix):])
            mkdirs(thisDirname)
        for thisfilename in filenames:
            srcfilename = os.path.normpath(os.path.join(root, thisfilename))
            skip = False
            if ignore is not None:
                for thisIgnore in ignore.split(','):
                    if not srcfilename.find(thisIgnore) < 0:
                        logger.info ("excluding: %s " % (srcfilename))
                        skip = True
                        continue
            if skip:
                continue
            cprefix = os.path.commonprefix((basePath, srcfilename))
            destfilename = os.path.normpath(destRoot + "/" + srcfilename[len(cprefix):])
            thisDirname = os.path.dirname(destfilename)
            mkdirs(thisDirname)

            count += 1
            if opt.verbose:
                sys.stderr.write("Building list of files to transfer %s\r" % (spinner[count % len(spinner)]))
            copy(srcfilename, destfilename)
        if not recursive:
            return
    return



def signal_handler(signal, frame):
    logger.critical("Interupt\n")
    sys.exit(-1)



if __name__ == '__main__':
    import time

    # handle interupts nicely

    import signal
    signal.signal(signal.SIGINT, signal_handler)

    startTime = time.time()

    parser = optparse.OptionParser(usage)
    parser.add_option('--verbose', '-v', action="store_true", help='run in verbose mode')
    parser.add_option('--debug', '-d', action="store_true", help='run in verbose mode')
    parser.add_option('--warning', action="store_true")
    parser.add_option('--ignore-checksum', action="store_true", help='dont check MD5 sum, use size and time instead')
    parser.add_option('--cache_nodes', action='store_true', help='cache node MD5 sum in an sqllite db')
    parser.add_option('--recursive', '-r', help="Do a recursive sync", action="store_true")
    parser.add_option('--nstreams', '-n', type=int, help="Number of streams to run (MAX: 30)", default=1)
    parser.add_option('--exclude', help="ignore directories or files containing this pattern", default=None)
    parser.add_option('--include', help="only include files matching this pattern", default=None)
    parser.add_option('--overwrite', help="overwrite copy on server regardless of modification/size/md5 checks", action="store_true")
    parser.add_option('--load_test', action="store_true", help="Used to stress test the VOServer, also set --nstreams to a large value")
    parser.add_option("--certfile", help="location of your CADC security certificate file", default=os.path.join(os.getenv("HOME", "."), ".ssl/cadcproxy.pem"))
    parser.add_option("--version", action="store_true", help="VOS Version %s" % (__version__.version))

    if len(sys.argv) == 1:
        parser.print_help()
        sys.exit()

    (opt, args) = parser.parse_args()

    if opt.version:
        sys.stdout.write("vsync version %s \n\n" % (__version__.version)) 
        sys.exit(0)

    if opt.verbose:
        log_level = logging.INFO
    elif opt.debug:
        log_level = logging.DEBUG
    elif opt.warning:
        log_level = logging.WARNING
    else:
        log_level = logging.ERROR

    logging.getLogger('vos').setLevel(log_level)
    logging.getLogger('vos').addHandler(logging.StreamHandler())
    logger = logging.getLogger('vos')

    if len(args) < 2:
        parser.error("requires one or more source files and a single destination directory")

    if opt.nstreams > 30 and not opt.load_test:
        parser.error("Maximum of 30 streams exceeded")

    if opt.cache_nodes:
        from vos import md5_cache
        md5Cache = md5_cache.MD5_Cache()


    dest = args.pop()
    if dest[0:4] != "vos:":
        parser.error("Only allows sync FROM local copy TO VOSpace")
    ## Currently we don't create nodes in sync and we don't sync onto files
    logger.info("Connecting to VOSpace")
    c = vos.Client(certFile=opt.certfile)
    logger.info("Confirming Destination is a directory")
    destIsDir = c.isdir(dest)

    queue = Queue.Queue(maxsize=10 * opt.nstreams)
    goodDirs = []
    nodeDict = {}

    streams = startStreams(opt.nstreams)


    ### build a complete file list given all the things on the command line
    for filename in args:
        filename = os.path.abspath(filename)
        thisRoot = dest
        if os.path.isdir(filename):
            if filename[-1] != "/" :
                if os.path.basename(filename) != os.path.basename(dest):
                    thisRoot = os.path.join(dest, os.path.basename(filename))
            mkdirs(thisRoot)
            nodeDict[thisRoot] = c.getNode(thisRoot)
            try:
                buildFileList(filename, destRoot=thisRoot, recursive=opt.recursive, ignore=opt.exclude)
            except Exception as e:
                logger.error(str(e))
                logger.error("ignoring error")
        elif os.path.isfile(filename):
            if destIsDir:
                thisRoot = os.path.join(dest, os.path.basename(filename))
            copy(filename, thisRoot)
        else:
            logger.error("%s: No such file or directory." % (filename))


    logger.info("\nWaiting for transfers to complete.\nCTRL-\ to interrupt\n")

    queue.join()
    endTime = time.time()
    bytesSent = 0
    filesSent = 0
    bytesSkipped = 0
    filesSkipped = 0
    filesErrored = 0
    for stream in streams:
        bytesSent += stream.bytesSent
        bytesSkipped += stream.bytesSkipped
        filesSent += stream.filesSent
        filesSkipped += stream.filesSkipped
        filesErrored += stream.filesErrored

    logger.info("\n\n==== TRANSFER REPORT ====\n\n")

    if bytesSent > 0:
        rate = bytesSent / (endTime - startTime) / 1024.0
        logger.info("Sent %d files (%8.1f kbytes @ %8.3f kBytes/s)" % (filesSent, bytesSent / 1024.0, rate))
        speedUp = (bytesSkipped + bytesSent) / bytesSent
        logger.info("Speedup:  %f (skipped %d files)" % (speedUp, filesSkipped))

    if bytesSent == 0:
        logger.info("No files needed sending ")

    if filesErrored > 0:
        logger.info("Error transferring %d files, please try again" % (filesErrored))


