#!/usr/bin/env python
import subprocess
import numpy as np
import re
import os.path
import time
import argparse

import mdwc.software_tools.abinit_controller as ac
import mdwc.MD_suite.MD_suite as md_ft


def write_md_output(
    path,
    bond_const,
    angl_const,
    pressure_out,
    volu_out,
    bond_constrain_out,
    cos_constrain_out,
):
    mdout_file = open(path, 'w')
    mdout_file.write('md_step     volume(Bohr^3)\n')
    for i, valu in enumerate(volu_out):
        mdout_file.write('%d          %.3f\n' % (i, valu))
    mdout_file.write('\n')
    mdout_file.write('md_step     pressure(hartree/Bohr^3)\n')
    for i, valu in enumerate(pressure_out):
        mdout_file.write('%d          %.3E\n' % (i, valu))
    mdout_file.write('\n')
    mdout_file.write('bond constraints\n')
    for md_i in range(bond_constrain_out.shape[0]):
        mdout_file.write('md_step   %d\n' % md_i)
        mdout_file.write('atoms in bond     bond value\n')
        for j in range(bond_constrain_out.shape[1]):
            mdout_file.write(
                '%d  %d            %.3f\n'
                % (
                    bond_const[j, 0],
                    bond_const[j, 1],
                    bond_constrain_out[md_i, j] ** 0.5,
                )
            )
    mdout_file.write('\n')
    mdout_file.write('angle constraints\n')
    for md_i in range(cos_constrain_out.shape[0]):
        mdout_file.write('md_step   %d\n' % md_i)
        mdout_file.write('atoms in angle constraint     cos of angle value\n')
        for j in range(cos_constrain_out.shape[1]):
            mdout_file.write(
                '%d  %d  %d                      %.3f\n'
                % (
                    angl_const[j, 0],
                    angl_const[j, 1],
                    angl_const[j, 2],
                    cos_constrain_out[md_i, j],
                )
            )
    mdout_file.close()
    return


parser = argparse.ArgumentParser(description='Paramenters for MD')
parser.add_argument(
    "-name", help="name of the .in, .files, .md", dest="name", type=str, required=True
)
parser.add_argument(
    "-NVT",
    help="If true activates NVT, the default is False and activates NPT",
    dest="NVT",
    type=bool,
    default=False,
)
parser.add_argument(
    "-mpirun",
    help="If true activates mpirun -np to use abinit, the default is False",
    dest="mpirun",
    type=bool,
    default=False,
)
parser.add_argument("-np", help="number of processors", dest="np", type=int)
input_para = vars(parser.parse_args())
name = input_para['name']

Qmass, bmass, P_ext, dt, correc_steps, md_steps, abinit_steps = ac.get_md_parameters(
    name + '.md'
)

(
    numb_bond_cons,
    bool_bond_cons,
    numb_angl_cons,
    bool_angl_cons,
    numb_cell_para_cons,
    bool_cell_para_cons,
    numb_cell_angl_cons,
    bool_cell_angl_cons,
    volu_cons,
    bool_volu,
    numb_atom_fix_cons,
    bool_atom_fix_cons,
    bond_const,
    angl_const,
    cell_para_const,
    cell_angl_const,
    atom_fix_const,
    bond_valu,
    angl_valu,
    cell_para_valu,
    cell_angl_valu,
    volu_valu,
    atom_fix_valu,
    atom_fix_cord,
) = ac.get_md_constrains(name + '.md')

md_total_steps = md_steps * abinit_steps
temp_arra = ac.temp_data_reader(name + '.md', md_total_steps)
pressure_volu_file = open('pressure_volume.mdout', 'w')
pressure_volu_file.write(
    'md_step     abinit_step     total_step    time(fs)     pressure(hartree/Bohr^3)     volume(Bohr^3)\n'
)

# read structrual information from
for i_abinit_step in range(abinit_steps):
    if i_abinit_step == 0:
        temp = temp_arra[i_abinit_step * md_steps]
        print('abinit step  ', i_abinit_step)
        s_t = 1.0  # thermostat degree of freedom
        s_t_dot = 0.0  # time derivative of thermostat degree of freedom
        # First abinit run. Take the prototype xxxx.in and xxxx.files
        # and put them in what is going to be the working directory
        work_dir = name + str(i_abinit_step)
        subprocess.call(['mkdir', work_dir])
        ac.from_prototype_in_to_in_step0(
            name + '.in', work_dir + '/' + name + str(i_abinit_step) + '.in'
        )
        ac.from_prototype_file_to_file(
            name + '.files',
            work_dir + '/' + name + str(i_abinit_step) + '.files',
            i_abinit_step,
        )
        rf = open(work_dir + '/' + name + str(i_abinit_step) + '.files')
        log = open(work_dir + '/' + name + str(i_abinit_step) + '.log', 'w')
        n = 0
        if input_para['mpirun'] == True:
            comand_string = 'mpirun -np %d abinit' % input_para['np']
            job = subprocess.Popen(
                comand_string,
                bufsize=1048576,
                shell=True,
                stdout=log,
                stdin=rf,
                cwd=work_dir,
            )
        else:
            job = subprocess.Popen(
                'abinit',
                bufsize=1048576,
                shell=True,
                stdout=log,
                stdin=rf,
                cwd=work_dir,
            )
        while job.poll() == None:
            # print job.poll()
            time.sleep(30)
            if os.path.exists(work_dir + '/' + name + str(i_abinit_step) + '.out'):
                output = open(work_dir + '/' + name + str(i_abinit_step) + '.out', 'r')
                data = output.read()
                match = re.search('Calculation\s+completed', data)
                if match and n == 0:
                    job.kill()
                    n = 1
        # print job.poll()
        nat, mass, h_t, strten_in = ac.get_nat_mass_latvec_in_strten_in(
            work_dir + '/' + name + str(i_abinit_step) + '.out'
        )

        x_t, f_t = ac.get_xred_fcart(
            work_dir + '/' + name + str(i_abinit_step) + '.out', nat
        )
        # h_t_inv= np.linalg.inv(h_t)
        # print 'x_t.shape', x_t.shape
        # print 'f_t.shape', f_t.shape
        # F_redu_t= np.dot(h_t_inv, f_t)
        v_t = md_ft.npt_md_suite.init_vel_atoms(mass, temp, len(mass))
        # print 'v_t.shape', v_t.shape
        # print 'F_redu_t.shape', F_redu_t.shape
        # once abinit provide the force and streints we can start the md
        # 1) init velocityes
        # v_t= md.init_vel_atoms(mass, temp, len(mass))
        h_t_dot = md_ft.npt_md_suite.init_vel_lattice(bmass, temp, h_t)
        x_t_dot = md_ft.npt_md_suite.get_x_dot(h_t, v_t, nat)
        # out= md.md_npt_step(dt, md_steps, mass, Qmass, bmass, temp, correc_steps, x_t, x_t_dot, v_t, F_redu_t,\
        #            h_t, h_t_dot, strten_in, P_ext, s_t, s_t_dot)
        if input_para['NVT'] == False:
            (
                s_out,
                s_out_dot,
                pressure_out,
                volu_out,
                bond_constrain_out,
                cos_constrain_out,
                h_out,
                h_dot_out,
                x_out,
                x_dot_out,
                v_out,
            ) = md_ft.npt_md_suite.md_npt_constrains(
                h_t,
                x_t,
                x_t_dot,
                f_t,
                strten_in,
                v_t,
                h_t_dot,
                bond_valu,
                angl_valu,
                cell_para_valu,
                cell_angl_valu,
                volu_valu,
                atom_fix_valu,
                atom_fix_cord,
                P_ext,
                mass,
                Qmass,
                bmass,
                dt,
                temp,
                s_t,
                s_t_dot,
                bond_const,
                angl_const,
                cell_para_const,
                cell_angl_const,
                atom_fix_const,
                correc_steps,
                md_steps,
                bool_bond_cons,
                bool_angl_cons,
                bool_cell_para_cons,
                bool_cell_angl_cons,
                bool_volu,
                bool_atom_fix_cons,
                volu_cons,
                nat,
                numb_cell_angl_cons,
                numb_cell_para_cons,
                numb_angl_cons,
                numb_bond_cons,
                numb_atom_fix_cons,
            )
            # set new step
            x_t = x_out
            x_t_dot = x_dot_out
            v_t = v_out
            h_t = h_out
            h_t_dot = h_dot_out
            s_t = s_out
            s_t_dot = s_out_dot
            write_md_output(
                work_dir + '/' + name + str(i_abinit_step) + '.mdout',
                bond_const,
                angl_const,
                pressure_out,
                volu_out,
                bond_constrain_out,
                cos_constrain_out,
            )
            for i, _ in enumerate(pressure_out):
                pre = pressure_out[i]
                vol = volu_out[i]
                md_time = (i_abinit_step * md_steps + i) * dt
                # pressure_volu_file.write('md_step     abinit_step     total_step    time(fs)     pressure(hartree/Bohr^3)     volume(Bohr^3)\n')
                pressure_volu_file.write(
                    '%d          %d               %d             %.3f         %.3E                         %.3f\n'
                    % (
                        i + 1,
                        i_abinit_step + 1,
                        (i_abinit_step * md_steps + i + 1),
                        md_time,
                        pre,
                        vol,
                    )
                )
        else:
            s_out, s_out_dot, x_out, v_out = md_ft.npt_md_suite.md_nvt_constrains(
                h_t,
                x_t,
                f_t,
                v_t,
                bond_valu,
                angl_valu,
                atom_fix_valu,
                atom_fix_cord,
                mass,
                Qmass,
                dt,
                temp,
                s_t,
                s_t_dot,
                bond_const,
                angl_const,
                atom_fix_const,
                correc_steps,
                md_steps,
                bool_bond_cons,
                bool_angl_cons,
                bool_atom_fix_cons,
                nat,
                numb_angl_cons,
                numb_bond_cons,
                numb_atom_fix_cons,
            )
            x_t = x_out
            v_t = v_out
            s_t = s_out
            s_t_dot = s_out_dot
    else:  # i_abinit_step == 0:
        temp = temp_arra[i_abinit_step * md_steps]
        print('abinit step  ', i_abinit_step)
        work_dir = ac.create_directories(name, x_t, h_t, i_abinit_step, pwd='.')
        rf = open(work_dir + '/' + name + str(i_abinit_step) + '.files')
        log = open(work_dir + '/' + name + str(i_abinit_step) + '.log', 'w')
        #        job= subprocess.Popen('mpirun -np 4 abinit', bufsize=1048576, shell=True, \
        #                              stdout=log, stdin=rf, cwd=work_dir)
        if input_para['mpirun'] == True:
            comand_string = 'mpirun -np %d abinit' % input_para['np']
            job = subprocess.Popen(
                comand_string,
                bufsize=1048576,
                shell=True,
                stdout=log,
                stdin=rf,
                cwd=work_dir,
            )
        else:
            job = subprocess.Popen(
                'abinit',
                bufsize=1048576,
                shell=True,
                stdout=log,
                stdin=rf,
                cwd=work_dir,
            )
        n = 0
        while job.poll() == None:
            # print job.poll()
            time.sleep(30)
            if os.path.exists(work_dir + '/' + name + str(i_abinit_step) + '.out'):
                output = open(work_dir + '/' + name + str(i_abinit_step) + '.out', 'r')
                data = output.read()
                match = re.search('Calculation\s+completed', data)
                if match and n == 0:
                    job.kill()
                    n = 1
        # end if match and n == 0:
        # print job.poll()
        nat, mass, h_t, strten_in = ac.get_nat_mass_latvec_in_strten_in(
            work_dir + '/' + name + str(i_abinit_step) + '.out'
        )
        x_t, f_t = ac.get_xred_fcart(
            work_dir + '/' + name + str(i_abinit_step) + '.out', nat
        )
        # h_t_inv= np.linalg.inv(h_t)
        # F_redu_t= np.dot(h_t_inv, f_t)
        # out= md.md_npt_step(dt, md_steps, mass, Qmass, bmass, temp, correc_steps, x_t, x_t_dot, v_t, F_redu_t,\
        #            h_t, h_t_dot, strten_in, P_ext, s_t, s_t_dot)
        if input_para['NVT'] == False:
            (
                s_out,
                s_out_dot,
                pressure_out,
                volu_out,
                bond_constrain_out,
                cos_constrain_out,
                h_out,
                h_dot_out,
                x_out,
                x_dot_out,
                v_out,
            ) = md_ft.npt_md_suite.md_npt_constrains(
                h_t,
                x_t,
                x_t_dot,
                f_t,
                strten_in,
                v_t,
                h_t_dot,
                bond_valu,
                angl_valu,
                cell_para_valu,
                cell_angl_valu,
                volu_valu,
                atom_fix_valu,
                atom_fix_cord,
                P_ext,
                mass,
                Qmass,
                bmass,
                dt,
                temp,
                s_t,
                s_t_dot,
                bond_const,
                angl_const,
                cell_para_const,
                cell_angl_const,
                atom_fix_const,
                correc_steps,
                md_steps,
                bool_bond_cons,
                bool_angl_cons,
                bool_cell_para_cons,
                bool_cell_angl_cons,
                bool_volu,
                bool_atom_fix_cons,
                volu_cons,
                nat,
                numb_cell_angl_cons,
                numb_cell_para_cons,
                numb_angl_cons,
                numb_bond_cons,
                numb_atom_fix_cons,
            )
            # set new step
            x_t = x_out
            x_t_dot = x_dot_out
            v_t = v_out
            h_t = h_out
            h_t_dot = h_dot_out
            s_t = s_out
            s_t_dot = s_out_dot
            write_md_output(
                work_dir + '/' + name + str(i_abinit_step) + '.mdout',
                bond_const,
                angl_const,
                pressure_out,
                volu_out,
                bond_constrain_out,
                cos_constrain_out,
            )
            for i, _ in enumerate(pressure_out):
                pre = pressure_out[i]
                vol = volu_out[i]
                md_time = (i_abinit_step * md_steps + i) * dt
                # pressure_volu_file.write('md_step     abinit_step     total_step    time(fs)     pressure(hartree/Bohr^3)     volume(Bohr^3)\n')
                pressure_volu_file.write(
                    '%d          %d               %d             %.3f         %.3E                         %.3f\n'
                    % (
                        i + 1,
                        i_abinit_step + 1,
                        (i_abinit_step * md_steps + i + 1),
                        md_time,
                        pre,
                        vol,
                    )
                )
        else:  # input_para['NVT'] == False
            s_out, s_out_dot, x_out, v_out = md_ft.npt_md_suite.md_nvt_constrains(
                h_t,
                x_t,
                f_t,
                v_t,
                bond_valu,
                angl_valu,
                atom_fix_valu,
                atom_fix_cord,
                mass,
                Qmass,
                dt,
                temp,
                s_t,
                s_t_dot,
                bond_const,
                angl_const,
                atom_fix_const,
                correc_steps,
                md_steps,
                bool_bond_cons,
                bool_angl_cons,
                bool_atom_fix_cons,
                nat,
                numb_angl_cons,
                numb_bond_cons,
                numb_atom_fix_cons,
            )
            x_t = x_out
            v_t = v_out
            s_t = s_out
            s_t_dot = s_out_dot
pressure_volu_file.close()
