#!/home/jpreto/env/bin/python
import os
import sys
import argparse
import subprocess

from mdkit.utility import mol2

parser = argparse.ArgumentParser(description="Convert pdb to .mol2 using .mol2 reference file")

parser.add_argument('-i',
    dest='pdbfile',
    required=True,
    help="PDB input file (1 structure)")

parser.add_argument('-r',
    dest='mol2file_ref',
    required=True,
    default=None,
    help="mol2 file used as reference")

parser.add_argument('-o',
    dest='mol2file_out',
    required=True,
    help="Output mol2 file")

args = parser.parse_args()

if os.path.isfile(args.pdbfile):
    pdbfile = args.pdbfile
else:
    raise ValueError(".pdb file %s does not exist!"%args.pdbfile)

if os.path.isfile(args.mol2file_ref):
    mol2file_ref = args.mol2file_ref
else:
    raise ValueError("ref .mol2 file %s does not exist!"%args.mol2file_ref)

mol2file_out = args.mol2file_out

def get_neighbors(graph, node, data=False):
    neighbors = []
    visited_nodes = [node]
    neighbors_current = [node]

    is_done = False
    while not is_done:
        neighbors_current = [nbr for n in neighbors_current for nbr in graph.neighbors(n) if nbr not in visited_nodes]
        neighbors_current = list(set(neighbors_current))
        if neighbors_current:
            visited_nodes.extend(neighbors_current)
            neighbors_current_noh = []
            for nbr in neighbors_current:
                if graph.node[nbr]['type'] not in ['H', 'h']:
                    neighbors_current_noh.append(nbr)
            if neighbors_current_noh:
                if data:
                    neighbors.append([graph.node[nbr]['type'] for nbr in neighbors_current_noh])
                else:
                    neighbors.append(neighbors_current_noh)
        else:
            is_done = True
    return neighbors

base, ext = os.path.splitext(pdbfile)
base_ref, ext_ref = os.path.splitext(mol2file_ref)

mol2file_query_uniq = base + '_uniq_noH.mol2'
mol2file_ref_uniq =  base_ref + '_uniq.mol2'

# make rough mol2 file from pdb with babel
subprocess.check_output('babel %(pdbfile)s %(base)s.mol2 &>/dev/null'%locals(), shell=True, executable='/bin/bash')

mol2.update_mol2file(base+'.mol2', mol2file_query_uniq, unique=True, remove=['H', 'h']) # remove hydrogens and give uique names to heavy atoms
os.remove(base+'.mol2')

# make the atoms unique for the original structure
mol2.update_mol2file(mol2file_ref, mol2file_ref_uniq, unique=True)

G_query = mol2.get_graph(mol2file_query_uniq)
G_ref = mol2.get_graph(mol2file_ref_uniq)

query2ref = {}
assigned_atoms_ref = []
for node, attr in G_query.nodes(data=True):
    neighbors_query = get_neighbors(G_query, node, data=True)
    found_atom = False
    for node_ref, attr_ref in G_ref.nodes(data=True):
        neighbors_ref = get_neighbors(G_ref, node_ref, data=True)
        is_same_atom = True
        for idx, nbr_q in enumerate(neighbors_query):
            nbr_r = neighbors_ref[idx]
            if sorted(nbr_q) != sorted(nbr_r):
                is_same_atom = False
                break
        if is_same_atom and node_ref not in assigned_atoms_ref and not found_atom:
            query2ref[node] = node_ref
            assigned_atoms_ref.append(node_ref)
            found_atom = True
    if not found_atom:
        atomtype = G_query.node[node]['type']
        raise ValueError("Heavy atoms involved in .pdb file are different to those in reference file!"%locals())

mol2f_q = mol2.Reader(mol2file_query_uniq)
mol2f_r = mol2.Reader(mol2file_ref_uniq)

struct_q = mol2f_q.next()
struct_r = mol2f_r.next()

mol2f_q.close()
mol2f_r.close()

new_struct_q = struct_q
for idx, line_q in enumerate(struct_q['ATOM']):
    atomidx = int(line_q[0])
    atomidx_ref = query2ref[atomidx]
    for line_r in struct_r['ATOM']:
        if int(line_r[0]) == atomidx_ref:
            new_struct_q['ATOM'][idx][1] = line_r[1]

mol2.Writer().write(mol2file_query_uniq, new_struct_q)
mol2.update_mol2file(mol2file_query_uniq, base+'_uniq.mol2', ADupdate=base_ref+'_uniq.mol2', unique=True)
mol2.arrange_hydrogens(base+'_uniq.mol2', mol2file_out) # arranging 

# remove intermediate files
os.remove(base+'_uniq.mol2')
os.remove(mol2file_ref_uniq)
