#!/usr/bin/env python
# -*- coding:utf-8 -*-
#
# Author: Pablo Saavedra
# Maintainer: Pablo Saavedra
# Contact: saavedra.pablo at gmail.com

import argparse
import ldap
import logging
import psycopg2
import sys
import ConfigParser

try:
    reload(sys)
    sys.setdefaultencoding('utf-8')  # Forcing UTF-8 in the enviroment:
    # http://stackoverflow.com/questions/3828723/why-we-need-sys-setdefaultencodingutf-8-in-a-py-scrip
except Exception:
    pass


## GLOBAL VARS #######################################################

conffile = ".gogs-ldap-groups-syncer.cfg"

settings = {}
settings["global"] = {
    "loglevel": 10,
    "logfile": "/dev/stdout",
}
settings["ldap"] = {
    "server": "ldap://ldap.local",
    "base": "ou=People,dc=example,dc=com",
    "filter": "(&(objectClass=posixGroup)(cn={group_name}))"
}
settings["postgres"] = {}
settings["ldap-map"] = {}

sync_token = None

DEV_NULL = '/dev/null'


# Functions ##########################################################

def setup(conffile, settings):
    cfg = ConfigParser.ConfigParser()
    cfg.read(conffile)

    for s in cfg.sections():
        for k, v in cfg.items(s):
            if s == "global":
                if k == "loglevel":
                    settings[s][k] = int(v)
                elif k == "dryrun":
                    if v.lower() in ("1", "true"):
                        settings[s][k] = True
                    else:
                        settings[s][k] = False
                else:
                    settings[s][k] = v
            else:
                if s.startswith("org__"):
                    org_name = s[5:]
                    if org_name not in settings["org"]:
                        settings["org"][org_name] = {}
                    if k.startswith("team__"):
                        ldap_name = k[6:]
                        settings["org"][org_name][ldap_name] = v
                else:
                    settings[s][k] = v


def debug_conffile(settings, logger):
    for s in settings.keys():
        for k in settings[s].keys():
            key = "%s.%s" % (s, k)
            value = settings[s][k]
            logger.debug("Configuration setting - %s: %s" % (key, value))


def get_ldap_group_members(ldap_settings, group_name, logger=None):
    ldap_server = ldap_settings["server"]
    ldap_base = ldap_settings["base"]
    ldap_filter = ldap_settings["filter"]
    get_uid = lambda x: x.split(",")[0].split("=")[1]
    try:
        ad_filter = ldap_filter.replace('{group_name}', group_name)
        conn = ldap.initialize(ldap_server)
        logger.debug("Searching members for %s" % (group_name))
        res = conn.search_s(ldap_base, ldap.SCOPE_SUBTREE, ad_filter)
    except Exception, e:
        if logger:
            logger.error("Error getting group from LDAP: %s" % e)

    return map(get_uid, res[0][1]['uniqueMember'])


def psql_conn(postgres_settings):
    conn = psycopg2.connect("dbname='%(dbname)s' user='%(user)s' host='%(host)s' password='%(password)s'" % postgres_settings)
    return conn


def psql_execute(cursor, sql_statement, dry_run=False):
    if dry_run is False:
        logger.debug(sql_statement)
        cursor.execute(sql_statement)
    else:
        logger.debug("DRY-RUN: " + sql_statement)


def get_org_id(cursor, org_name):
    _sql = """select id from "user" where "user".lower_name = '%s'"""
    psql_execute(cursor, _sql % org_name.lower())
    rows = cursor.fetchall()
    return map(lambda x: x[0], rows)[0]


def get_user_id(cursor, user_name):
    _sql = """select id from "user" where "user".lower_name = '%s'"""
    psql_execute(cursor, _sql % user_name.lower())
    rows = cursor.fetchall()
    return map(lambda x: x[0], rows)[0]


def count_user_teams_in_org(cursor, org_id, user_id):
    _sql = """select count(*) from team_user where org_id=%s and uid=%s"""
    psql_execute(cursor, _sql % (org_id, user_id))
    rows = cursor.fetchall()
    return map(lambda x: x[0], rows)[0]


def is_org_user(cursor, org_id, user_id):
    _sql = """select count(*) from org_user where org_id=%s and uid=%s"""
    psql_execute(cursor, _sql % (org_id, user_id))
    rows = cursor.fetchall()
    return bool(map(lambda x: x[0], rows)[0])


def is_repo_team(cursor, repo_id, team_id):
    _sql = """select count(*) from team_repo where repo_id=%s and team_id=%s"""
    psql_execute(cursor, _sql % (repo_id, team_id))
    rows = cursor.fetchall()
    return bool(map(lambda x: x[0], rows)[0])


def get_team_id(cursor, org_name, team_name):
    _sql = """select "team".id from team inner join "user" on "team".org_id = "user".id where "user".lower_name = '%s' and "team".lower_name = '%s'"""
    psql_execute(cursor, _sql % (org_name.lower(), team_name.lower()))
    rows = cursor.fetchall()
    return map(lambda x: x[0], rows)[0]


def get_table_attribute(cursor, table_name, row_id, attribute_name):
    if type(row_id) == dict:
        _sql = """select %s from %s where %s"""
        where_ids = " and ".join(map(lambda x: "%s = '%s'" % x, row_id.items()))
        psql_execute(cursor, _sql % (attribute_name, table_name, where_ids))
    else:
        _sql = """select %s from %s where id = '%s'"""
        psql_execute(cursor, _sql % (attribute_name, table_name, row_id))
    rows = cursor.fetchall()
    return map(lambda x: x[0], rows)[0]


def get_organizations(cursor):
    _sql = """select "user".lower_name from "user" where type = 1"""
    psql_execute(cursor, _sql)
    rows = cursor.fetchall()
    return map(lambda x: x[0], rows)


def get_repo_ids(cursor, org_id):
    _sql = """select team_repo.repo_id from team_repo inner join team on team_repo.team_id =
team.id where team_repo.org_id = '%s' and team.lower_name = 'owners'"""
    psql_execute(cursor, _sql % org_id)
    rows = cursor.fetchall()
    return map(lambda x: x[0], rows)


def get_organization_team_names(cursor, organization_name):
    _sql = """select "team".lower_name from team inner join "user" on "team".org_id = "user".id where "user".lower_name = '%s'"""
    psql_execute(cursor, _sql % organization_name.lower())
    rows = cursor.fetchall()
    return map(lambda x: x[0], rows)


def get_team_members(cursor, org_name, team_name):
    _sql = """select member.lower_name from team_user join "user" member on
member.id = team_user.uid inner join "user" org on team_user.org_id = org.id
inner join team on team_user.team_id = team.id where org.lower_name =
'%s' and team.lower_name = '%s'"""
    psql_execute(cursor, _sql % (org_name.lower(), team_name.lower()))
    rows = cursor.fetchall()
    return map(lambda x: x[0], rows)


def update_table_attribute(cursor, table_name, row_id, attribute_name, attribute_value):
    if type(row_id) == dict:
        _sql = """update %s set %s = %s where %s"""
        where_ids = " and ".join(map(lambda x: "%s = '%s'" % x, row_id.items()))
        psql_execute(cursor, _sql % (table_name, attribute_name,
                                     attribute_value, where_ids),
                     dry_run=settings["global"]["dryrun"])
    else:
        _sql = """update %s set %s = '%s' where id = %s"""
        psql_execute(cursor, _sql % (table_name, attribute_name,
                                     attribute_value, row_id),
                     dry_run=settings["global"]["dryrun"])


def remove_org_member(cursor, org_id, user_id):
    _sql = """delete from org_user where org_id = %s and uid = %s"""
    psql_execute(cursor, _sql % (org_id, user_id), dry_run=settings["global"]["dryrun"])
    if cursor.rowcount < 1:
        logger.warning("%s is already not a member of the %s organization" % (user_id, org_id))
        return
    num_members = int(get_table_attribute(cursor, '"user"', org_id, "num_members"))
    update_table_attribute(cursor, '"user"', org_id, "num_members", num_members - 1)


def remove_team_member(cursor, org_id, team_id, user_id):
    _sql = """delete from team_user where team_id = %s and uid = %s"""
    psql_execute(cursor, _sql % (team_id, user_id), dry_run=settings["global"]["dryrun"])
    if cursor.rowcount < 1:
        logger.warning("%s is already not a member of the %s team" % (user_id, team_id))
        return

    num_members = int(get_table_attribute(cursor, "team", team_id, "num_members"))
    update_table_attribute(cursor, "team", team_id, "num_members", num_members - 1)

    c_user_teams_in_org = int(count_user_teams_in_org(cursor, org_id, user_id))
    if c_user_teams_in_org < 1:
        remove_org_member(cursor, org_id, user_id)
    else:
        ids = {"uid": user_id, "org_id": org_id}
        num_teams = int(get_table_attribute(cursor, 'org_user', ids, "num_teams"))
        update_table_attribute(cursor, "org_user", ids, "num_teams", num_teams - 1)


def add_team_repo(cursor, org_id, team_id, repo_id):
    logger.info("Added repo (%s) in team (%s) of the org (%s)" % (repo_id, team_id, repo_id))
    _sql = """insert into team_repo (org_id, team_id, repo_id) values (%s, %s, %s)"""
    psql_execute(cursor, _sql % (org_id, team_id, repo_id), dry_run=settings["global"]["dryrun"])
    num_repos = int(get_table_attribute(cursor, 'team', team_id, "num_repos"))
    update_table_attribute(cursor, 'team', team_id, "num_repos", num_repos + 1)


def add_org_repos_in_team(cursor, org_id, team_id):
    for repo_id in get_repo_ids(cursor, org_id):
        if not is_repo_team(cursor, repo_id, team_id):
            add_team_repo(cursor, org_id, team_id, repo_id)


def add_org_member(cursor, org_id, user_id):
    if is_org_user(cursor, org_id, user_id):
        logger.info("%s is already a member of the %s org" % (user_id, org_id))
        return

    _sql = """insert into org_user (org_id, uid, num_teams, is_public, is_owner) values (%s, %s, 0, false, false)"""
    psql_execute(cursor, _sql % (org_id, user_id), dry_run=settings["global"]["dryrun"])
    num_members = int(get_table_attribute(cursor, '"user"', org_id, "num_members"))
    update_table_attribute(cursor, '"user"', org_id, "num_members", num_members + 1)


def add_team_member(cursor, org_id, team_id, user_id):
    _sql = """insert into team_user (org_id, team_id, uid) values (%s, %s, %s)"""
    psql_execute(cursor, _sql % (org_id, team_id, user_id), dry_run=settings["global"]["dryrun"])
    if cursor.rowcount < 1:
        logger.warning("%s is already a member of the %s team" % (user_id, team_id))
        return

    num_members = int(get_table_attribute(cursor, "team", team_id, "num_members"))
    update_table_attribute(cursor, "team", team_id, "num_members", num_members + 1)

    add_org_member(cursor, org_id, user_id)

    ids = {"uid": user_id, "org_id": org_id}
    try:
        num_teams = int(get_table_attribute(cursor, 'org_user', ids, "num_teams"))
    except Exception:
        num_teams = 0
    update_table_attribute(cursor, "org_user", ids, "num_teams", num_teams + 1)


## command line options parser #######################################
parser = argparse.ArgumentParser()
parser.add_argument("-c", "--conffile", dest="conffile", default=conffile,
                    help="Conffile (default: %s)" % conffile)
args = parser.parse_args()
conffile = args.conffile

# setting up #########################################################
setup(conffile, settings)

## logging ###########################################################
hdlr = logging.FileHandler(settings["global"]["logfile"])
hdlr.setFormatter(logging.Formatter('%(levelname)s %(asctime)s %(message)s'))
logger = logging.getLogger('ldap2gogs')
logger.addHandler(hdlr)
logger.setLevel(settings["global"]["loglevel"])
logger.debug("Default encoding: %s" % sys.getdefaultencoding())
debug_conffile(settings, logger)

## main ##############################################################
dry_run = settings["global"]["dryrun"]
if __name__ == '__main__':
    conn = psql_conn(settings["postgres"])
    cursor = conn.cursor()
    for org_name in get_organizations(cursor):
        org_id = get_org_id(cursor, org_name)
        for team_name in get_organization_team_names(cursor, org_name):
            team_id = get_team_id(cursor, org_name, team_name)
            logger.info("Processing team " + org_name + " :: " + team_name)

            add_org_repos_in_team(cursor, org_id, team_id)

            if team_name in settings["ldap-map"]:
                ldap_group = settings["ldap-map"][team_name]
                ldap_members = get_ldap_group_members(settings["ldap"],
                                                      group_name=ldap_group,
                                                      logger=logger)

                for member_name in get_team_members(cursor, org_name, team_name):
                    if member_name not in ldap_members:
                        logger.info("Remove account %s from the team %s" % (member_name, team_name))
                        user_id = get_user_id(cursor, member_name)
                        remove_team_member(cursor, org_id, team_id, user_id)
                    else:
                        logger.info("Skipped account %s" % (member_name))
                        ldap_members.remove(member_name)
                # adding the pending accounts which wasn't before
                for member_name in ldap_members:
                    logger.info("Add %s account in the team %s" % (member_name, team_name))
                    try:
                        user_id = get_user_id(cursor, member_name)
                        try:
                            add_team_member(cursor, org_id, team_id, user_id)
                        except Exception, e:
                            logger.error("Error adding %s account in the team %s: %s" % (member_name, team_name, e))
                    except Exception:
                        logger.warning("User %s is not logged in gogs yet" % (member_name))

    if dry_run is False:
        logger.info("Make the changes in the database persistent")
        # Make the changes to the database persistent
        conn.commit()

    # Close communication with the database
    cursor.close()
