#!/usr/bin/env python3

import sys
import argparse
import warnings
import networkx.algorithms.isomorphism as iso
from networkx import is_isomorphic
from pyRDTP.geomio import file_to_mol
from pyRDTP.molecule import Molecule
from pyRDTP.operations import graph

PARSER = argparse.ArgumentParser(description="""Compare the topology of two
                                 different molecules.""")

PARSER.add_argument("-i", "--initial", type=str, nargs='+', default=['CONTCAR'],
                    help="Location of the initial structure.")
PARSER.add_argument("-f", "--final", type=str, nargs='+', default=['POSCAR'],
                    help="Location of the final structure.")
PARSER.add_argument("-elem", type=str, nargs='+', default=None,
                    help="Elements of the molecules that will be compared")
PARSER.add_argument("-nelem", type=str, nargs='+', default=None,
                    help="""Elements that not compose the molecules that will
                            be compared""")
PARSER.add_argument("-t", "--tolerance", type=float, default=0.15,
                    help="""Tolerance (In A) that will be used during the bond
                         detection""")
PARSER.add_argument("-u", "--unique", type=str, nargs='+', default=None,
                    help="""Elements that will be taking into account as unique
                         identities during the isomorphism test.""")
PARSER.add_argument("-v", "--verbose", action='store_true', default=False,
                    help="Print warnings.")
PARSER.add_argument("-y", "--test", action='store_true', default=False,
                    help="""Print the value of the isomorphism test. If more
                    than one input files are specified this is always True""")

ARGS = PARSER.parse_args()

# Warnings and errors
if len(ARGS.initial) != len(ARGS.final):
    raise IndexError("Number of initial and final files must be the same")

if ARGS.verbose:
    if ARGS.elem is None and ARGS.nelem is None:
        warnings.warn("No elements specified, all atoms will be used")
    if ARGS.unique is None:
        warnings.warn("""No unique elements specified, all the elements will
                      be setted as unique""")
    if ARGS.elem is not None and ARGS.nelem is not None:
        warnings.warn("""Elements and negative elements detected, only elements
                      will be used""")


# Read boths files
PAIRS = []
for initial, final in zip(ARGS.initial, ARGS.final):
    INITIAL = file_to_mol(initial, 'contcar')
    FINAL = file_to_mol(final, 'contcar')
    PACK_TMP = [INITIAL, FINAL]
    PAIRS.append((PACK_TMP, initial, final))


for PACK, initial, final in PAIRS:
    # Select the proper atoms
    INDEX = []
    NO_ELEM = False
    for geom in PACK:
        INDEX_TMP = []
        if ARGS.elem is not None:
            INDEX_TMP += geom.atom_element_filter(ARGS.elem, option='index',
                                                  invert=False)
        elif ARGS.nelem is not None:
            INDEX_TMP += geom.atom_element_filter(ARGS.nelem, option='index',
                                                  invert=True)
        else:
            NO_ELEM = True
        INDEX.append(INDEX_TMP)

    # Generate new molecules
    MOLECULES = []
    GRAPHS = []
    for atoms, geom in zip(INDEX, PACK):
        if NO_ELEM:
            MOL_EX = geom
        else:
            TMP_ATOMS = [geom.atoms[atom_index].copy() for atom_index in atoms]
            MOL_EX = Molecule('dummy')
            MOL_EX.atom_add_list(TMP_ATOMS)
            MOL_EX.cell_p_add(geom.cell_p.copy())
            MOL_EX.coords_convert_update('direct')
        MOL_EX.connectivity_search_voronoi(tolerance=ARGS.tolerance)
        GRAPH_EX = graph.generate(MOL_EX)
        MOLECULES.append(MOL_EX)
        GRAPHS.append(GRAPH_EX)

    # Get Unique Discrimation
    if ARGS.unique is not None:
        UNIQUE = ARGS.unique
    else:
        UNIQUE = list(INITIAL.elements)

    # Isomorphism test
    CRITERIA = iso.categorical_node_match('elem', 'H')
    BOOL = is_isomorphic(GRAPHS[0], GRAPHS[1], node_match=CRITERIA)

    if len(ARGS.initial) != 1:
        print(initial, final, BOOL)
    elif ARGS.test:
        print(BOOL)

if len(ARGS.initial) == 1:
    if BOOL:
        sys.exit(0)
    else:
        sys.exit(1)
