#!/Users/abel/.virtualenvs/test_lammps/bin/python

from contextlib import contextmanager
from phonolammps import Phonolammps
from lammps import lammps
from datetime import date
import numpy as np
import argparse
import os, sys
import phonopy


# Define function to silence output
@contextmanager
def silence_stdout(deactivate):
    if deactivate:
        yield None
    else:
        new_target = open(os.devnull, "w")
        old_target, sys.stdout = sys.stdout, new_target
        try:
            yield new_target
        finally:
            sys.stdout = old_target

@contextmanager
def show_stdout():
    yield None


parser = argparse.ArgumentParser(description='phonoLAMMPS options')

parser.add_argument('lammps_input_file', metavar='lammps_file', type=str,
                    help='lammps input file', nargs='?')

parser.add_argument('-o', metavar='file', type=str, default='FORCE_CONSTANTS',
                    help='force constants output file [default: FORCE_CONSTANTS]')

parser.add_argument('--dim', metavar='N', type=int, nargs=3,  default=[1, 1, 1],
                    help='dimensions of the supercell')

parser.add_argument('-p', action='store_true',
                   help='plot phonon band structure')

parser.add_argument('-c', '--cell', metavar='file', type=str, default=None,
                    help='generates a POSCAR type file containing the unit cell')

parser.add_argument('-pa', '--primitive_axis', metavar='F', type=float, nargs=9,  default=[1, 0, 0,
                                                                                           0, 1, 0,
                                                                                           0, 0, 1],
                    help='primitive axis')

parser.add_argument('-t', metavar='F', type=float, default=None,
                    help='temperature in K')

parser.add_argument('--amplitude', metavar='F', type=float, default=0.01,
                    help='displacement distance [default: 0.01 angstrom]')

parser.add_argument('--total_time', metavar='F', type=float, default=20,
                    help='total MD time in picoseconds [default: 20 ps]')

parser.add_argument('--relaxation_time', metavar='F', type=float, default=5,
                    help='MD relaxation time in picoseconds [default: 5 ps]')

parser.add_argument('--timestep', metavar='F', type=float, default=0.001,
                    help='MD time step in picoseconds [default: 0.001 ps]')

parser.add_argument('--logshow', action='store_true',
                   help='show LAMMPS & dynaphopy log on screen')

parser.add_argument('--no_symmetrize', action='store_true',
                   help='deactivate force constant symmetrization')

parser.add_argument('--use_NAC', action='store_true',
                   help='include non analytical corrections (Requires BORN file in work directory)')

parser.add_argument('--write_force_sets', action='store_true',
                   help='write FORCE_SETS file')

parser.add_argument('--version', action='store_true',
                   help='print version')


args = parser.parse_args()



#raise argparse.ArgumentError(args, 'message')

if args.version:
    from phonolammps import __version__ as ph_lammps_version

    lmp = lammps(cmdargs=['-screen', 'none'])
    version_date = str(lammps.version(lmp))

    date_lammps = date(int(version_date[0:4]),
                       int(version_date[4:6]),
                       int(version_date[6:8]))

    print('PHONOLAMMPS: {}'.format(ph_lammps_version))
    print('LAMMPS: {}'.format(date_lammps.strftime("%d %B %Y")))
    print('PHONOPY: {}'.format(phonopy.__version__))
    print('PYTHON: {}'.format(sys.version.split('\n')[0]))

    exit()

if args.lammps_input_file is None:
    parser.error("too few arguments")

phlammps = Phonolammps(args.lammps_input_file,
                       supercell_matrix=np.diag(args.dim),
                       primitive_matrix=np.array(args.primitive_axis).reshape((3, 3)),
                       displacement_distance=args.amplitude,
                       show_log=args.logshow,
                       show_progress=True,
                       use_NAC=args.use_NAC,
                       symmetrize=not(args.no_symmetrize))



# If temperature is set then do dynaphopy calculation
if args.t is not None:
    from dynaphopy import Quasiparticle
    from dynaphopy.atoms import Structure
    from dynaphopy.interface.lammps_link import generate_lammps_trajectory
    from dynaphopy.interface.phonopy_link import ForceConstants
    from phonopy.file_IO import write_FORCE_CONSTANTS, write_force_constants_to_hdf5

    unitcell = phlammps.get_unitcell()
    force_constants = phlammps.get_force_constants()
    structure = Structure(cell=unitcell.get_cell(),  # cell_matrix, lattice vectors in rows
                          scaled_positions=unitcell.get_scaled_positions(),
                          atomic_elements=unitcell.get_chemical_symbols(),
                          primitive_matrix=np.array(args.primitive_axis).reshape((3, 3)))

    structure.set_force_constants(ForceConstants(force_constants,
                                                 supercell=np.diag(args.dim)))

    structure.set_supercell_matrix(np.diag(args.dim))
    # generate LAMMPS MD trajectory (requires dynaphopy development)
    trajectory = generate_lammps_trajectory(structure, 'in.lammps',
                                            total_time=args.total_time,
                                            time_step=args.timestep,
                                            relaxation_time=args.relaxation_time,
                                            silent=False,
                                            supercell=args.dim,
                                            memmap=False,
                                            velocity_only=True,
                                            temperature=args.t)

    # Calculate renormalized force constants with dynaphopy
    calculation = Quasiparticle(trajectory)
    with silence_stdout(args.logshow):
        calculation.select_power_spectra_algorithm(2)
        renormalized_force_constants = calculation.get_renormalized_force_constants().get_array()

    if args.p:
        calculation.plot_renormalized_phonon_dispersion_bands()

    print('writing renormalized force constants')
    write_FORCE_CONSTANTS(renormalized_force_constants, filename=args.o)

else:
    if args.p:
        phlammps.plot_phonon_dispersion_bands()

    if args.write_force_sets:
        print('writing force sets')
        phlammps.write_force_sets(filename='FORCE_SETS')
    print('writing force constants')
    phlammps.write_force_constants(filename=args.o)


if args.cell is not None:
    print ('writing POSCAR file: {}'.format(args.cell))
    phlammps.write_unitcell_POSCAR(args.cell)
