#!/usr/local/bin/blue-python3.5
# vim:syntax=python

import argparse
import logging
import pamphlet
import paramiko
import socket
import threading
import time

class SshDaemon(object):
    def __init__(self):
        ap = argparse.ArgumentParser()
        ap.add_argument('--key',
                        required=True,
                        help="Path to the ssh host key")
        ap.add_argument('--bind-address',
                        default='0.0.0.0',
                        help="Address to bind to")
        ap.add_argument('--port', type=int,
                        default=22,
                        help='Port to bind to')
        ap.add_argument('--pam-service',
                        default='ssh-key-injector',
                        help="PAM service to use")
        self.args = ap.parse_args()
        logging.basicConfig(level=logging.INFO)
        self.logger = logging.getLogger('ssh.daemon')

    def run(self):
        self.server_socket = socket.socket()
        self.server_socket.bind((self.args.bind_address, self.args.port))
        self.server_socket.listen(10)
        self.args.key = paramiko.rsakey.RSAKey(filename=self.args.key)
        self.logger.info("sshd running on %s:%s" % (self.args.bind_address, self.args.port))
        self.session = 1

        while True:
            client_socket, (address, port) = self.server_socket.accept()
            self.logger.info("New connection from %s, session %d" % (address, self.session))
            SshSession(self.session, client_socket, address, self.args).start()
            self.session = (self.session + 1) % 10000

class SshSession(threading.Thread):
    def __init__(self, sessno, socket, address, args):
        super(SshSession, self).__init__()
        self.socket = socket
        self.address = address
        self.args = args
        self.user = None
        self.logger = logging.getLogger('ssh.session[%d]' % sessno)
        self.channel = None

    def send_and_log(self, message, level=logging.INFO):
        self.logger.log(level, message)
        if self.channel:
            self.channel.send('\r\n%s\r\n\r\n' % message)

    def run(self):
        try:
            self.session()
        except Exception as e:
            self.logger.exception("Unhandled exception. Traceback:")
            if(self.channel):
                try:
                    self.channel.send("\r\nSomething went wrong\r\n\r\n")
                except:
                    pass
        if self.channel:
            self.logger.info("Session finished, closing channel")
            self.channel.close()
            self.logger.info("Channel closed")
        self.transport.close()

    def session(self):
        self.transport = paramiko.transport.Transport(self.socket)
        self.transport.logger.level = logging.CRITICAL
        self.transport.add_server_key(self.args.key)
        self.transport.has_shell = threading.Event()
        self.transport.has_agent = threading.Event()
        self.transport.start_server(server=SshAuthInterface(self.args.pam_service))

        # Wait for the client to authenticate and set up their session
        self.channel = self.transport.accept(None)
        if not self.channel:
            self.logger.error("No ssh channel established")
            return

        self.transport.has_shell.wait(5)
        if not self.transport.has_shell.is_set():
            self.send_and_log("You must use an interactive shell. Disconnecting.")
            return

        self.send_and_log("You have succesfully authenticated")

class SshAuthInterface(paramiko.server.ServerInterface):
    def __init__(self, pam_service):
        self.pam_service = pam_service
        self.logger = logging.getLogger("ssh.auth")

    def get_allowed_auths(self, username):
        return "keyboard-interactive"

    def check_auth_interactive(self, username, submethods):
        self.pam = pamphlet.ThreadedPamApplication(self.pam_service, username)
        self.auth_thread = self.pam.authenticate()
        data = self.pam.get_data()
        if isinstance(data, Exception):
            self.logger.exception(data)
            self.auth_thread.join()
            return paramiko.AUTH_FAILED
        if not data:
            return paramiko.AUTH_FAILED
        q = paramiko.server.InteractiveQuery()
        for prompt in data:
            q.add_prompt(prompt, not prompt.wants_password)
        return q

    def check_auth_interactive_response(self, responses):
        self.pam.set_input(responses)
        data = self.pam.get_data()
        if isinstance(data, Exception):
            self.logger.exception(data)
            self.auth_thread.join()
            return paramiko.AUTH_FAILED
        if not data:
            self.auth_thread.join()
            return paramiko.AUTH_SUCCESSFUL
        q = paramiko.server.InteractiveQuery()
        for prompt in data:
            q.add_prompt(prompt, not prompt.wants_password)
        return q

    def check_channel_request(self, kind, chanid):
        self.logger.info("Session finished, closing channel")
        if kind == 'session':
            return paramiko.OPEN_SUCCEEDED
        return paramiko.OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED

    def check_channel_pty_request(self, *ignored):
        return True

    def check_channel_shell_request(self, channel):
        channel.transport.has_shell.set()
        return True
SshDaemon().run()
