#!python
from datetime import datetime
from os import getcwd, mkdir, chdir
from os.path import isdir, isfile
from find_spg import find_crystal_system
from calc_elastic_constants import calc_elastic_constants
from deform_cell_ohess_strains import deform_cell_ohess_strains
from deform_cell_asess_strains import deform_cell_asess_strains
from deform_cell_ulics import deform_cell_ulics
from make_conv_cell import make_conventional_cell
from read_input import indict
from relax_atoms_pos import relax_atoms_pos
from calc_stress import calc_stress
from optimize_initial_str import optimize_initial_str
from extract_mean_values import mean_stress,mean_pressure,mean_temperature,mean_volume
from ase.io import vasp
from equilibrium_md import equil_md
from stability_criteria import criteria
from sound_velocity import sound_velocity
import warnings
warnings.filterwarnings('ignore')

mean_press = 0
stress_set_dict = {}
method_stress_statistics = indict['method_stress_statistics'][0]
num_last_samples = int(indict['num_last_samples'][0])
run_mode = int(indict['run_mode'][0])
dimensional = indict['dimensional'][0]


if method_stress_statistics == 'static':
    num_last_samples = 1

cwd = getcwd()
structure_file = indict['structure_file'][0]

print("")
print("Reading controlling parameters from elastool.in...")
print("")

# optimize the initial structure at fixed pressure/volume
if run_mode == 1 and not isfile('%s/OPT/CONTCAR'%cwd):
    pos_conv = make_conventional_cell(structure_file)
    pos_optimized = optimize_initial_str(pos_conv, cwd, 'fixed-pressure-opt')
else:
    pos_optimized = vasp.read_vasp('%s/OPT/CONTCAR'%cwd)




tubestrain_type = tubestrain_type = "Longitudinal" #indict['tubestrain_type'][0]

#print("latt_system", tubestrain_type)
#exit(1)

latt_system = find_crystal_system(pos_optimized, dimensional,tubestrain_type)

if dimensional == '1D':
    print("Experimental do not use, not yet debugged")
    if latt_system == 'Hydrostatic':
         tubestrain_type = 'Hydrostatic'  
    elif latt_system == 'Biaxial':
         tubestrain_type = 'Biaxial'
    elif latt_system == 'any1D':
         tubestrain_type = 'Longitudinal'
    elif latt_system == 'Nanotube':   
         tubestrain_type = 'Generalized'
    elif latt_system == 'Radial':   
         tubestrain_type = 'Radial'
    else:
        print('Choose a type of strain for the nanotube!!!')
        exit(1)

#if dimensional == '3D':
if method_stress_statistics == 'dynamic':
    equil_md(pos_optimized, cwd)
    tag = 'Total+kin.'
    stress_0 = mean_stress('%s/NO_STRAIN_MD/OUTCAR'%cwd, num_last_samples, tag)
    mean_press = 0.1 * mean_pressure('%s/NO_STRAIN_MD/OUTCAR'%cwd, num_last_samples)
    mean_temp = mean_temperature('%s/NO_STRAIN_MD/OUTCAR'%cwd, num_last_samples)
    mean_volume = mean_volume('%s/NO_STRAIN_MD/vasprun.xml'%cwd, num_last_samples)
    stress_set_dict[0] = [stress_0]

    pos0 = vasp.read_vasp('%s/NO_STRAIN_MD/POSCAR'%cwd)
    vol0 = pos0.get_volume()
    vol_scale = mean_volume / vol0

    pos_opt = vasp.read_vasp('%s/OPT/CONTCAR'%cwd)
    cell_new = pos_opt.get_cell() * vol_scale
    pos_opt.set_cell(cell_new, scale_atoms=True)

    pos_optimized_v = optimize_initial_str(pos_opt, cwd, 'fixed-volume-opt')
    repeat = [int(indict['repeat_num'][0]),int(indict['repeat_num'][1]),int(indict['repeat_num'][2])]
    pos_optimized = pos_optimized_v.repeat(repeat)


if run_mode == 1 or run_mode == 3:
    if method_stress_statistics == 'static':
        tag = 'in kB'
        stress_0 = mean_stress('%s/OPT/OUTCAR'%cwd, num_last_samples, tag)
        mean_press = 0.1 * mean_pressure('%s/OPT/OUTCAR'%cwd, num_last_samples)
        stress_set_dict[0] = [stress_0]

delta_list = [float(up) for up in indict['strains_list']]

if method_stress_statistics == 'dynamic':
    #delta_list = [float(indict['strains_list'][0])]
    strains_matrix = 'ohess'
else:
    strains_matrix = indict['strains_matrix'][0]

time_start = datetime.now()

if not isdir('STRESS'):
    mkdir('STRESS')

chdir('STRESS')

if run_mode != 2:
    print("Calculating stresses using the %s strain matrices..." % strains_matrix.upper())
else:
    print("Preparing necessary files using the %s strain matrices..." % strains_matrix.upper())

for up in delta_list:
    print("strain = %.3f" % up)
    if up != 0:
        if not isdir('strain_%s' % str(up)):
            mkdir('strain_%s' % str(up))
        chdir('strain_%s' % str(up))

        cell = pos_optimized.get_cell()
        if strains_matrix == 'ohess':
            deformed_cell_list = deform_cell_ohess_strains(latt_system, cell, up)
        elif strains_matrix == 'asesss':
            deformed_cell_list = deform_cell_asess_strains(latt_system, cell, up)
        elif strains_matrix == 'ulics':
            deformed_cell_list = deform_cell_ulics(latt_system, cell, up)

        stress_set_dict[up] = []
        for num_cell, cell_strain in enumerate(deformed_cell_list):
            if not isdir('matrix_%s' % str(num_cell)):
                mkdir('matrix_%s' % str(num_cell))
            chdir('matrix_%s' % str(num_cell))
            # relax atoms positoins int the strained structure
            #pos_conv_strain = relax_atoms_pos(pos_optimized, cell_strain, cwd)
            # calculate stresses
            stress_set_dict = calc_stress(pos_optimized, cell_strain, method_stress_statistics, stress_set_dict, num_last_samples, up, cwd)
            chdir('..')
        chdir('..')
chdir('..')

if int(indict['run_mode'][0]) == 1 or int(indict['run_mode'][0]) == 3:
    #print(stress_set_dict)
    print("")
    print("Fitting the first-order function to the collected \nstress-strain data according to Hooke's law...")
    elastic_constants_dict = {}
    elastic_constants_dict = calc_elastic_constants(pos_optimized, latt_system, elastic_constants_dict, stress_set_dict)
   # print("elastic_constants_dict are  ", elastic_constants_dict)
   # exit(1)	
    elastic_constants_dict = sound_velocity(elastic_constants_dict, cwd, dimensional,latt_system)


    longdash = '-'*48
    longequal = '='*48
    with open('elastool.out', 'w') as ec_file:
        ec_file.write("\n")
        ec_file.write("+%s+\n" % longequal)
        if dimensional == '1D':
             ec_file.write("|This is a %2s %s |\n" % (indict['dimensional'][0], ('lattice with ' +tubestrain_type+ ' strain.').ljust(34)))
        else:
             ec_file.write("|This is a %2s %s |\n" % (indict['dimensional'][0], (latt_system+' lattice.').ljust(34)))
        ec_file.write("|%s|\n" % longdash)
        ec_file.write("|Mean Pressure = %s|\n" % (str("%.2f" %mean_press)+' GPa').ljust(32))

        if method_stress_statistics == 'dynamic':
            ec_file.write("|Mean Temperature =  %s|\n" % (str(mean_temp)+' K').ljust(32))
        print_anisotropy = False
        try:
            G_V = elastic_constants_dict['G_v']
            G_R = elastic_constants_dict['G_r']
            B_V = elastic_constants_dict['B_v']
            B_R = elastic_constants_dict['B_r']
            A_U = 5*G_V / G_R + B_V / B_R - 6
            A_C = (G_V - G_R) / (G_V + G_R)
            print_anisotropy = True
        except:
            pass
        has_print_ec = False
        has_print_moduli = False
        has_print_sound = False
        for key in elastic_constants_dict.keys():
            if dimensional == '3D':
                if key[0] == 'c':
                    if not has_print_ec:
                        ec_file.write("|%s|\n" % longdash)
                        ec_file.write("%s|\n" %("|Elastic constants:").ljust(49))
                        has_print_ec = True
                    ec_file.write("|%s = %s|\n" % (key.capitalize(), (str("%.2f" % elastic_constants_dict[key])+" GPa").ljust(42)))
                elif key[0] == 'B' or key[0] == 'G':
                    if not has_print_moduli:
                        ec_file.write("|%s|\n" % longdash)
                        ec_file.write("%s|\n" %("|Elastic moduli:").ljust(49))
                        has_print_moduli = True
                    if len(key) == 3:
                        ec_file.write("|%s = %s|\n" % (key.upper(), (str("%.2f" % elastic_constants_dict[key])+" GPa").ljust(42)))
                    else:
                        ec_file.write("|%s = %s|\n" % (key.upper(), (str("%.2f" % elastic_constants_dict[key])+" GPa").ljust(40)))
                elif key[0] == 'E':
                    ec_file.write("|Young's modulus (%s) = %s|\n" % (key.capitalize(), (str("%.2f" % elastic_constants_dict[key])+" GPa").ljust(26)))
                elif key[0] == 'v':
                    ec_file.write("|Possion's ratio (%s) = %s|\n" % (key.capitalize(), (str("%.4f" % elastic_constants_dict[key])).ljust(26)))
                elif key[0] == 'V':
                    if not has_print_sound:
                        ec_file.write("|%s|\n" % longdash)
                        ec_file.write("%s|\n" %("|Sound velocity:").ljust(49))
                        has_print_sound = True
                    ec_file.write("|%s = %s|\n" % (key.upper(), (str("%.2f" % elastic_constants_dict[key])+" Km/s").ljust(42)))
                elif key[0] == 'T':
                    ec_file.write("|%s|\n" % longdash)
                    ec_file.write("%s|\n" %("|Debye temperature:").ljust(49))
                    ec_file.write("|%s = %s|\n" % (key.upper(), (str("%.2f" % elastic_constants_dict[key])+" K").ljust(42)))

            elif dimensional == '2D':
                if key[0] == 'c':
                    if not has_print_ec:
                        ec_file.write("|%s|\n" % longdash)
                        ec_file.write("%s|\n" %("|Elastic constants:").ljust(49))
                        has_print_ec = True
                    ec_file.write("|%s = %s|\n" % (key.capitalize(), (str("%.2f" % elastic_constants_dict[key])+" N/m").ljust(42)))
                elif key[0] == 'Y':
                    ec_file.write("|Young's modulus (%s) = %s|\n" % (key.capitalize(), (str("%.2f" % elastic_constants_dict[key])+" N/m").ljust(23)))
                elif key[0] == 'v':
                    ec_file.write("|Possion's ratio (%s) = %s|\n" % (key.capitalize(), (str("%.2f" % elastic_constants_dict[key])).ljust(26)))
                elif key[0] == 'B':
                    ec_file.write("|Bulk modulus (%s) = %s|\n" % (key.capitalize(), (str("%.2f" % elastic_constants_dict[key])+" N/m").ljust(29)))
                elif key[0] == 'G':
                    ec_file.write("|Shear modulus (%s) = %s|\n" % (key.capitalize(), (str("%.2f" % elastic_constants_dict[key])+" N/m").ljust(28)))
                elif key[0] == 'V':
                    if not has_print_sound:
                        ec_file.write("|%s|\n" % longdash)
                        ec_file.write("%s|\n" %("|Sound velocity:").ljust(49))
                        has_print_sound = True
                    ec_file.write("|%s = %s|\n" % (key.upper(), (str("%.2f" % elastic_constants_dict[key])+" Km/s").ljust(42)))
                elif key[0] == 'T':
                    ec_file.write("|%s|\n" % longdash)
                    ec_file.write("%s|\n" %("|Debye temperature:").ljust(49))
                    ec_file.write("|%s = %s|\n" % (key.upper(), (str("%.2f" % elastic_constants_dict[key])+" K").ljust(42)))


            elif dimensional == '1D':
                if key[0] == 'c':
                    if not has_print_ec:
                        ec_file.write("|%s|\n" % longdash)
                        ec_file.write("%s|\n" %("|Elastic constants:").ljust(49))
                        has_print_ec = True
                    ec_file.write("|%s = %s|\n" % (key.capitalize(), (str("%.2f" % elastic_constants_dict[key])+" GPa").ljust(42)))
                elif key[0] == 'Y':
                    ec_file.write("|Young's modulus (%s) = %s|\n" % (key.capitalize(), (str("%.2f" % elastic_constants_dict[key])+" GPa").ljust(26)))
                elif key[0] == 'v':
                    ec_file.write("|Possion's ratio (%s) = %s|\n" % (key.capitalize(), (str("%.2f" % elastic_constants_dict[key])).ljust(26)))
                elif key[0] == 'B':
                    ec_file.write("|Bulk modulus (%s) = %s|\n" % (key.capitalize(), (str("%.2f" % elastic_constants_dict[key])+" GPa").ljust(29)))
                elif key[0] == 'G':
                    ec_file.write("|Shear modulus (%s) = %s|\n" % (key.capitalize(), (str("%.2f" % elastic_constants_dict[key])+" GPa").ljust(28)))
                elif key[0] == 'C':
                    ec_file.write("|Compliance (%s) = %s|\n" % (key.capitalize(), (str("%.2e" % elastic_constants_dict[key])+" GPa^-1").ljust(31)))
                elif key[0] == 'R':
                    ec_file.write("|Resonance Frequency (%s) = %s|\n" % (key.capitalize(), (str("%.2e" % elastic_constants_dict[key])+" Hz").ljust(22)))

                elif key[0] == 'V':
                    if not has_print_sound:
                        ec_file.write("|%s|\n" % longdash)
                        ec_file.write("%s|\n" %("|Sound velocity:").ljust(49))
                        has_print_sound = True
                    ec_file.write("|%s = %s|\n" % (key.upper(), (str("%.2f" % elastic_constants_dict[key])+" Km/s").ljust(42)))
                elif key[0] == 'T':
                    ec_file.write("|%s|\n" % longdash)
                    ec_file.write("%s|\n" %("|Debye temperature:").ljust(49))
                    ec_file.write("|%s = %s|\n" % (key.upper(), (str("%.2f" % elastic_constants_dict[key])+" K").ljust(42)))

        if print_anisotropy:
            ec_file.write("|%s|\n" % longdash)
            ec_file.write("%s|\n" %("|Elastic anisotropy:").ljust(49))
            ec_file.write("|A_U = %s|\n" % (str("%.4f" % A_U)).ljust(42))
            ec_file.write("|A_C = %s|\n" % (str("%.4f" % A_C)).ljust(42))
        
        stable = criteria(elastic_constants_dict, latt_system)
        ec_file.write("|%s|\n" % longdash)
        ec_file.write("%s|\n" %("|Structure stability analysis...").ljust(49))
        if stable:
            ec_file.write("|This structure is mechanically STABLE.          |\n")
        else:
            ec_file.write("|This structure is NOT mechanically stable.   |\n")

        ec_file.write("+%s+\n" % longequal)
        ec_file.write("\n")

    #print(elastic_constants_dict)

time_now = datetime.now()
time_used = (time_now - time_start).seconds

with open('time_used.log', 'w') as time_record_file:
    time_record_file.write("The stress calculations used %d seconds.\n" % time_used)



if run_mode != 2:
    print("")
    print("The final results are as follows:")
    for line in open('elastool.out', 'r'):
        l = line.strip('\n')
        print(l)

    print("")
    print("Results are also saved in the elastool.out file.")
    print("")
    print("")
    print("Well done! GOOD LUCK!")
    print("")
else:
    print("")
    print("All necessary files are prepared in the STRESS directory.")
    print("Run VASP in each subdirectory and rerun elastool with run_mode = 3.")
    print("")
    print("GOOD LUCK!")
    print("")
