#!/usr/bin/env python

import boto.ec2
import boto.ec2.elb
import paramiko
import os
import re
import socket
import sys
from optparse import OptionParser
import logging
import subprocess
from multiprocessing.pool import ThreadPool

logger = logging.getLogger('cloudily')

def abort(message):
    raise SystemExit(message)

class Node(object):
    """Base class for a graph node"""
    def __init__(self, id, name, aliases=[]):
        self.id = id
        self.name = name
        self.aliases = ['id:' + id] + aliases

    def __hash__(self):
        return hash(self.id)

    def __eq__(self, a):
        return self.id == a.id

    def __str__(self):
        return self.id

class Elb(Node):
    """An ELB node"""
    pass

class Host(Node):
    """A physical host node"""
    def __init__(self, id, hostname, aliases, name):
        self.id = id
        self.hostname = hostname
        self.aliases = aliases
        self.name = name
        self.ssh = None

    def __str__(self):
        return self.name

    def connect(self):
        # ssh config file
        config = paramiko.SSHConfig()
        # usage ssh config for username, keys, etc.
        configfile = os.path.expanduser('~/.ssh/config')
        if os.path.exists(configfile):
            config.parse(open(configfile))
        o = config.lookup(self.hostname)

        # ssh client
        self.ssh = paramiko.SSHClient()
        self.ssh.set_missing_host_key_policy(
            paramiko.AutoAddPolicy())
        username = o.get('user', os.getlogin())
        self.ssh.connect(o['hostname'], username=username, timeout=10)

    @property
    def connection(self):
        if not self.ssh:
            self.connect()
        return self.ssh

class HostScanner(object):
    """Base class for a host scanner"""
    def get_output(self, host, cmd):
        logger.debug('Executing: %s on %s' % (cmd, host))
        stdin, stdout, stderr = host.connection.exec_command(cmd)
        return stdout.read()

    @property
    def types(self):
        return (Host,)

class ArpScanner(HostScanner):
    """Get host-host connections by entries in the arp cache"""
    re_arp_ip = re.compile(r'\(([^)]+)\)')

    def list(self, host):
        output = self.get_output(host, 'arp -an')
        for line in output.split('\n'):
            m = self.re_arp_ip.search(line)
            if m:
                yield host, 'ip:' + m.group(1), True, None

class LoginScanner(HostScanner):
    """Get user-host connections by user logins"""
    def __init__(self, mode):
        self.mode = mode

    def list(self, host):
        output = self.get_output(host, 'last -w')
        for line in output.split('\n'):
            if not line or line.startswith('wtmp begins') or line.startswith('reboot'):
                continue
            username = line.split(' ')[0]
            yield host, 'user:' + username, False, None

class NetstatScanner(HostScanner):
    """Get host-host connections by open connections"""
    def __init__(self, ports):
        if ports:
            self.ports = set(ports.split(','))
        else:
            self.ports = None

    # some additional services along with /etc/services
    known_services = { 
        '9200/tcp': 'elasticsearch',
        '9300/tcp': 'elasticsearch node',
        '9500/tcp': 'elasticsearch thrift',
        '6379/tcp': 'redis',
        '26379/tcp': 'redis sentinel',
        '8649/tcp': 'ganglia gmetad',
        '8649/udp': 'ganglia gmetad',
        '27017/tcp': 'mongo',
        '27018/tcp': 'mongo slave',
    }

    def getservbyport(self, port, proto):
        key = '%s/%s' % (port, proto)
        # lookup in known_services
        if key in self.known_services:
            return self.known_services[key]
        try:
            # lookup in /etc/services
            v = socket.getservbyport(int(port))
            self.known_services[key] = v
            return v
        except:
            pass
        return None

    re_ws = re.compile('\s+')

    def list(self, host):
        output = self.get_output(host, 'netstat -nut')
        for line in output.split('\n'):
            if not line or not (line.startswith('tcp') or line.startswith('udp')):
                continue
            try:
                parts = self.re_ws.split(line)
                proto, recvq, sendq, local, remote, state = parts[:6]
                if proto == 'tcp6':
                    proto = 'tcp'
                lip, lport = local.split(':', 1)
                rip, rport = remote.split(':', 1)
                if lip == '127.0.0.1':
                    continue
                if self.ports and not(rport in self.ports or lport in self.ports):
                    continue
                label = self.getservbyport(rport, proto)
                if label:
                    yield host, 'ip:' + rip, True, label
                    continue
                label = self.getservbyport(lport, proto)
                if label:
                    yield host, 'ip:' + rip, False, label
                    continue
                # unknown protocol on either end - warn?
                logging.warn('Unknown source/dest port on connection: %s (%r)' % (line, parts))
            except ValueError:
                continue

class NodeScanner(object):
    """Base class for a node scanner"""
    @property
    def types(self):
        return None

    def __str__(self):
        return self.__class__.__name__

class EC2NodeScanner(NodeScanner):
    """Scanner for EC2 resources.

    Supports:
    - instances
    - ELBs
    """
    def __init__(self, key, secret, options):
        self.key = key
        self.secret = secret
        self.options = options

    def get_instances(self, filters={}):
        conn = self.get_connection()
        filters = {'instance-state-name': 'running'}
        return (inst for r in conn.get_all_instances(filters=filters) for inst in r.instances)

    def get_elbs(self):
        conn = boto.ec2.elb.connect_to_region(os.getenv('EC2_REGION', 'us-east-1'))
        return conn.get_all_load_balancers()

    def get_connection(self):
        ec2_region = os.getenv('EC2_REGION', 'us-east-1')
        return boto.ec2.connect_to_region(ec2_region)

    def list(self):
        if 'instances' in self.options:
            for inst in self.get_instances():
                name = inst.tags.get('Name') or inst.public_dns_name
                aliases = ['id:' + inst.id, 'host:' + inst.public_dns_name, 'ip:' + inst.ip_address, 'ip:' + inst.private_ip_address]
                yield Host(inst.id, inst.public_dns_name, aliases, name)

        if 'elb' in self.options:
            for lb in self.get_elbs():
                hostname, aliaslist, ipaddrlist = socket.gethostbyname_ex(lb.dns_name)
                aliases = ['host:' + hostname] + ['host:' + h for h in aliaslist] + ['ip:' + ip for ip in ipaddrlist]
                elb = Elb('elb:'+lb.name, lb.name + ' ELB', aliases)
                yield elb
                for inst in lb.instances:
                    yield elb, 'id:' + inst.id, True, None

class FileNodeScanner(NodeScanner):
    """Scanner using static host list"""
    def __init__(self, filename):
        self.filename = filename

    def list(self):
        with file(self.filename) as fin:
            for line in fin:
                hostname = line.strip()
                ip = socket.gethostbyname(hostname)
                aliases = ['host:' + hostname, 'ip:' + ip]
                yield Host(hostname, hostname, aliases, hostname)

class Edge(object):
    """Represents an edge in the graph"""
    def __init__(self, frm, to, label):
        self.frm = frm
        self.to = to
        self.label = label

    def __hash__(self):
        return hash(self.frm) ^ hash(self.to) ^ hash(self.label)

    def __eq__(self, a):
        return self.frm == a.frm and self.to == a.to and self.label == a.label

    def __str__(self):
        return '%s -> %s' % (self.frm, self.to)

def unique(seq):
    seen = set()
    return [x for x in seq if x not in seen and not seen.add(x)]

class Cloudily(object):
    """Main class"""
    def run(self):
        self.parse_args()
        self.configure()
        self.scan()
        self.save(self.opts.out)
        if self.opts.png:
            self.render(self.opts.out, self.opts.png, self.opts.layout)
        if self.opts.preview:
            self.preview(self.opts.out)

    def parse_args(self):
        parser = OptionParser()
        parser.add_option('--unknown', action='store_true', help='Include unknown nodes')
        parser.add_option('--out', default='output.dot', help='Dot output filename')
        parser.add_option('--hosts', help='A static host list: file.txt')
        parser.add_option('--ec2', help='Scan EC2 for instances, elb,etc.')
        parser.add_option('--ec2creds', default='env', help='Use EC2 for host list')
        parser.add_option('--arp', action='store_true', help='Scan arp cache')
        parser.add_option('--logins', help='Scan user logins. Value currently unused.')
        parser.add_option('--conns', action='store_true', help='Scan open connections (netstat)')
        parser.add_option('--connsports', help='Port list to include')
        parser.add_option('--labels', action='store_true', help='Add edge labels (if available)')
        parser.add_option('--preview', action='store_true', help='Generate a preview montage of graphviz layouts')
        parser.add_option('--png', help='Generate a png image using graphviz')
        parser.add_option('--layout', default='dot', help='Select a layout to use for the png image')
        self.opts, self.args = parser.parse_args()
        if not self.opts.arp and not self.opts.logins and not self.opts.conns:
            self.opts.conns = True
        if not self.opts.hosts and not self.opts.ec2:
            self.opts.ec2 = 'instances,elb'

    def configure(self):
        self.scanners = s = []
        self.parse_args()

        if self.opts.hosts:
            s.append(FileNodeScanner(self.opts.hosts))
        elif self.opts.ec2:
            if self.opts.ec2creds == 'env':
                key = os.getenv('AWS_ACCESS_KEY_ID')
                if not key:
                    abort('AWS_ACCESS_KEY_ID not set')
                secret = os.getenv('AWS_SECRET_ACCESS_KEY')
                if not secret:
                    abort('AWS_SECRET_ACCESS_KEY not set')
            options = self.opts.ec2.split(',')
            s.append(EC2NodeScanner(key, secret, options))

        if self.opts.arp:
            s.append(ArpScanner())
        if self.opts.logins:
            s.append(LoginScanner(self.opts.logins))
        if self.opts.conns:
            s.append(NetstatScanner(self.opts.connsports))

        logging.basicConfig(level=logging.WARN, format='%(levelname)-8s %(message)s')
        logger.setLevel(logging.DEBUG)

    def scan(self):
        pool = ThreadPool(8)
        nodes = []
        edges = set()
        lookup = {}
        for scanner in self.scanners:
            if scanner.types is None:
                logger.info('Scanning using %s' % scanner)
                results = scanner.list()
            else:
                def scan_node(node):
                    logger.info('Scanning node %s' % node)
                    try:
                        results = list(scanner.list(node))
                        logger.debug('Finished node %s' % node)
                        return results
                    except:
                        logger.exception('Exception scanning node: %s' % node)
                        return []
                matches = [node for node in nodes if isinstance(node, scanner.types)]
                results = [r for rs in pool.map(scan_node, matches) for r in rs]

            for result in results:
                if isinstance(result, Node):
                    logger.debug('Discovered node: %s' % result)
                    nodes.append(result)
                    for a in result.aliases:
                        lookup[a] = result
                else:
                    node, id, forward, label = result
                    target = lookup.get(id)
                    if not target and self.opts.unknown:
                        target = Node(id, id)
                        nodes.append(target)
                        lookup[node] = target

                    if not self.opts.labels:
                        label = None
                    if target:
                        if forward:
                            edge = Edge(node, target, label)
                        else:
                            edge = Edge(target, node, label)

                        if edge not in edges:
                            logger.debug('Discovered edge: %s' % edge)
                            edges.add(edge)

        self.nodes = nodes
        self.edges = edges
        logger.info('Done')

    def save(self, filename):
        if filename == '-':
            self.to_dot(sys.stdout)
        else:
            with file(filename, 'w') as fout:
                self.to_dot(fout)

    def render(self, dot, png, prog):
        try:
            ps = subprocess.Popen([prog, '-Tpng', dot], stdout=subprocess.PIPE)
            output = ps.communicate()[0]
            with file(png, 'w') as fout:
                fout.write(output)
            logger.info("Generated: %s (layout: %s)" % (png, prog))
        except Exception as ex:
            logger.error("Problem running graphviz '%s' command.\n\n"
                "Please ensure imagemagick is installed:\n"
                "  %s" % (prog, str(ex)))

    def preview(self, filename):
        files = []
        for prog in ('dot', 'neato', 'circo', 'fdp', 'sfdp', 'twopi'):
            self.render(filename, prog+'.png', prog)
            files.append(prog+'.png')

        montage = 'montage.png'
        args = ['montage', '-label', "'%f'"]
        args.extend(files)
        args.extend(['-geometry', '800x800', montage])
        try:
            subprocess.call(args)
            logger.info("Generated: %s (from: %s)" % (montage, ', '.join(files)))
        except Exception as ex:
            logger.error("Problem running imagemagick 'montage' command.\n\n"
                "Please ensure imagemagick is installed:\n"
                "  %s" % str(ex))

    shapes = {
        Host: 'box',
        Elb: 'trapezium',
    }

    def to_dot(self, fout):
        print >>fout, 'digraph network {'
        print >>fout, '  graph [overlap=false];'
        for node in self.nodes:
            shape = self.shapes.get(type(node), 'ellipse')
            print >>fout, '  "%s" [label="%s", shape=%s];' % (node.id, node.name, shape)
        print >>fout
        for edge in self.edges:
            if edge.label:
                print >>fout, '  "%(from)s" -> "%(to)s" [label="%(label)s"];' % {'from': edge.frm.id, 'label': edge.label, 'to': edge.to.id}
            else:
                print >>fout, '  "%s" -> "%s";' % (edge.frm.id, edge.to.id)
        print >>fout, '}'

def main():
    Cloudily().run()

if __name__ == '__main__':
    main()
