#!/usr/local/bin/python2.7
# encoding: utf-8
'''
scripts.pophod -- shortdesc

scripts.pophod is a description

It defines classes_and_methods
'''

import sys
import os
import traceback
import numpy as np
from hod import hod
from hod import tools
import ast
from hod.fort.routines import hod_routines as hr
from argparse import ArgumentParser
from argparse import RawDescriptionHelpFormatter

import cosmolopy as cp

from multiprocessing import Pool, cpu_count
__all__ = []
__version__ = 0.1
__date__ = "2014 - 01 - 21"
__updated__ = "2014 - 01 - 21"

DEBUG = 0
TESTRUN = 0
PROFILE = 0

class CLIError(Exception):
    '''Generic exception to raise and log different fatal errors.'''
    def __init__(self, msg):
        super(CLIError).__init__(type(self))
        self.msg = "E: %s" % msg
    def __str__(self):
        return self.msg
    def __unicode__(self):
        return self.msg

def main(argv=None):
    '''
    Populate a dark-matter simulation with galaxies based on a given HOD.

    The script will take input files in a number of formats. Currently these
    include:

        1. SubFind format: a fofcat file and a pos file.
        2. ASCII format: 4 columns: x,y,z,mass

    If other formats are in demand, they will be implemented. Otherwise, first
    convert any required files to one of these formats.

    To use the ASCII format, just pass a single filename. To use the SubFind
    format, pass a filename to the --posfile argument as well (having passed
    the fofcat file as the infile).
    '''

    if argv is None:
        argv = sys.argv
    else:
        sys.argv.extend(argv)

    program_name = os.path.basename(sys.argv[0])
    program_version = "v%s" % __version__
    program_build_date = str(__updated__)
    program_version_message = '%%(prog)s %s (%s)' % (program_version, program_build_date)
    program_shortdesc = __import__('__main__').__doc__.split("\n")[1]
    program_license = '''%s

     Created by Steven Murray on 21-01-2014.
     Copyright 2014 organization_name. All rights reserved.

     Licensed under the Apache License 2.0
     http://www.apache.org/licenses/LICENSE-2.0

     Distributed on an "AS IS" basis without warranties
     or conditions of any kind, either express or implied.

    USAGE
    ''' % (program_shortdesc)

    try:
        # Setup argument parser
        parser = ArgumentParser(description=program_license, formatter_class=RawDescriptionHelpFormatter)
        parser.add_argument("-v", "--verbose", dest="verbose", action="count", help="set verbosity level [default: %(default)s]")
        parser.add_argument('-V', '--version', action='version', version=program_version_message)
        parser.add_argument("infile", nargs="+", help="the main input halo catalogue(s). If SubFind format, this is the fofcat file.")
        parser.add_argument("hodmod_dict", help="dictionary of HOD parameters")

        parser.add_argument("--outfile", nargs="*", help="the output galaxy catalogue")

        parser.add_argument("--filetype", nargs="*", default=["ascii"], choices=["ascii", "subfind"], help="the type of input file(s). Must be len(1) or len(infile)")
        parser.add_argument("--hodmod", default="zheng", help="the HOD model used")
        parser.add_argument("--delta_h", default=200.0, type=float, help="the average halo overdensity")
        parser.add_argument("--halo_profile", default="NFW", help="the halo halo_profile to use")
        parser.add_argument("--halo_cm-relation", default="zehavi", help="the Concentration-Mass relation")
        parser.add_argument("--no-truncate", action="store_true", help="don't truncate halo_profile")
        parser.add_argument("--omegam", nargs="*", default=[0.3], type=float, help="the matter density")
        parser.add_argument("--redshift", nargs="*", default=[0.0], type=float, help="the redshift")
        parser.add_argument("--nthreads", type=int, help="number of threads to use. Default auto.")
        parser.add_argument("-m", "--particle-mass", default=1.0, type=float, help="particle mass in the simulation. For ascii, leave as 1.0 unless 4th column is number of particles.")
        parser.add_argument("--delta-wrt", nargs="*", default=["mean"], choices=["mean", "crit"], help="what delta_h is with respect to")

        # Process arguments
        args = parser.parse_args()

        # # Check that filetype has the right number
        if len(args.filetype) != 1 and len(args.filetype) != len(args.infile):
            raise ValueError("filetype must be len(1) or len(infile)")

        if len(args.filetype) == 1:
            ftypes = len(args.infile) * args.filetype
        else:
            ftypes = args.filetype

        # Check if outfile the right length
        if args.outfile != None and len(args.outfile) != len(args.infile):
            raise ValueError("outfile must be None or len(infile)")

        if args.outfile is None:
            outfiles = [f + "_gal" for f in args.infile]
        else:
            outfiles = args.outfile

        # Get posfiles if type is subfind.
        posfiles = [f.replace("fofcat", "pos") if ftypes[i] == "subfind" else None for i, f in enumerate(args.infile)]


        # auto-calculate the number of threads to use if not set.
        if not args.nthreads:
            nthreads = cpu_count()
        else:
            nthreads = args.nthreads
        p = Pool(nthreads)


        # Also check omegam and redshift (as they are dependent on simulation)
        if len(args.redshift) != 1 and len(args.redshift) != len(args.infile):
            raise ValueError("redshift must be len(1) or len(infile)")

        if len(args.redshift) == 1:
            redshifts = len(args.infile) * args.redshift
        else:
            redshifts = args.redshift

        if len(args.omegam) != 1 and len(args.omegam) != len(args.infile):
            raise ValueError("omegam must be len(1) or len(infile)")

        if len(args.omegam) == 1:
            omegams = len(args.infile) * args.omegam
        else:
            omegams = args.omegam

        delta_h = []
        if len(args.delta_wrt) == 1:
            delta_h = len(args.infile) * [args.delta_h]
        elif len(args.delta_wrt) == len(args.infile):
            delta_h = [args.delta_h if dw == "mean" else args.delta_h / cp.density.omega_M_z(args.redshift[i], omega_M_0=omegams[i],
                                                          omega_lambda_0=1 - omegams[i],
                                                          omega_k_0=0.0) for i, dw in enumerate(args.delta_wrt)]
        else:
            raise ValueError("length of delta_wrt must be 1 or len(infile)")

        print delta_h

        # Define the HOD model
        hodparams = ast.literal_eval(args.hodmod_dict)
        hodmod = hod.HOD(args.hodmod, **hodparams)


        popfunc = _function_wrapper(populate, [ftypes, args, delta_h, posfiles, omegams,
                                                redshifts, hodmod, outfiles])

        flags = p.map(popfunc, range(len(args.infile)))

        assert False not in flags

        return 0
    except KeyboardInterrupt:
        ### handle keyboard interrupt ###
        return 0
    except Exception, e:
        if DEBUG or TESTRUN:
            raise(e)
        traceback.print_exc()
        indent = len(program_name) * " "
        sys.stderr.write(program_name + ": " + repr(e) + "\n")
        sys.stderr.write(indent + "  for help use --help\n")
        return 2


#===============================================================================
# FUNCTION CALLS
#===============================================================================
def populate(i, ftypes, args, delta_h, posfiles, omegams, redshifts, hodmod, outfiles):
    if ftypes[i] == "subfind":
        centres, masses = read_subfind(args.infile[i], posfiles[i], args.particle_mass)
    elif ftypes[i] == "ascii":
        centres, masses = read_ascii(args.infile[i], args.particle_mass)
    else:
        raise ValueError("ftypes had value " + ftypes[i] + " in it!")

    boxsize = np.ceil(centres.max())  # best estimate we can get
    # Convert to physical co-ordinates
    centres /= (1 + redshifts[i])

    pos = tools.populate(centres, masses, delta_h[i], omegams[i],
                         redshifts[i], args.halo_profile, args.cm_relation,
                         hodmod, not args.no_truncate)

    pos *= (1 + redshifts[i])

    # wrap the galaxies so they stay inside the box
    pos[pos < 0] += boxsize
    pos[pos > boxsize] -= boxsize

    np.savetxt(outfiles[i], pos)
    return True

class _function_wrapper(object):
    """
    This is a hack to make the likelihood function pickleable when ``args``
    are also included.

    Taken from the emcee package.
    """
    def __init__(self, f, args):
        self.f = f
        self.args = args

    def __call__(self, x):
        try:
            return self.f(x, *self.args)
        except:
            raise

def read_ascii(filename, mp):
    x = np.genfromtxt(filename)
    masses = x[1:, 3] * mp

    centres = x[1:, :3]

    return centres, masses

def read_subfind(fofcat_file, pos_file, mp):
    import struct
    positions_dtype = np.dtype([("x", 'f'),
                                ("y", 'f'),
                                ("z", 'f')
                                ])
    with open(fofcat_file) as f:
        # Get number of halos
        nhalos = struct.unpack('i', f.read(4))[0]

        # Get halo sizes
        grouplen = np.fromfile(f, dtype='i', count=nhalos)

        # Get group id offsets
        groupoffsets = np.fromfile(f, dtype='i', count=nhalos)

    with open(pos_file) as f:
        # Get total number of particles in groups
        total_particles_in_groups = struct.unpack('i', f.read(4))[0]

        # Positions Block
        pos = np.fromfile(f, dtype=positions_dtype, count=total_particles_in_groups)

    # Make pos a normal ndarray
    pos = pos.view(np.float32).reshape(pos.shape + (-1,))
#     centres = []
#     for i in range(nhalos):
#         a = groupoffsets[i]
#         b = groupoffsets[i] + grouplen[i]
#         centres.append(np.array([np.mean([pos["x"][a :b]]),
#                                  np.mean([pos["y"][a :b]]),
#                                  np.mean([pos["z"][a :b]])]))
#
#         # Now make sure particles weren't wrapped around boundary
#         for i, cd in enumerate(['x', 'y', 'z']):
#             pos[cd][a:b][(pos[cd][a:b] - 10) > centres[-1][i]] -= pos[cd].max()
#             pos[cd][a:b][(pos[cd][a:b] + 10) < centres[-1][i]] += pos[cd].max()
#
#         centres[-1] = np.array([np.mean([pos["x"][a :b]]),
#                                 np.mean([pos["y"][a :b]]),
#                                 np.mean([pos["z"][a :b]])])

    centres = hr.get_subfind_centres(groupoffsets, grouplen,
                                     np.asfortranarray(pos))

    masses = grouplen * mp

    return centres, masses


if __name__ == "__main__":
    if DEBUG:
        sys.argv.append("-h")
        sys.argv.append("-v")
    if TESTRUN:
        import doctest
        doctest.testmod()
    if PROFILE:
        import cProfile
        import pstats
        profile_filename = 'scripts.pophod_profile.txt'
        cProfile.run('main()', profile_filename)
        statsfile = open("profile_stats.txt", "wb")
        p = pstats.Stats(profile_filename, stream=statsfile)
        stats = p.strip_dirs().sort_stats('cumulative')
        stats.print_stats()
        statsfile.close()
        sys.exit(0)
    sys.exit(main())
