#!/usr/bin/env python
# encoding:utf8
"""Read and copy Keepass database entries using dmenu or rofi

"""
from contextlib import closing
import errno
import itertools
import locale
from multiprocessing import managers  # pylint: disable=unused-import
import os
from os.path import exists, expanduser
import random
import shlex
import socket
import string
import sys
from subprocess import Popen, PIPE
from threading import Timer
import time
from pykeepass import PyKeePass
from pykeyboard import PyKeyboard
if sys.version_info.major < 3:
    # hack to reduce client connect timeout for python 2.7 (defaults to 20 seconds)
    # https://stackoverflow.com/questions/6512884/properly-disconnect-multiprocessing-remote-manager#9936835
    def _new_init_timeout():
        return time.time() + 0.1
    # pragma pylint: disable=protected-access,line-too-long
    sys.modules['multiprocessing'].__dict__['managers'].__dict__['connection']._init_timeout = _new_init_timeout
    # pragma pylint: enable=protected-access,line-too-long

# pragma pylint: disable=ungrouped-imports,wrong-import-order,wrong-import-position
from multiprocessing.managers import BaseManager
from multiprocessing import Event, Process
# pragma pylint: enable=ungrouped-imports,wrong-import-order,wrong-import-position

try:
    import configparser as configparser
except ImportError:
    import ConfigParser as configparser


if sys.version_info.major < 3:
    str = unicode  # pylint: disable=undefined-variable, invalid-name, redefined-builtin


def find_free_port():
    """Find random free port to use for BaseManager server

    Returns: int Port

    """
    with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
        sock.bind(('127.0.0.1', 0))  # pylint:disable=no-member
        return sock.getsockname()[1]  # pylint:disable=no-member


def random_str():
    """Generate random auth string for BaseManager

    Returns: string

    """
    letters = string.ascii_lowercase
    return ''.join(random.choice(letters) for i in range(15))


def process_config():
    """Set global variables. Read the config file. Create default config file if
    one doesn't exist.

    """
    # pragma pylint: disable=global-variable-undefined
    global CACHE_PERIOD_MIN, \
           CACHE_PERIOD_DEFAULT_MIN, \
           CONF, \
           CONF_FILE, \
           DMENU_LEN, \
           ENV, \
           ENC
    # pragma pylint: enable=global-variable-undefined
    ENV = os.environ.copy()
    ENV['LC_ALL'] = 'C'
    ENC = locale.getpreferredencoding()
    CACHE_PERIOD_DEFAULT_MIN = 360
    CONF_FILE = expanduser("~/.config/keepmenu/config.ini")
    CONF = configparser.ConfigParser()
    if not exists(CONF_FILE):
        try:
            os.mkdir(os.path.dirname(CONF_FILE))
        except OSError:
            pass
        with open(CONF_FILE, 'w') as conf_file:
            CONF.add_section('dmenu')
            CONF.set('dmenu', 'dmenu_command', 'dmenu')
            CONF.add_section('dmenu_passphrase')
            CONF.set('dmenu_passphrase', 'nf', '#222222')
            CONF.set('dmenu_passphrase', 'nb', '#222222')
            CONF.set('dmenu_passphrase', 'rofi_obscure', 'True')
            CONF.add_section('database')
            CONF.set('database', 'database_1', '')
            CONF.set('database', 'keyfile_1', '')
            CONF.set('database', 'pw_cache_period_min', str(CACHE_PERIOD_DEFAULT_MIN))
            CONF.write(conf_file)
    CONF.read(CONF_FILE)
    if CONF.has_option("database", "pw_cache_period_min"):
        CACHE_PERIOD_MIN = int(CONF.get("database", "pw_cache_period_min"))
    else:
        CACHE_PERIOD_MIN = CACHE_PERIOD_DEFAULT_MIN
    if CONF.has_option("dmenu", "l"):
        DMENU_LEN = int(CONF.get("dmenu", "l"))
    else:
        DMENU_LEN = 24


def get_auth():
    """Generate and save port and authkey to ~/.cache/.keepmenu-auth

    Returns: int port, bytestring authkey

    """
    auth_file = expanduser("~/.cache/.keepmenu-auth")
    auth = configparser.ConfigParser()
    if not exists(auth_file):
        with open(auth_file, 'w') as a_file:
            auth.set('DEFAULT', 'port', str(find_free_port()))
            auth.set('DEFAULT', 'authkey', random_str())
            auth.write(a_file)
    auth.read(auth_file)
    port = int(auth.get('DEFAULT', 'port'))
    authkey = auth.get('DEFAULT', 'authkey').encode()
    return port, authkey


def dmenu_cmd(num_lines, prompt="Entries"):  # pylint: disable=too-many-branches
    """Parse config.ini for dmenu options

    Args: args - num_lines: number of lines to display
                 promp: prompt to show
    Returns: command invocation (as a list of strings) for
                dmenu -l <num_lines> -p <prompt> -i ...

    """
    dmenu_command = "dmenu"
    if not CONF.has_section('dmenu'):
        dmenu = [dmenu_command, "-i", "-l",
                 str(min(DMENU_LEN, num_lines)), "-p", str(prompt)]
    else:
        args = CONF.items('dmenu')
        args_dict = dict(args)
        dmenu_args = []
        if "dmenu_command" in args_dict:
            command = shlex.split(args_dict["dmenu_command"])
            dmenu_command = command[0]
            dmenu_args = command[1:]
            del args_dict["dmenu_command"]
        if "rofi" in dmenu_command:
            lines = "-i -dmenu -lines"
            # rofi doesn't support 0 length line, it requires at least -lines=1
            # see https://github.com/DaveDavenport/rofi/issues/252
            num_lines = num_lines or 1
        else:
            lines = "-i -l"
        if "l" in args_dict:
            # rofi doesn't support 0 length line, it requires at least -lines=1
            # see https://github.com/DaveDavenport/rofi/issues/252
            if "rofi" in dmenu_command:
                args_dict['l'] = min(num_lines, int(args_dict['l'])) or 1
            lines = "{} {}".format(lines, args_dict['l'])
            del args_dict['l']
        else:
            lines = "{} {}".format(lines, num_lines)
        if "pinentry" in args_dict:
            del args_dict["pinentry"]
    if prompt == "Passphrase":
        if CONF.has_section('dmenu_passphrase'):
            args = CONF.items('dmenu_passphrase')
            args_dict.update(args)
        rofi_obscure = True
        if CONF.has_option('dmenu_passphrase', 'rofi_obscure'):
            rofi_obscure = CONF.getboolean('dmenu_passphrase', 'rofi_obscure')
            del args_dict["rofi_obscure"]
        if rofi_obscure is True and "rofi" in dmenu_command:
            dmenu_args.extend(["-password"])
    extras = (["-" + str(k), str(v)] for (k, v) in args_dict.items())
    dmenu = [dmenu_command, "-p", str(prompt)]
    dmenu.extend(dmenu_args)
    dmenu += list(itertools.chain.from_iterable(extras))
    dmenu[1:1] = lines.split()
    dmenu = list(filter(None, dmenu))  # Remove empty list elements
    return dmenu


def dmenu_err(prompt):
    """Pops up a dmenu prompt with an error message

    """
    Popen(dmenu_cmd(1, prompt), stdin=PIPE, stdout=PIPE,
          env=ENV).communicate(input='')
    return


def get_database():
    """Read databases from config or ask for user input.

    Returns: (database name, keyfile, passphrase)
             Returns (None, None, None) on error selecting database

    """
    args = CONF.items('database')
    args_dict = dict(args)
    dbases = [i for i in args_dict if i.startswith('database')]
    dbs = []
    for dbase in dbases:
        dbn = expanduser(args_dict[dbase])
        idx = dbase.rsplit('_', 1)[-1]
        try:
            keyfile = expanduser(args_dict['keyfile_{}'.format(idx)])
        except KeyError:
            keyfile = ''
        try:
            passw = args_dict['password_{}'.format(idx)]
        except KeyError:
            passw = ''
        if dbn:
            dbs.append((dbn, keyfile, passw))
    if not dbs:
        res = get_initial_db()
        if res is True:
            dbs = [get_database()]
        else:
            return (None, None, None)
    if len(dbs) > 1:
        inp_bytes = "\n".join(i[0] for i in dbs).encode(ENC)
        sel = Popen(dmenu_cmd(len(dbs), "Select Database"),
                    stdin=PIPE,
                    stdout=PIPE,
                    env=ENV).communicate(input=inp_bytes)[0].decode(ENC).rstrip('\n')
        dbs = [i for i in dbs if i[0] == sel]
        if not sel:
            return (None, None, None)
    return dbs[0]


def get_initial_db():
    """Ask for initial database name and keyfile if not entered in config file

    """
    db_name = Popen(dmenu_cmd(0, "Enter path to existing "
                                 "Keepass database. ~/ for $HOME is ok"),
                    stdin=PIPE,
                    stdout=PIPE).communicate()[0].decode(ENC).rstrip('\n')
    if not db_name:
        dmenu_err("No database entered. Try again.")
        return False
    keyfile_name = Popen(dmenu_cmd(0, "Enter path to keyfile. "
                                   "~/ for $HOME is ok"),
                         stdin=PIPE,
                         stdout=PIPE).communicate()[0].decode(ENC).rstrip('\n')
    with open(CONF_FILE, 'w') as conf_file:
        CONF.set('database', 'database_1', db_name)
        if keyfile_name:
            CONF.set('database', 'keyfile_1', keyfile_name)
        CONF.write(conf_file)
    return True


def get_entries(dbo):
    """Open keepass database and return the PyKeePass entries

        Args: dbo: tuple (db path, keyfile path, password)
        Returns: PyKeePass list of Entry objects

    """
    dbf, keyfile, password = dbo
    if dbf is None:
        return None
    if not password:
        password = get_passphrase()
    try:
        kpo = PyKeePass(dbf, password, keyfile=keyfile)
    except (IOError, OSError, IndexError) as e:  ## pylint: disable=invalid-name
        if e.args[0] == "Master key invalid." or e.args[0] == "No credentials found.":
            dmenu_err("Invalid Password or keyfile")
        elif e.errno == errno.ENOENT:
            dmenu_err("Database does not exist. Edit ~/.config/keepmenu/config.ini")
        return None
    return kpo.entries


def get_passphrase():
    """Get a database password from dmenu or pinentry

    Returns: string

    """
    pinentry = None
    if CONF.has_option("dmenu", "pinentry"):
        pinentry = CONF.get("dmenu", "pinentry")
    if pinentry:
        password = ""
        out = Popen(pinentry,
                    stdout=PIPE,
                    stdin=PIPE).communicate( \
                            input=b'setdesc Enter database password\ngetpin\n')[0]
        if out:
            res = out.decode(ENC).split("\n")[2]
            if res.startswith("D "):
                password = res.split("D ")[1]
    else:
        password = Popen(dmenu_cmd(0, "Passphrase"),
                         stdin=PIPE,
                         stdout=PIPE).communicate()[0].decode(ENC).rstrip('\n')
    return password


def type_entry(entry):
    """Use PyUserInput to type the selected entry username and/or password and
    then 'Enter'.

    """
    kbd = PyKeyboard()
    if entry.username:
        kbd.type_string(entry.username)
        if entry.password:
            kbd.tap_key(kbd.tab_key)
    if entry.password:
        kbd.type_string(entry.password)
    # Not sure why we need n=2, but only seems to work that way
    kbd.tap_key(kbd.enter_key, n=2)


def type_text(data):
    """Use PyUserInput to type the given text data

    """
    kbd = PyKeyboard()
    kbd.type_string(data)


def view_all_entries(options, kp_entries):
    """Generate numbered list of all Keepass entries and open with dmenu.

    Returns: dmenu selection

    """
    num_align = len(str(len(kp_entries)))
    kp_entry_pattern = "{:>{na}} - {} - {} - {}"  # Path,username,url
    # Have to number each entry to capture duplicates correctly
    kp_entries_b = "\n".join([kp_entry_pattern.format(j, i.path, i.username, i.url, na=num_align)
                              for j, i in enumerate(kp_entries)]).encode(ENC)
    if options:
        options_b = ("\n".join(options) + "\n").encode(ENC)
        entries_b = options_b + kp_entries_b
    else:
        entries_b = kp_entries_b
    return Popen(dmenu_cmd(min(DMENU_LEN, len(options) + len(kp_entries))),
                 stdin=PIPE,
                 stdout=PIPE,
                 env=ENV).communicate(input=entries_b)[0].decode(ENC).rstrip('\n')


def view_entry(kp_entry):
    """Show title, username, password, url and notes for an entry.

    Returns: dmenu selection

    """
    fields = [kp_entry.path or "Title: None",
              kp_entry.username or "Username: None",
              '**********' if kp_entry.password else "Password: None",
              kp_entry.url or "URL: None",
              "Notes: <Enter to view>" if kp_entry.notes else "Notes: None"]
    kp_entries_b = "\n".join(fields).encode(ENC)
    sel = Popen(dmenu_cmd(len(fields)), stdin=PIPE, stdout=PIPE,
                env=ENV).communicate(input=kp_entries_b)[0].decode(ENC).rstrip('\n')
    if sel == "Notes: <Enter to view>":
        sel = view_notes(kp_entry.notes)
    elif sel == "Notes: None":
        sel = ""
    elif sel == '**********':
        sel = kp_entry.password
    return sel


def view_notes(notes):
    """View the 'Notes' field line-by-line within dmenu.

    Returns: text of the selected line for typing

    """
    notes_l = notes.split('\n')
    notes_b = "\n".join(notes_l).encode(ENC)
    sel = Popen(dmenu_cmd(min(DMENU_LEN, len(notes_l))), stdin=PIPE, stdout=PIPE,
                env=ENV).communicate(input=notes_b)[0].decode(ENC).rstrip('\n')
    return sel


def client():
    """Define client connection to server BaseManager

    Returns: BaseManager object
    """
    port, auth = get_auth()
    mgr = BaseManager(address=('', port), authkey=auth)
    mgr.register('set_event')
    mgr.register('expire_cache')
    mgr.register('kill')
    mgr.connect()
    return mgr


class DmenuRunner(Process):
    """Listen for dmenu calling event and run keepmenu

    """
    def __init__(self, server, entries):
        Process.__init__(self)
        self.server = server
        self.database = (None, None, None)
        self.entries = entries
        self.cache_timer = Timer(CACHE_PERIOD_MIN * 60, self.cache_time)
        self.cache_timer.daemon = True
        self.cache_timer.start()

    def run(self):
        while True:
            self.server.start_flag.wait()
            if self.server.cache_time_expired.is_set():
                self.cache_timer_reset()
            if not self.entries:
                pass
            else:
                self.dmenu_run()
            if self.server.kill_flag.is_set():
                break
            self.server.start_flag.clear()

    def cache_time(self):
        """Reset keepass Entries list when cache timer expires

        """
        self.server.cache_time_expired.set()

    def cache_timer_reset(self):
        """Reset cache timer and reload entries

        """
        self.database = get_database()
        self.entries = get_entries(self.database)
        if not self.entries:
            return
        self.cache_timer = Timer(CACHE_PERIOD_MIN  * 60, self.cache_time)
        self.cache_timer.daemon = True
        self.cache_timer.start()
        self.server.cache_time_expired.clear()

    def dmenu_run(self):
        """Run dmenu with the given list of Keepass Entry objects

        Args: kp_entries - list of Keepass Entry objects

        """
        options = ['View/Type Individual entries',
                   'Reload database',
                   'Kill Keepmenu daemon']
        sel = view_all_entries(options, self.entries)
        if not sel:
            return
        if sel == options[0]:
            options = []
            sel = view_all_entries(options, self.entries)
            try:
                entry = self.entries[int(sel.split('-', 1)[0])]
            except ValueError:
                return
            text = view_entry(entry)
            type_text(text)
        elif sel == options[1]:
            self.cache_timer_reset()
            self.dmenu_run()
        elif sel == options[2]:
            try:
                self.server.kill_flag.set()
            except (EOFError, IOError):
                return
        else:
            try:
                entry = self.entries[int(sel.split('-', 1)[0])]
            except ValueError:
                return
            type_entry(entry)


class Server(Process):
    """Run BaseManager server to listen for dmenu calling events

    """
    def __init__(self):
        Process.__init__(self)
        self.port, self.authkey = get_auth()
        self.start_flag = Event()
        self.kill_flag = Event()
        self.cache_time_expired = Event()
        self.start_flag.set()

    def run(self):
        serv = self.server()  # pylint: disable=unused-variable
        self.kill_flag.wait()

    def server(self):
        """Set up BaseManager server

        """
        mgr = BaseManager(address=('127.0.0.1', self.port),
                          authkey=self.authkey)
        mgr.register('set_event', callable=self.start_flag.set)
        mgr.register('expire_cache', callable=self.cache_time_expired.set)
        mgr.register('kill', callable=self.kill_flag.set)
        mgr.start()
        return mgr


def run():
    """Main entrypoint. Start the background Manager and Dmenu runner processes.

    """
    database = get_database()
    entries = get_entries(database)
    if not entries:
        sys.exit()
    server = Server()
    dmenu = DmenuRunner(server, entries)
    dmenu.daemon = True
    server.start()
    dmenu.start()
    server.join()
    if exists(expanduser("~/.cache/.keepmenu-auth")):
        os.remove(expanduser("~/.cache/.keepmenu-auth"))


if __name__ == '__main__':
    try:
        MANAGER = client()
        MANAGER.set_event()  # pylint: disable=no-member
    except socket.error:  ## Use socket.error for Python 2 & 3 compat.
        process_config()
        run()

# vim: set et ts=4 sw=4 :
