#!/usr/bin/env python

__author__ = 'SungHo Lee (shlee@unc.edu)'
__version_info__ = ('2017', '03', '25')
__version__ = '-'.join(__version_info__) + 'REV04'

import re, os, argparse, sys
try:
    import nibabel as nib
    import numpy as np
    import json
except:
    print "Required libraries (numpy, nibabel, json) are not installed"
    sys.exit(0)


class BrukerRawData(object):
    def __init__(self, path, pid, ori=False):
        # parameter files
        with open(os.path.join(path, 'acqp')) as f:
            acqp = f.readlines()[:]
        with open(os.path.join(path, 'method')) as f:
            method = f.readlines()[:]
        with open(os.path.join(path, 'pdata', str(pid), 'reco')) as f:
            reco = f.readlines()[:]
        with open(os.path.join(os.path.split(path)[0], 'subject')) as f:
            subject = f.readlines()[:]

        # Parsing header information
        self._acqp = self.parsing(acqp, 'acqp')
        self._method = self.parsing(method, 'method')
        self._reco = self.parsing(reco, 'reco')
        self._subject = self.parsing(subject, 'subject')

        # Check datatype
        dtype = self._reco['RECO_wordtype']
        if dtype == '_16BIT_SGN_INT':
            dt = np.dtype('int16')
        elif dtype == '_32BIT_SGN_INT':
            dt = np.dtype('int32')
        else:
            dt = np.dtype('float64')

        # Load binary image file
        img = os.path.join(path, 'pdata', str(pid), '2dseq')
        self._2dseq = np.fromfile(img, dtype=dt)

        # Get image resolution
        self._resol = self.get_resol()
        self._affine = np.diag(self.resol + [1])

        # Convert to Nifti
        self._nii = nib.Nifti1Image(self.img, self.affine)

        # Correct orientation
        if ori == True:
            self.correct_orient()

        # Convert to Nifti
        self.set_default_header()

    @property
    def acqp(self):
        return self._acqp

    @property
    def method(self):
        return self._method

    @property
    def reco(self):
        return self._reco

    @property
    def subject(self):
        return self._subject

    @property
    def resol(self):
        return self._resol

    @property
    def img(self):
        return self._2dseq.reshape(self.get_shape()).T

    @property
    def affine(self):
        return self._affine

    @property
    def nii(self):
        return self._nii


    def parsing(self, profiles, key):
        p_sprofile = r'^\#\#\$(.*)\=([^(].*[^)])'
        p_vprofile = r'^\#\#\$(.*)\=\((.*)\)'
        p_vis = r'^\$\$.*'
        p_string = r'^\<(.*)\>$'
        output_obj = dict()
        for i, line in enumerate(profiles):
            if re.search(p_sprofile, line):
                key = re.sub(p_sprofile, r'\1', line).strip()
                value = re.sub(p_sprofile, r'\2', line).strip()
                value = self.check_dt(value)
                output_obj[key] = value
            elif re.search(p_vprofile, line):
                key = re.sub(p_vprofile, r'\1', line).strip()
                n_value = re.sub(p_vprofile, r'\2', line).strip()
                try:
                    n_value = map(int, map(str.strip, n_value.split(',')))
                except:
                    n_value = map(str, map(str.strip, n_value.split(',')))
                if len(n_value) == 1:
                    n_value = n_value[0]
                values = list()
                for next_line in profiles[i + 1:]:
                    if re.search(p_sprofile, next_line):
                        break
                    elif re.search(p_vprofile, next_line):
                        break
                    elif re.search(p_vis, next_line):
                        break
                    else:
                        values.append(next_line.strip())
                values = ' '.join(values)
                if isinstance(n_value, list):
                    try:
                        values = np.array(self.check_array(n_value, values)).reshape(n_value)
                    except:
                        values = self.check_dt(values)
                    output_obj[key] = values
                else:
                    if re.match(p_string, values):
                        output_obj[key] = re.sub(p_string, r'\1', values)
                    else:
                        if n_value == 1:
                            values = self.check_dt(values)
                        else:
                            try:
                                values = self.check_array(n_value, values)
                            except:
                                print('{}({})={}'.format(key, n_value, values))
                        output_obj[key] = values
            else:
                pass
        return output_obj

    def check_dt(self, value):
        p_int = r'^-?[0-9]+$'
        p_float = r'^-?(\d+\.?)?\d+([eE][-+]?\d+)?$'
        p_string = r'^\<(.*)\>$'
        value = value.strip(' ')
        if re.match(p_float, value):
            if re.match(p_int, value):
                value = int(value)
            else:
                value = float(value)
        else:
            try:
                value = int(value)
            except:
                if re.match(p_string, value):
                    value = re.sub(p_string, r'\1', value).strip(" ")
        return value

    def check_array(self, n_value, values):
        p_groups = r'\(([^)]*)\)'
        if re.match(p_groups, values):
            values = re.findall(p_groups, values)
            values = [map(self.check_dt, value.split(', ')) for value in values]
        else:
            values = map(self.check_dt, values.split())
        return values

    def get_shape(self):
        matrix = self.reco['RECO_size'][::-1]
        slices = self.acqp['NSLICES']
        echo_img = self.method['PVM_NEchoImages']
        num_rep = None
        if self.method['PVM_NEchoImages'] > 1:
            matrix = [echo_img] + matrix
        if self.method['PVM_SpatDimEnum'] == '2D':
            matrix = [slices] + matrix
        elif self.method['PVM_SpatDimEnum'] == '3D':
            pass
        try:
            num_rep = self.method['PVM_NRepetitions']
        except:
            pass
        if self.method['Method'] == 'DtiEpi':
            num_b0 = self.method['PVM_DwAoImages']
            num_dir = self.method['PVM_DwNDiffDir']
            matrix = [(num_b0 + num_dir)] + matrix
        if num_rep:
            if self.method['Method'] == 'DtiEpi':
                matrix[0] *= num_rep
            else:
                matrix = [num_rep] + matrix
        else:
            pass
        print(matrix)
        return matrix

    def get_tempresol(self):
        tr = self.method['PVM_RepetitionTime']
        num_avr = self.method['PVM_NAverages']
        try:
            num_seg = self.method['NSegments']
            return tr * num_seg * num_avr
        except:
            return tr * num_avr

    def get_resol(self):
        if self.method['PVM_SpatDimEnum'] == '2D':
            # dx, dy = self.method['PVM_SpatResol']
            dy, dx = list(np.array(self.reco['RECO_fov'])/np.array(self.reco['RECO_size'])*10)
            dz = self.acqp['ACQ_slice_thick']
        elif self.method['PVM_SpatDimEnum'] == '3D':
            # dx, dy, dz = self.method['PVM_SpatResol']
            dy, dx, dz = list(np.array(self.reco['RECO_fov'])/np.array(self.reco['RECO_size'])*10)
        else:
            dx, dy, dz = (1, 1, 1)
        return [dx, dy, dz]

    def get_center(self):
        if self.method['PVM_SpatDimEnum'] == '2D':
            center_x, center_y = np.array(self.method['PVM_Fov'])/2
            center_z = (self.acqp['ACQ_slice_thick'] + self.method['PVM_SPackArrSliceGap']) * self.acqp['NSLICES'] / 2
        elif self.method['PVM_SpatDimEnum'] == '3D':
            center_x, center_y, center_z = np.array(self.method['PVM_Fov'])/2
        else:
            center_x, center_y, center_z = (0, 0, 0)
        return center_x, center_y, center_z

    def get_orient(self):
        slice_orient = self.method['PVM_SPackArrSliceOrient']
        read_orient = self.method['PVM_SPackArrReadOrient']
        pos = self.acqp['ACQ_patient_pos'].split('_')
        return [slice_orient, read_orient] + pos

    def get_geometry(self):
        phase1_offset = self.method['PVM_SPackArrPhase1Offset']
        phase2_offset = self.method['PVM_SPackArrPhase2Offset']
        slice_offset = self.method['PVM_SPackArrSliceOffset']
        read_offset = self.method['PVM_SPackArrReadOffset']
        return [slice_offset, phase1_offset, phase2_offset, read_offset]

    def set_default_header(self):
        self.nii.header.default_x_flip = False
        tr = self.get_tempresol()
        if self.method['Method'] == 'EPI':
            self.nii.header.set_xyzt_units('mm', 'sec')
            self.nii.header['pixdim'][4] = float(tr) / 1000
            self.nii.header.set_dim_info(slice=2)
            self.nii.header['slice_duration'] = float(tr) / (1000 * self.acqp['NSLICES'])
            if self.method['PVM_ObjOrderScheme'] == 'User_defined_slice_scheme':
                self.nii.header['slice_code'] = 0
            elif self.method['PVM_ObjOrderScheme'] == 'Sequential':
                self.nii.header['slice_code'] = 1
            elif self.method['PVM_ObjOrderScheme'] == 'Reverse_sequential':
                self.nii.header['slice_code'] = 2
            elif self.method['PVM_ObjOrderScheme'] == 'Interlaced':
                self.nii.header['slice_code'] = 3
            elif self.method['PVM_ObjOrderScheme'] == 'Reverse_interlacesd':
                self.nii.header['slice_code'] = 4
            elif self.method['PVM_ObjOrderScheme'] == 'Angiopraphy':
                self.nii.header['slice_code'] = 0
            self.nii.header['slice_start'] = min(self.acqp['ACQ_obj_order'])
            self.nii.header['slice_end'] = max(self.acqp['ACQ_obj_order'])
        else:
            self.nii.header.set_xyzt_units('mm', 'unknown')
            self.nii.header['qform_code'] = 1
            self.nii.header['sform_code'] = 0

    def correct_orient(self, human=0):
        """
        Readout direction: - this setup only affected axis of readout direction not other axis
            L_R: left to right (x)
            A_P: anterior to posterior (y)
            H_F: head to foot (z)
        position:
            Supine, Prone
            Head, Foot
        """
        ori = self.get_orient()
        if human:
            pass
        else:
            self._resol[0] = -1 * float(self._resol[0])
            self._resol[1] = -1 * float(self._resol[1])
            affine = np.diag(self._resol + [1])
            centers = self.get_center()
            print(centers)
            offset = np.round(self.get_geometry(), decimals=4)
            affine[0, 3] = (centers[0] - offset[3])
            affine[1, 3] = (centers[1] + offset[1])
            affine[2, 3] = -1 * (centers[2] - offset[0])
            if ori[0] == 'axial':
                self._affine = affine[[0, 2, 1, 3], :]
            elif ori[0] == 'coronal':
                self._affine = affine
            else:
                self._affine = affine[[2, 1, 0, 3], :]
            self.nii.set_qform(self.affine, code=1)
            self.nii.set_sform(self.affine, code=0)

    def __repr__(self):
        return 'BrukerRawData'

def main():
    parser = argparse.ArgumentParser(prog='brk2nifti', description="Convert Bruker raw data to Nifti formated image")
    parser.add_argument("pid", help="Proccessed ID (in case image is reconstructed)")
    parser.add_argument("path", help="Folder location for the Bruker raw data", type=str)
    parser.add_argument("filename", help="Filename w/o extension to export NifTi image", type=str)
    parser.add_argument("-V", "--version", action="version", version="%(prog)s ("+__version__+")")
    parser.add_argument("-o", "--orient", help="Reorientation to correct space", action='store_true', default=0)
    args = parser.parse_args()

    img = BrukerRawData(args.path, args.pid, ori=args.orient)
    img.nii.to_filename(args.filename)

if __name__ == '__main__':
    main()

