#!python
import sys
from datetime import datetime
from os import getcwd, mkdir, chdir
from os.path import isdir, isfile
from find_spg import find_crystal_system
from calc_elastic_constants import calc_elastic_constants
from deform_cell_ohess_strains import deform_cell_ohess_strains
from deform_cell_asess_strains import deform_cell_asess_strains
from deform_cell_ulics import deform_cell_ulics
from make_conv_cell import make_conventional_cell
from read_input import indict
from relax_atoms_pos import relax_atoms_pos
from calc_stress import calc_stress
from optimize_initial_str import optimize_initial_str
from extract_mean_values import mean_stress,mean_pressure,mean_temperature,mean_volume
from ase.io import vasp
from equilibrium_md import equil_md
from stability_criteria import criteria
from sound_velocity import sound_velocity
import pkg_resources
import warnings
warnings.filterwarnings('ignore')

version = pkg_resources.get_distribution("ElasTool").version
#---------------------------------------------------------
#Print out citation
def print_boxed_message(ec_file=None):
    header_footer = "+" + "-" * 78 + "+"
    spacer = "| " + " " * 76 + " |"

    # List of lines to be printed
    lines = [
        (" * CITATIONS *", True),
        ("If you have used Elastool in your research, PLEASE cite:", False),
        ("", False),  # Space after the above line
        ("ElasTool: An automated toolkit for elastic constants calculation, ", False),
        ("Z.-L. Liu, C.E. Ekuma, W.-Q. Li, J.-Q. Yang, and X.-J. Li, ", False),
        ("Computer Physics Communications 270, 108180, (2022)", False),
        ("", False),

        ("", False),  # Blank line for separation
        ("Efficient prediction of temperature-dependent elastic and", False),
        ("mechanical properties of 2D materials, S.M. Kastuar, C.E. Ekuma, Z-L. Liu,", False),
        ("Nature Scientific Report 12, 3776 (2022)", False)
    ]

    def output_line(line):
        if ec_file:
            ec_file.write(line + "\n")
        else:
            print(line)

    output_line(header_footer)
    
    # Loop through each line, printing it within the box
    for line, underline in lines:
        centered_line = line.center(76)
        output_line("| " + centered_line + " |")
        
        if underline:
            underline_str = "-" * len(centered_line)
            output_line("| " + underline_str.center(76) + " |")

    # Print footer of the box
    output_line(header_footer)



def write_line(ec_file, content, padding=1, border_char="|", filler_char=" "):
    content_width = max_width - (2 * padding) - 2  # Subtract 2 for the border characters
    content = content[:content_width]  # Ensure content doesn't exceed the width
    line = border_char + filler_char*padding + content.ljust(content_width) + filler_char*padding + border_char
    ec_file.write(line + "\n")





def print_banner(ec_file, version):
    # Get current date and time
    current_time = datetime.now().strftime('%H:%M:%S')
    current_date = datetime.now().strftime('%Y-%m-%d')
    conclusion_msg = f"Calculations ended at {current_time} on {current_date}"

    # Concatenate the message with the version info
    message = f"SUMMARY OF RESULTS\nusing\nElasTool Version: {version}\n{conclusion_msg}"

    # Now use the write_line function
    write_line(ec_file, '❤' * (max_width - 2), padding=0, border_char='❤', filler_char='❤')  # This will print a line of hearts
    for line in message.split('\n'):
        centered_line = line.center(max_width - 4)  # Subtract 4 for the two border characters and spaces at each end
        write_line(ec_file, centered_line, padding=1, border_char='❤')
    write_line(ec_file, '❤' * (max_width - 2), padding=0, border_char='❤', filler_char='❤')  # This will print another line of hearts

#---------------------------------------------------------



#--------------------------------------------------
# Main code starts
#--------------------------------------------------
mean_press = 0
stress_set_dict = {}
method_stress_statistics = indict['method_stress_statistics'][0]
num_last_samples = int(indict['num_last_samples'][0])
run_mode = int(indict['run_mode'][0])
dimensional = indict['dimensional'][0]


if method_stress_statistics == 'static':
    num_last_samples = 1

cwd = getcwd()
structure_file = indict['structure_file'][0]



# optimize the initial structure at fixed pressure/volume
if run_mode == 1 and not isfile('%s/OPT/CONTCAR'%cwd):
    pos_conv = make_conventional_cell(structure_file)
    pos_optimized = optimize_initial_str(pos_conv, cwd, 'fixed-pressure-opt')
else:
    pos_optimized = vasp.read_vasp('%s/OPT/CONTCAR'%cwd)




tubestrain_type = "Nanotube" #indict['tubestrain_type'][0] #For potential extension to other forms of strain in 1D


latt_system = find_crystal_system(pos_optimized, dimensional,tubestrain_type)

if dimensional == '1D':
    if latt_system == 'Nanotube':
         tubestrain_type = 'Nanotube'
    else:
        print('Choose a type of strain for the 1D system!!!')
        exit(1)
  
#if dimensional == '3D':
if method_stress_statistics == 'dynamic':
    equil_md(pos_optimized, cwd)
    tag = 'Total+kin.'
    stress_0 = mean_stress('%s/NO_STRAIN_MD/OUTCAR'%cwd, num_last_samples, tag)
    mean_press = 0.1 * mean_pressure('%s/NO_STRAIN_MD/OUTCAR'%cwd, num_last_samples)
    mean_temp = mean_temperature('%s/NO_STRAIN_MD/OUTCAR'%cwd, num_last_samples)
    mean_volume = mean_volume('%s/NO_STRAIN_MD/vasprun.xml'%cwd, num_last_samples)
    stress_set_dict[0] = [stress_0]

    pos0 = vasp.read_vasp('%s/NO_STRAIN_MD/POSCAR'%cwd)
    vol0 = pos0.get_volume()
    vol_scale = mean_volume / vol0

    pos_opt = vasp.read_vasp('%s/OPT/CONTCAR'%cwd)
    cell_new = pos_opt.get_cell() * vol_scale
    pos_opt.set_cell(cell_new, scale_atoms=True)

    pos_optimized_v = optimize_initial_str(pos_opt, cwd, 'fixed-volume-opt')
    repeat = [int(indict['repeat_num'][0]),int(indict['repeat_num'][1]),int(indict['repeat_num'][2])]
    pos_optimized = pos_optimized_v.repeat(repeat)


if run_mode == 1 or run_mode == 3:
    if method_stress_statistics == 'static':
        tag = 'in kB'
        stress_0 = mean_stress('%s/OPT/OUTCAR'%cwd, num_last_samples, tag)
        mean_press = 0.1 * mean_pressure('%s/OPT/OUTCAR'%cwd, num_last_samples)
        stress_set_dict[0] = [stress_0]

delta_list = [float(up) for up in indict['strains_list']]

if method_stress_statistics == 'dynamic':
    #delta_list = [float(indict['strains_list'][0])]
    strains_matrix = 'ohess'
else:
    strains_matrix = indict['strains_matrix'][0]

time_start = datetime.now()

print("")
print("Reading controlling parameters from elastool.in...")
print("")



if not isdir('STRESS'):
    mkdir('STRESS')


chdir('STRESS')


if run_mode != 2:
    print("Calculating stresses using the %s strain matrices..." % strains_matrix.upper())
else:
    print("Preparing necessary files using the %s strain matrices..." % strains_matrix.upper())



for up in delta_list:
    print("strain = %.3f" % up)
    if up != 0:
        if not isdir('strain_%s' % str(up)):
            mkdir('strain_%s' % str(up))
        chdir('strain_%s' % str(up))

        cell = pos_optimized.get_cell()
        if strains_matrix == 'ohess':
            deformed_cell_list = deform_cell_ohess_strains(latt_system, cell, up)
        elif strains_matrix == 'asesss':
            deformed_cell_list = deform_cell_asess_strains(latt_system, cell, up)
        elif strains_matrix == 'ulics':
            deformed_cell_list = deform_cell_ulics(latt_system, cell, up)

        stress_set_dict[up] = []
        for num_cell, cell_strain in enumerate(deformed_cell_list):
            if not isdir('matrix_%s' % str(num_cell)):
                mkdir('matrix_%s' % str(num_cell))
            chdir('matrix_%s' % str(num_cell))
            # relax atoms positoins int the strained structure
            #pos_conv_strain = relax_atoms_pos(pos_optimized, cell_strain, cwd)
            # calculate stresses
            stress_set_dict = calc_stress(pos_optimized, cell_strain, method_stress_statistics, stress_set_dict, num_last_samples, up, cwd)
            chdir('..')
        chdir('..')
chdir('..')

# Estimate the max width based on the longest expected line:
max_width = len("|This is a 3D hexagonal lattice with simple shear. |")


if int(indict['run_mode'][0]) == 1 or int(indict['run_mode'][0]) == 3:
    print("")
    print("Fitting the first-order function to the collected \nstress-strain data according to Hooke's law...")

    elastic_constants_dict = calc_elastic_constants(
        pos_optimized, latt_system, {}, stress_set_dict)

    elastic_constants_dict = sound_velocity(
        elastic_constants_dict, cwd, dimensional, latt_system)

    longdash = '-' * 55

    with open('elastool.out', 'w') as ec_file:
        #write_line(ec_file, "")
        print_banner(ec_file,version)
        write_line(ec_file, "=" * max_width, border_char="+", filler_char="-")

        if dimensional == '1D':
            description = "This is a %2s %s" % (
                indict['dimensional'][0],   tubestrain_type + ' lattice.')
        else:
            description = "This is a %2s %s" % (
                indict['dimensional'][0], latt_system + ' lattice.')

        write_line(ec_file, description)
        write_line(ec_file, "Mean Pressure = %s GPa" %
                   str("%.2f" % mean_press))

        if method_stress_statistics == 'dynamic':
            write_line(ec_file, "Mean Temperature =  %s K" % str(mean_temp))

        print_anisotropy = False
        try:
            G_V = elastic_constants_dict['G_v']
            G_R = elastic_constants_dict['G_r']
            B_V = elastic_constants_dict['B_v']
            B_R = elastic_constants_dict['B_r']
            A_U = 5 * G_V / G_R + B_V / B_R - 6
            A_C = (G_V - G_R) / (G_V + G_R)
            print_anisotropy = True
        except:
            pass

        has_print_ec = False
        has_print_moduli = False
        has_print_sound = False

        for key in elastic_constants_dict.keys():
            if dimensional == '3D':
                if key[0] == 'c' and not has_print_ec:
                    write_line(ec_file, longdash)
                    write_line(ec_file, "Elastic constants:")
                    has_print_ec = True

                content_mapping = {
                    'c': "%s = %s GPa",
                    'B': "%s = %s GPa",
                    'G': "%s = %s GPa",
                    'E': "Young's modulus (%s) = %s GPa",
                    'v': "Poisson's ratio (%s) = %s",
                    'V': "Sound velocity (%s) = %s Km/s",
                    'T': "Debye temperature (%s) = %s K"
                }

                content = content_mapping.get(key[0], None)
                if content:
                    write_line(ec_file, content %
                               (key.capitalize(), "%.2f" % elastic_constants_dict[key]))

            elif dimensional == '2D':
                content_mapping = {
                    'c': "%s = %s N/m",
                    'Y': "Young's modulus (%s) = %s N/m",
                    'v': "Poisson's ratio (%s) = %s",
                    'B': "In-plane stiffness (%s) = %s N/m",
                    'G': "Shear modulus (%s) = %s N/m",
                    'V': "Sound velocity (%s) = %s Km/s",
                    'T': "Debye temperature (%s) = %s K"
                }

                content = content_mapping.get(key[0], None)
                if content:
                    write_line(ec_file, content %
                               (key.capitalize(), "%.2f" % elastic_constants_dict[key]))

            elif dimensional == '1D':
                content_mapping = {
                    'c': "%s = %s GPa",
                    'Y': "Young's modulus (%s) = %s GPa",
                    'v': "Poisson's ratio (%s) = %s",
                    'B': "Bulk modulus (%s) = %s GPa",
                    'G': "Shear modulus (%s) = %s GPa",
                    'C': "Compliance (%s) = %s GPa^-1",
                    'R': "Resonance frequency (%s) = %s GHz",
                    'V': "Sound velocity (%s) = %s Km/s",
                    'T': "Debye temperature (%s) = %s K"
                }

                content = content_mapping.get(key[0], None)
                if content:
              #      if key in ['C', 'R']:
              #          write_line(ec_file, content %
              #                     (key.capitalize(), "%.2e" % elastic_constants_dict[key]))
              #      else:
              #          write_line(ec_file, content %
              #                     (key.capitalize(), "%.2f" % elastic_constants_dict[key]))

                    write_line(ec_file, content %
                               (key.capitalize(), "%.2f" % elastic_constants_dict[key]))


        if print_anisotropy:
            write_line(ec_file, longdash)
            write_line(ec_file, "Elastic anisotropy:")
            write_line(ec_file, "A_U = %s" % "%.4f" % A_U)
            write_line(ec_file, "A_C = %s" % "%.4f" % A_C)

        stable = criteria(elastic_constants_dict, latt_system)
        write_line(ec_file, longdash)
        write_line(ec_file, "Structure stability analysis...")

        if stable:
            write_line(ec_file, "This structure is mechanically STABLE.")
        else:
            write_line(ec_file, "This structure is NOT mechanically stable.")

        write_line(ec_file, "")
        write_line(ec_file, "=" * max_width, border_char="+", filler_char="-")

        print_boxed_message(ec_file)
        ec_file.write("\n")



    #print(elastic_constants_dict)

time_now = datetime.now()
time_used = (time_now - time_start).seconds

with open('time_used.log', 'w') as time_record_file:
    time_record_file.write("The stress calculations used %d seconds.\n" % time_used)


output = sys.stdout


if run_mode != 2:
    #print("")
    #print_banner(output,version)

    for line in open('elastool.out', 'r'):
        l = line.strip('\n')
        print(l)
    print("")
    print("Results are also saved in the elastool.out file.")
    print("")
    print("")
    #print_boxed_message()
    print("Well done! GOOD LUCK!")
    print("")
else:
    print("")
    print("All necessary files are prepared in the STRESS directory.")
    print("Run VASP in each subdirectory and rerun elastool with run_mode = 3.")
    print("")
    print_boxed_message();
    print("Well done! GOOD LUCK!")
    print("")



