#!/usr/bin/env python
"""Download a given SSH certificate and put it in the right place.

Downloads the certificate specified on argv and then searches through the
user's .ssh directory looking for a matching private key. Puts the certificate
next to that one.
"""
import datetime
import os
import subprocess
import sys
import tempfile
import urllib

import ssh_ca.agent_client


class CertMetadata(object):
    def __init__(self):
        self.public_key_fingerprint = None
        self.valid_for_seconds = None


def download_cert_to_tempfile(url):
    resp = urllib.urlopen(url)
    if resp.code > 299 or resp.code < 200:
        print "Bad response code: HTTP/%d" % (resp.code,)
        print resp.read()
        sys.exit(1)

    temp_file = tempfile.NamedTemporaryFile(delete=False)
    with temp_file.file:
        temp_file.write(resp.read())
    return temp_file.name


def get_cert_metadata(cert_path):
    proc = subprocess.Popen(['/usr/bin/ssh-keygen', '-L', '-f', cert_path],
                            stdout=subprocess.PIPE, stderr=subprocess.PIPE)

    if proc.stderr.read().find('is not a public key') != -1:
        print "Invalid signed ssh certificate file: %s" % (cert_path,)
        sys.exit(1)

    metadata = CertMetadata()
    for line in proc.stdout.readlines():
        if 'Public key:' in line:
            fingerprint = line[line.find('RSA-CERT') + 9:]
            fingerprint = fingerprint.strip()
            metadata.public_key_fingerprint = fingerprint
        if 'Valid:' in line:
            expire_time = line[line.find(' to ') + 4:].strip()
            expire_dt = datetime.datetime.strptime(expire_time,
                                                   '%Y-%m-%dT%H:%M:%S')
            now_dt = datetime.datetime.now()
            delta = expire_dt - now_dt
            valid_for_seconds = delta.seconds
            metadata.valid_for_seconds = valid_for_seconds

    return metadata


def find_private_key_for_public_key(pub_fingerprint):
    ssh_dir = os.getenv('HOME') + '/.ssh'
    for filename in os.listdir(ssh_dir):
        key_filename = ssh_dir + '/' + filename
        if key_filename.endswith('pub'):
            continue
        proc = subprocess.Popen(
            ['/usr/bin/ssh-keygen', '-l', '-f', key_filename],
            stdout=subprocess.PIPE, stderr=subprocess.PIPE)
        for line in proc.stdout.readlines():
            if pub_fingerprint in line:
                return key_filename


def move_cert_into_place(cert_path, private_key_filename):
    new_cert_filename = private_key_filename + '-cert.pub'
    os.rename(cert_path, new_cert_filename)


def re_add_identity(private_key_filename, valid_for_seconds):
    public_filename = private_key_filename + '.pub'
    if os.path.exists(public_filename):
        agent_client = ssh_ca.agent_client.Client()
        try:
            agent_client.connect()
        except ssh_ca.agent_client.SshClientFailure:
            print "Unable to find SSH agent, things probably aren't working."

        with open(public_filename, 'r') as f:
            pub_key = f.read().strip().split()[1]
        try:
            agent_client.remove_key(pub_key)
        except ssh_ca.agent_client.SshClientFailure:
            print 'Unable to delete existing key, this is probably benign'

    subprocess.check_output([
        '/usr/bin/ssh-add', '-t', '%d' % (valid_for_seconds,),
        private_key_filename])


if __name__ == '__main__':
    if len(sys.argv) != 2:
        print 'Usage: %s <URL to cert>' % (sys.argv[0],)
        sys.exit(1)

    cert_filename = download_cert_to_tempfile(sys.argv[1])
    cert_metadata = get_cert_metadata(cert_filename)
    key_fingerprint = cert_metadata.public_key_fingerprint
    private_key_filename = find_private_key_for_public_key(key_fingerprint)
    if not private_key_filename:
        print 'Unable to find private key matching certificate.'
        sys.exit(1)

    move_cert_into_place(cert_filename, private_key_filename)
    re_add_identity(private_key_filename, cert_metadata.valid_for_seconds)
