#!/home/travis/miniconda/bin/python

# Copyright (C) 2012 Atsushi Togo
# All rights reserved.
#
# This file is part of phonopy.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
#
# * Redistributions of source code must retain the above copyright
#   notice, this list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright
#   notice, this list of conditions and the following disclaimer in
#   the documentation and/or other materials provided with the
#   distribution.
#
# * Neither the name of the phonopy project nor the names of its
#   contributors may be used to endorse or promote products derived
#   from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.

import sys
import warnings
import numpy as np
from phonopy.interface.vasp import get_born_OUTCAR, get_born_vasprunxml

def fracval(frac):
    if frac.find('/') == -1:
        return float(frac)
    else:
        x = frac.split('/')
        return float(x[0]) / float(x[1])

from optparse import OptionParser
parser = OptionParser()
parser.set_defaults(num_atoms=None,
                    primitive_axis =None,
                    supercell_matrix=None,
                    symmetrize_tensors=True,
                    read_outcar=False,
                    symprec=1e-5)
parser.add_option("--dim", dest="supercell_matrix",
                  action="store", type="string",
                  help="Same behavior as DIM tag")
parser.add_option("--pa", "--primitive_axis", dest="primitive_axis",
                  action="store", type="string",
                  help="Same as PRIMITIVE_AXIS tags")
parser.add_option("--nost", "--no_symmetrize_tensors",
                  dest="symmetrize_tensors",
                  action="store_false",
                  help="Prevent from symmetrizing tensors")
parser.add_option("--tolerance", dest="symprec", type="float",
                  help="Symmetry tolerance to search")
parser.add_option("--outcar", dest="read_outcar",
                  action="store_true",
                  help=("Read OUTCAR instead of vasprun.xml. "
                        "POSCAR is necessary in this case."))
(options, args) = parser.parse_args()

if args:
    outcar_filename = args[0]
else:
    if options.read_outcar:
        outcar_filename = "OUTCAR"
    else:
        outcar_filename = "vasprun.xml"

if len(args) > 1:
    poscar_filename = args[1]
else:
    poscar_filename = "POSCAR"

if options.primitive_axis:
    vals = [fracval(x) for x in options.primitive_axis.split()]
    if len(vals) == 9:
        primitive_axis = np.array(vals).reshape(3, 3)
    else:
        print("Primitive axes are incorrectly set.")
        sys.exit(1)
else:
    primitive_axis = np.eye(3)

if options.supercell_matrix:
    vals = [int(x) for x in options.supercell_matrix.split()]
    if len(vals) == 9:
        supercell_matrix = np.reshape(np.array(vals, dtype='intc'), (3, 3))
    elif len(vals) == 3:
        supercell_matrix = np.diag(np.array(vals, dtype='intc'))
    else:
        print("Supercell matrix is incorrectly set.")
        sys.exit(1)
else:
    supercell_matrix = np.eye(3, dtype='intc')

with warnings.catch_warnings() as w:
    # To catch warnings as error, set warnings.simplefilter("error")
    warnings.simplefilter("always")

    try:
        if options.read_outcar:
            borns, epsilon, atom_indices = get_born_OUTCAR(
                poscar_filename=poscar_filename,
                outcar_filename=outcar_filename,
                primitive_matrix=primitive_axis,
                supercell_matrix=supercell_matrix,
                symmetrize_tensors=options.symmetrize_tensors,
                symprec=options.symprec)
        else:
            borns, epsilon, atom_indices = get_born_vasprunxml(
                filename=outcar_filename,
                primitive_matrix=primitive_axis,
                supercell_matrix=supercell_matrix,
                symmetrize_tensors=options.symmetrize_tensors,
                symprec=options.symprec)
    except UserWarning:
        sys.exit(1)
    else:
        text = "# epsilon and Z* of atoms "
        text += ' '.join(["%d" % n for n in atom_indices + 1])
        print(text)
        print(("%13.8f" * 9) % tuple(epsilon.flatten()))
        for z in borns:
            print(("%13.8f" * 9) % tuple(z.flatten()))
        sys.exit(0)

sys.exit(1)
