#!/usr/bin/env python
"""
Launches batch job running starfit on every folder listed in a file

Uses slurm job array by default, but could be reconfigured
to another batch scheduler if desired.
"""
from __future__ import print_function, division

import os,re,sys,os.path,shutil,glob
import argparse
import numpy as np
import subprocess

if __name__=='__main__':
    parser = argparse.ArgumentParser(description='Fire up a batch starfit job')

    parser.add_argument('file', type=str)
    parser.add_argument('-n', '--nsplit', type=int, default=None,
                        help='Total number of jobs to split into.  Will default to length of file.')
    parser.add_argument('--ntasks_per_node', type=int, default=20,
                        help='number of cores to use per node.')
    
    parser.add_argument('-t', '--time', type=float, default=5,
                        help='approximate time that one line will take, in minutes')
    parser.add_argument('extra', nargs=argparse.REMAINDER)
    
    args = parser.parse_args()
    
    listfile = os.path.abspath(args.file)

    num_lines = sum(1 for line in open(listfile))
    nsplit = num_lines if args.nsplit is None else args.nsplit

    n_nodes = nsplit // (args.ntasks_per_node)
    if nsplit % args.ntasks_per_node != 0:
        n_nodes += 1
    
    if nsplit < args.ntasks_per_node:
        ntasks_per_node = nsplit
    else:
        ntasks_per_node = args.ntasks_per_node

    num_per_job = num_lines // nsplit
    if num_lines % nsplit != 0:
        num_per_job += 1
    tot_minutes = args.time*num_per_job
    time_string = '{:02.0f}:{:02.0f}:00'.format(tot_minutes//60, tot_minutes % 60)

            
    scriptfile = '{}.batch'.format(listfile)

    fout = open(scriptfile, 'w')
    fout.write('#!/bin/bash\n')
    fout.write('#SBATCH -J starfit-{}\n'.format(args.file))
    fout.write('#SBATCH -N {}\n'.format(n_nodes))
    fout.write('#SBATCH --ntasks-per-node={}\n'.format(ntasks_per_node))
    fout.write('#SBATCH -t {}\n'.format(time_string))
    fout.write('\n')
    fout.write('for ((i=0; i<=$(expr $SLURM_NPROCS-1); i++)) do\n' +
               ' awk "NR % ${{SLURM_NPROCS}} == $i" {} | '.format(listfile) +
               'xargs starfit ')

    for arg in args.extra:
        fout.write('{} '.format(arg))
    fout.write('&\n')
    fout.write('done\n' +
               'wait\n')
    fout.close()

    #execute slurm batch job
    subprocess.call('sbatch {}'.format(scriptfile), shell=True)

