#!/usr/bin/python2
# -*- coding: utf-8 -*-

from __future__ import absolute_import
from __future__ import print_function
import shutil
import argus_gui
import pandas
import numpy as np
import argparse
import tempfile
import sys
from scipy.sparse import lil_matrix
from six.moves import map
from six.moves import range

# gets data from CSVs. Expects a header.
def parseCSVs(csv):
    if csv.split('.')[-1] == 'csv':
        dataf = pandas.read_csv(csv, index_col = False)
        return dataf.as_matrix()
    # else check if we have sparse data representation
    elif csv.split('.')[-1] == 'tsv':
        fo = open(csv)
        # expect a header
        line = fo.readline()
        # next line has shape information for the sparse matrix
        line = fo.readline()
        shape = list(map(int, line.split('\t')))
        #ret = lil_matrix((shape[0], shape[1]))
        ret = lil_matrix((shape[0], shape[1]))
        ret[:,:] = np.nan
        line = fo.readline()
        while line != '':
            val = list(map(float, line.split('\t')))
            ret[int(val[0]) - 1, int(val[1]) - 1] = val[2]
            line = fo.readline()
        return ret

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Argus-Wand command line interface")

    parser.add_argument("cam_profile",
                        help="Text file with intrinsic and extrinsic estimates for each camera row by row. Values must be separated by spaces. Column specification: fx px py AR s r2 r4 t1 t2 r6")
    parser.add_argument("--paired_points", default = '',
                        help="CSV file with paired UV pixel coordinates Has 4*(number of cameras) columns")

    parser.add_argument("--unpaired_points",  default = '',
                        help="CSV file with unpaired UV pixel coordinates. Has a multiple of 2*(number of cameras) columns")

    parser.add_argument("--reference_points", default = '',
                        help="CSV file with three points marked in pixel coordinates. One line, and has 6*(number of cameras) columns")

    parser.add_argument("--scale", default = '1',
                        help="Measured distance between paired points used to define a scale")

    parser.add_argument("--intrinsics_opt",
                        help="0, 1, or 2 camera intrinsics to optimize (focal length and principal point)", default = "0")

    parser.add_argument("--distortion_opt",
                        help="0, 1, 2, or 3. Specifies the mode of distortion optimization. (Optimize none, optimize r2, optimize r2 and r4, or optimize all distortion coefficients", default = "0")

    parser.add_argument("output_name",help="Output filename for points and optimized camera profile")

    parser.add_argument("-g","--graph",action="store_true",
                        help="Graph points and camera positions with Matplotlib. Must have X11 running.")
    parser.add_argument("--outliers",action="store_true",
                        help="Report on outliers and ask to re-calibrate without them")
    parser.add_argument("--tmp", default = "None",
                        help = "Temporary directory to pass results between obects. If none is specified one will be created. Directory will be destroyed uppon completion, use with caution!")
    parser.add_argument("--output_camera_profiles", action = "store_true",
                        help = "Use this argument to output camera profiles for use with other Argus programs")
    parser.add_argument("--choose_reference", action = "store_true",
                        help = "Use this argument to have Argus-Wand optimize the choice of reference camera by counting the number of triangulatable points for each choice")
    args = parser.parse_args()
    print('Loading points...')
    sys.stdout.flush()
    
    # Get paired points from a CSV file as an array, no index column, with or without header
    if args.paired_points != '':
        ppts = parseCSVs(args.paired_points)
    else:
        ppts = None
    # Get unpaired points 
    if args.unpaired_points != '':
        #print args.unpaired_points
        uppts = parseCSVs(args.unpaired_points)
    else:
        uppts = None
    # Make sure we have a camera profile TXT document
    try:
        cams = np.loadtxt(args.cam_profile)
    except:
        print('Could not load camera profile! Make sure it is formatted according to the documentation.')
        sys.exit()

    # Check if the camera profile is the correct shape:
    if cams.shape[1] != 12:
        print('Incorrect shape for camera profile! Consult the documentation.')
        sys.exit()

    # Format the camera profile to how SBA expects it i.e. take out camera number column, image width and height, then add in skew
    cams = np.delete(cams, [0,2,3], 1)
    cams = np.insert(cams, 4, 0., axis = 1)

    scale = float(args.scale)
    display = args.graph
    #print 'Graphing: {0}'.format(display)
    mode = args.intrinsics_opt + args.distortion_opt
    
    # Reference points. Either one, two, or three uv correspondences:
    # One - an origin
    # Two - origin and z-axis
    # Three - origin, x-axis, y-axis
    if args.reference_points != '':
        print('Defining reference points')
        ref = pandas.read_csv(args.reference_points, index_col = False).as_matrix()
        toDel = list()
        # take out NaNs
        for k in range(ref.shape[0]):
            if True in np.isnan(ref[k]):
                toDel.append(k)
        ref = np.delete(ref, toDel, axis = 0)
        if ref.shape[1] != 2*cams.shape[0] or not 1 <= ref.shape[0] <= 3:
            print('Incorrect shape of reference points! Make sure they are formatted according to the documentation.')
            sys.exit()
    else:
        ref = None
    # Output files location and tag
    name = args.output_name
    # temporary file name
    if args.tmp != "None":
        tmp = args.tmp
    else:
        tmp = tempfile.mkdtemp()
    # boolean for outlier analysis at the end
    rep = args.outliers
    if ppts is not None:
        if ppts.shape[1] != 4*cams.shape[0]:
            print('Incorrect shape of paired points! Paired points file must have 4*(number of cameras) columns')
            sys.exit()
    if uppts is not None:
        if uppts.shape[1] % (2*cams.shape[0]) != 0:
            print('Incorrect shape of unpaired points! Unpaired points must have a multiple of 2*(number of cameras) columns')
            sys.exit()
    oCPs = args.output_camera_profiles

    driver = argus_gui.sbaArgusDriver(ppts, uppts, cams, display, scale, mode, ref, name, tmp, rep, oCPs, args.choose_reference)
    driver.fix()

    shutil.rmtree(tmp)
