#!/bin/python

import subprocess
from termcolor import colored
import toml
import argparse
import enum
import pathlib
import glob
import libtmux
import paramiko

#TODO: Configuration can be used to control which ones are checked (default: all)
#TODO: Checks should be applied to every branch
#TODO: testing: create temp repo
#TODO: add (optional) groups for groups of repos

class History(enum.Enum):
    EQUAL    = enum.auto()  # history is the same as remote
    PULL     = enum.auto()  # history is behind remote
    PUSH     = enum.auto()  # history is ahead of remote
    DIVERGED = enum.auto()  # history has diverged from remote

class FetchStatus(enum.Enum):
    SUCCESS    = enum.auto()   # fetch succeeded 
    FAIL       = enum.auto()   # fetch failed
    NO_REMOTE  = enum.auto()   # repo has no remote
    NOT_GIT    = enum.auto()   # not a git repository


def run_git_command(options, git_path):
    """Run a git command, return (stdout, stderr)
    
    Arguments:
        options    list of commands following git
        git_path   path to git repository
    """
    global ssh

    if ssh is not None:
        command = f'cd {git_path}; git ' + ' '.join(options)
        stdin, stdout, stderr = ssh.exec_command(command)
        stdout = ''.join(stdout)
        stderr = ''.join(stderr)
    else:
        cmd = subprocess.run(['git',] + list(options), cwd=git_path, capture_output=True)
        stdout = cmd.stdout.decode()
        stderr = cmd.stderr.decode()
    
    return stdout, stderr

def is_git_repository(git_path):
    """verify that git_path is a git repository"""
    path = pathlib.Path(git_path).expanduser().resolve()
    return (path / '.git').exists()

def git_remote(git_path):
    """git the list of remotes"""
    return run_git_command(['remote'], git_path)[0]

def git_diff(git_path):
    """run git diff"""
    return run_git_command(['diff'], git_path)[0]

def git_fetch_status(git_path, fetch=True):
    """run git fetch (if possible), return a FetchStatus

    Arguments:
        git_path    path to git repository
        fetch       if True, run git fetch
    """

    if not is_git_repository(git_path):
        return FetchStatus.NOT_GIT
    if not git_remote(git_path):
        return FetchStatus.NO_REMOTE

    if fetch:
        out, err = run_git_command(['fetch'], git_path)
        if err:
            return FetchStatus.FAIL

    return FetchStatus.SUCCESS

def local_sha(git_path):
    """get the local SHA"""
    return run_git_command(['rev-parse', '@'], git_path)[0].strip()

def remote_sha(git_path):
    """get the remote SHA"""
    return run_git_command(['rev-parse', '@{u}'], git_path)[0].strip()

def base_sha(git_path):
    """get the remote base SHA"""
    return run_git_command(['merge-base', '@', '@{u}'], git_path)[0].strip()

def git_untracked_files(git_path):
    """get untracked files"""
    return run_git_command(['ls-files', '--others', '--exclude-standard'], git_path)[0]

def git_stash_list(git_path):
    """get list of stashes"""
    return run_git_command(['stash', 'list'], git_path)[0]

def get_history(git_path, fetch=False):
    """Obtain the status of the history relative to the remote

    Arguments:
        git_path   path to git repository
        fetch      (bool) If True, fetch remote data before getting history
    """
    if fetch:
        git_fetch_status(git_path)

    local = local_sha(git_path)
    remote = remote_sha(git_path)
    base = base_sha(git_path)

    if local == remote:
        return History.EQUAL
    elif local == base:
        return History.PULL
    elif remote == base:
        return History.PUSH
    else:
        return History.DIVERGED

def get_remote_branches(git_path):
    """get a list of the remote branch names"""
    out = run_git_command(['ls-remote', '--heads', '-q'], git_path)[0].strip()
    out = out.split()[1::2]   # skip the SHAs

    remote_branches = []
    for path in out:
        remote_branches.append(pathlib.Path(path).name)
    return remote_branches

def get_local_branches(git_path):
    """get a list of the local branch names"""
    out = run_git_command(['for-each-ref', '--format=%(refname)', 'refs/heads/'], git_path)[0].split()

    local_branches = []
    for path in out:
        local_branches.append(pathlib.Path(path).name)
    return local_branches

def get_dangling_branches(git_path):
    """get branches that do not exist on the remote"""
    remote_branches = get_remote_branches(git_path)
    local_branches = get_local_branches(git_path)

    dangling = filter(lambda b: b not in remote_branches, local_branches)
    return list(dangling)

def run_parser():
    """run the command-line parser"""
    parser = argparse.ArgumentParser()

    parser.add_argument('--pull', action='store_true', default=False,
            help="run 'git pull' on all repositories that do not have a conflict")

    parser.add_argument('--push', action='store_true', default=False,
            help="run 'git push' on all repositories that do not have a conflict")

    parser.add_argument('--tmux', action='store_true', default=False,
            help="open all repositories that require changes in tmux windows")

    parser.add_argument('--ssh', type=str,
            help="run git-scan on a different machine over ssh")

    parser.add_argument('--repo', type=str, nargs='*',
            help="run git-scan over the given repositories")

    parser.add_argument('--no-fetch', action='store_true', default=False,
            help="do not run any 'git fetch' commands")

    subparsers = parser.add_subparsers(dest="action")

    add_parser = subparsers.add_parser('add',
            help="add a new repository to the list of scanned repositories")
    add_parser.add_argument('repository', type=str, nargs='*',
            help='list of repositories to be removed')

    remove_parser = subparsers.add_parser('remove',
            help="remove a repository from the list of scanned repositories")
    remove_parser.add_argument('repository', type=str, nargs='*',
            help='list of repositories to be removed')

    list_parser = subparsers.add_parser('list',
            help="list all scannable repositories")
    list_parser.add_argument('--resolve', action='store_true', default=False,
            help="resolve the absolute paths")

    return parser.parse_args()

def get_config_path():
    """get the path to the configuration file"""
    path = pathlib.Path('~/.config/git-scan/git-scan.conf').expanduser()
    return path

def get_config():
    """get the configuration file contents as a dictionary"""
    path = get_config_path()
    config = toml.load(path)
    return config

def add_config(repositories):
    """add new repositories to the configuration file"""
    config_path = get_config_path()
    config = toml.load(config_path)

    for repo in repositories:
        path = pathlib.Path(repo).resolve().as_posix()
        if is_git_repository(path):
            config['repositories'].append(path)
        else:
            print(f"'{path}' is not a git repository")

    toml.dump(config, open(config_path, 'w'))

def remove_config(repositories):
    """remove repositories from the configuration file"""
    config_path = get_config_path()
    config = toml.load(config_path)

    for repo in repositories:
        path = pathlib.Path(repo).resolve()
        for i, other_path in enumerate(config['repositories']):
            try:
                if path.samefile(pathlib.Path(other_path).expanduser()):
                    config['repositories'].pop(i)
                    break

            except FileNotFoundError:  # in case of globs
                if path.match(other_path):
                    config['repositories'].pop(i)
                    break

    toml.dump(config, open(config_path, 'w'))

def get_paths(repositories=[]):
    """get paths to repositories from the config file
    
    Arguments:
        repositories    list of repositories to get paths of (default: all)
    """
    config = get_config()
    repos = config['repositories']

    paths = []
    for repo in repos:
        path = pathlib.Path(repo)
        path = path.expanduser()
        globbed = glob.glob(str(path))
        if len(globbed) > 1:
            paths.extend(globbed)
        else:
            paths.append(str(path))

    paths = list(set(paths))

    if repositories:
        def filter_name(path):
            name = pathlib.Path(path).name
            return name in repositories

        paths = list(filter(filter_name, paths))

    return paths

def print_config_repos(resolve=False):
    """print repositories in the config file"""
    config = get_config()
    repos = config['repositories']
    for repo in repos:
        if not resolve:
            print(repo)
        else:
            path = pathlib.Path(repo).expanduser()
            globbed = glob.glob(str(path))
            if len(globbed) > 1:
                print('\n'.join(globbed))
            else:
                print(path)

def fetch_warning_message(fetch_status):
    """obtain a warning string for a given FetchStatus"""
    warning = colored('WARNING: ', color='red')

    if fetch_status == FetchStatus.SUCCESS:
        return ''
    elif fetch_status == FetchStatus.FAIL:
        warning += 'failed to fetch from remote'
    elif fetch_status == FetchStatus.NOT_GIT:
        warning += 'not a git repository'
    elif fetch_status == FetchStatus.NO_REMOTE:
        warning += 'repository does not have a remote'

    return warning

def corrections_to_make(diff, history, untracked, stashes, dangling_branches):
    """Determine what correction need to be made"""
    corrections = []

    if diff:
        corrections.append('diffs')
    if history != history.EQUAL:
        corrections.append(str(history))
    if untracked:
        corrections.append('untracked files')
    if stashes:
        corrections.append('stashed changes')
    if dangling_branches:
        f = ', '.join([colored(b, attrs=['underline']) for b in dangling_branches])
        corrections.append('branches dangling: ' + f)

    return corrections

def display(path, i, corrections, messages):
    """display the status of a repository
    
    Arguments:
        path    path to repository
        i       index label
        corrections     list of corrections to make
        messages        list of messages to print
    """
    tab = ' '*4 + '• '
    print(colored(f'({i+1}/{len(paths)}) ', color='yellow'), end='')
    print(colored(pathlib.Path(path).name, attrs=['bold'], color='yellow'), end='')
    print(' - ' + path, end='')

    if corrections:
        print(colored(' ✗', color='red'))
        for c in corrections:
            print(tab + c)
    else:
        print(colored(' ✓', color='green'))

    for m in messages:
        print(tab + m)

def attempt_pull(git_path, history, diff):
    """attempt a pull request
    
    Arguments:
        git_path   path to git repository
        history    status of History
        diff       result of git diff
    """
    if history == History.PULL and not diff:
        out, err = run_git_command(['pull'], git_path)
        if not err:
            return True

    return False

def attempt_push(git_path, history):
    """attempt a push request
    
    Arguments:
        git_path   path to git repository
        history    status of History
    """
    if history == History.PUSH:
        out, err = run_git_command(['push'], git_path)
        if not err:
            return True

    return False

def get_ssh_id(ssh_input):
    """Get (username, hostname) for given input. Can be of the following forms:
        (1) username@hostname
        (2) alias in ~/.ssh/config
    """
    try:
        username, hostname = ssh_input.split('@')
    except ValueError:
        config = paramiko.config.SSHConfig()
        ssh_config_file = pathlib.Path('~/.ssh/config').expanduser()
        config.parse(open(ssh_config_file))
        lookup = config.lookup(ssh_input)
        username, hostname = lookup['user'], lookup['hostname']

    return username, hostname

if __name__ == '__main__':
    args = run_parser()

    if args.action == 'add':
        add_config(args.repository)
        exit()
    elif args.action == 'remove':
        remove_config(args.repository)
        exit()
    elif args.action == 'list':
        print_config_repos(resolve=args.resolve)
        exit()


    paths = get_paths(args.repo)

    if args.ssh and args.tmux:
        print('git-scan: combining ssh and tmux options is not currently supported')
        exit(0)

    ssh = None
    if args.ssh:
        ssh = paramiko.SSHClient()
        ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
        username, hostname = get_ssh_id(args.ssh)
        ssh.connect(hostname=hostname, username=username, timeout=5)

    if args.tmux:
        tmux_server = libtmux.Server()
        name = 'git-scan'
        tmux_session = tmux_server.find_where(dict(session_name=name))
        if tmux_session is None:
            tmux_session = tmux_server.new_session(session_name=name, attach=False)
        first_window = True

    for i, path in enumerate(paths):
        fetch_status = git_fetch_status(path, not args.no_fetch)

        diff = git_diff(path)
        history = get_history(path)
        untracked = git_untracked_files(path)
        stashes = git_stash_list(path)
        dangling_branches = get_dangling_branches(path)

        messages = []

        if args.pull:
            if attempt_pull(path, history, diff):
                history = History.EQUAL
                messages.append(colored('CHANGES PULLED', color='green'))

        if args.push:
            if attempt_push(path, history):
                history = History.EQUAL
                messages.append(colored('CHANGES PUSHED', color='green'))

        if fetch_status != FetchStatus.SUCCESS:
            messages.append(fetch_warning_message(fetch_status))

        corrections = corrections_to_make(diff=diff, history=history, 
                untracked=untracked, stashes=stashes, dangling_branches=dangling_branches)

        if args.tmux and corrections:
            name = pathlib.Path(path).name
            if first_window:
                tmux_server.switch_client('git-scan')
                window = tmux_session.windows[0]
                window.rename_window(name)
                pane = window.attached_pane
                pane.send_keys(f'cd {path}', enter=True)
                first_window = False
            else:
                window = tmux_session.new_window(window_name=name, start_directory=path, attach=False)

            pane = window.attached_pane
            pane.send_keys('git status', enter=True)
            pane.send_keys(f'git-scan --no-fetch --repo {name}', enter=True)

        display(path, i, corrections=corrections, messages=messages)
