#!/usr/bin/env python

import argparse
import graphviz
import logging
import sys
import yaml
from pgspawn import Graph


parser = argparse.ArgumentParser(description="Convert pgspawn's YAML description into DOT file used by graphviz.")
parser.add_argument(
    "-v", "--verbose",
    dest='verbose_count', action='count', default=0,
    help="increases log verbosity for each occurence",
)
parser.add_argument(
    "-i", "--input",
    dest='input_file', type=argparse.FileType('r'), default=sys.stdin,
    help="input file",
)
parser.add_argument(
    "-o", "--output",
    dest='output_file', type=argparse.FileType('w'), default=sys.stdout,
    help="output file",
)

if __name__ == '__main__':
    arguments = parser.parse_args()

    # Sets log level to WARN going more verbose for each new -v.
    logging.basicConfig(
        format='%(process)d %(levelname)s: %(message)s',
        level=max(3 - arguments.verbose_count, 0) * 10,
    )

    # load graph description
    description = yaml.load(arguments.input_file)
    arguments.input_file.close()
    graph = Graph.from_dict(description)

    # make visualized graph (empty so far)
    gv = graphviz.Digraph(strict=True)

    # TODO draw global inputs, outputs and sockets
    # * draw with different style to visualize their different nature
    # * currently global, unused inputs, outputs and sockets will not be drawn

    # dict socket_id -> node in graphviz description
    socket_starts = {}

    # for each node make subgraph and add each fd as its node
    for node_id, node in enumerate(graph.nodes):
        # NOTE: the subgraph name needs to begin with 'cluster' (all lowercase)
        #       so that Graphviz recognizes it as a special cluster subgraph
        with gv.subgraph(name='cluster_{}'.format(node_id)) as subgv:
            subgv.attr(label=str(node.command))

            for fd, pipe_name in node.outputs.items():
                subgv_node_id = '{}-{}'.format(node_id, fd)
                subgv.node(subgv_node_id, label=str(fd))
                gv.edge(subgv_node_id, pipe_name)

            for fd, pipe_name in node.inputs.items():
                subgv_node_id = '{}-{}'.format(node_id, fd)
                subgv.node(subgv_node_id, label=str(fd))
                gv.edge(pipe_name, subgv_node_id)

            for fd, socket_id in node.sockets.items():
                subgv_node_id = '{}-{}'.format(node_id, fd)
                subgv.node(subgv_node_id, label=str(fd))
                if socket_id in socket_starts:
                    gv.edge(
                        socket_starts[socket_id], subgv_node_id,
                        label=socket_id,
                        dir='both',
                    )
                    del socket_starts[socket_id]
                else:
                    socket_starts[socket_id] =  subgv_node_id

    # TODO draw unused socket ends
    # * draw them RED! they have to look dangerous

    # output result
    print(gv.source, file=arguments.output_file)
