#! /usr/bin/env python
# -*- coding: utf-8 -*-

import sys
import os
import traceback
import shutil
import re
import argparse
import time
import yaml
import ConfigParser
import jinja2

from os import path
from subprocess import check_call, check_output, CalledProcessError

VERSION = '0.2.0'
BASEDIR = path.expanduser('~/.augploy')
SECRET_VARS = ['mysql.root']


class AugployError(Exception):
    def __init__(self, type_, msg):
        self.msg = 'augploy %s error: %s' % (type_, msg)

    def __str__(self):
        return self.msg


class RelaxedUndefined(jinja2.Undefined):
    def __getattr__(self, attribute):
        return ''


def log(msg):
    print(msg)


def makedirs(dir):
    if not path.isdir(dir):
        os.makedirs(dir)


def copytree(src, dst, symlinks=False, ignore=None):
    names = os.listdir(src)
    if ignore is not None:
        ignored_names = ignore(src, names)
    else:
        ignored_names = set()

    if not os.path.isdir(dst):
        os.makedirs(dst)

    errors = []
    for name in names:
        if name in ignored_names:
            continue
        src_path = os.path.join(src, name)
        dst_name = os.path.join(dst, name)
        try:
            if symlinks and os.path.islink(src_path):
                linkto = os.readlink(src_path)
                os.symlink(linkto, dst_name)
            elif os.path.isdir(src_path):
                copytree(src_path, dst_name, symlinks, ignore)
            else:
                # Will raise a SpecialFileError for unsupported file types.
                shutil.copy2(src_path, dst_name)
        # Catch the Error from the recursive copytree so that we can continue with other files.
        except shutil.Error as err:
            errors.extend(err.args[0])
        except EnvironmentError as why:
            errors.append((src_path, dst_name, str(why)))
    try:
        shutil.copystat(src, dst)
    except OSError as why:
        if WindowsError is not None and isinstance(why, WindowsError):
            # Copying file access times may fail on Windows.
            pass
        else:
            errors.extend((src, dst, str(why)))
    if errors:
        raise shutil.Error(errors)


def load_yml_file(yml_path):
    with open(yml_path) as yml_file:
        return yaml.load(yml_file)


def dump_yml_file(obj, yml_path):
    with open(yml_path, 'wb') as yml_file:
        yml_file.writelines(['---\n'])
        yml_file.write(yaml.dump(obj, default_flow_style=False))


def deep_in_place_update(d, s):
    for k, v in s.iteritems():
        if type(v) is dict:
            d[k] = deep_in_place_update(d.get(k, {}), v)
        else:
            d[k] = s[k]
    return d


def deep_in_place_merge(obj1, obj2):
    # obj1 precede over obj2.
    deep_in_place_update(obj2, obj1)
    deep_in_place_update(obj1, obj2)


def is_scp_style(repo_url):
    return re.match(r'(.*@|.*).*:.*', repo_url)


def parse_repo_url(repo_url):
    tmp = repo_url.split('#')
    repo_url = tmp[0]
    repo_revision = 'master'
    if len(tmp) > 1:
        repo_revision = str.join('#', tmp[1:])
    repo_name = repo_url.split(':')[1].replace('.git', '').replace('/', '-')
    return (repo_url, repo_name, repo_revision)


def prepare_repo(repo_url, tmp_path):
    log('prepare repo %s ...' % repo_url)

    # We take this repo_url as local file path if it doesn't matche scp style, we need to copy it to tmp path.
    # If it does match, we need to fetch it from remote, export it to tmp path with specified revision.
    if not is_scp_style(repo_url):
        repo_name = path.basename(repo_url.rstrip('/'))
        tmp_repo_path = path.join(tmp_path, repo_name)
        copytree(repo_url, tmp_repo_path, symlinks=True,
                 ignore=lambda src, names: names if path.basename(src) == '.git' else [])
        return tmp_repo_path

    repos_path = path.join(BASEDIR, 'repos')
    makedirs(repos_path)

    (repo_url, repo_name, repo_revision) = parse_repo_url(repo_url)
    repo_path = path.join(repos_path, repo_name)
    tmp_repo_path = path.join(tmp_path, repo_name)

    if not path.isdir(repo_path):
        check_call('git clone --mirror %s %s' % (repo_url, repo_path), shell=True)

    os.chdir(repo_path)
    log('    update ...')
    check_output('git remote update', shell=True)

    makedirs(tmp_repo_path)
    log('    checkout %s ...' % repo_revision)
    check_output('git archive %s | (cd %s && tar xf -)' % (repo_revision, tmp_repo_path), shell=True)
    return tmp_repo_path


def check_bin_version(ap_repo_path):
    augploy_content = open(path.join(ap_repo_path, 'bin', 'augploy')).read()
    version_matches = re.findall(r"^VERSION\s*=\s*'([\.\d]+)'$", augploy_content, re.MULTILINE)
    version = version_matches[0]

    if version != VERSION:
        raise AugployError('bin', 'this utility is outdated, please run `pip install -U augploy` to upgrade')


def parse_include(config, config_dir):
    if type(config) is dict:
        include_configs = []
        for k, v in config.iteritems():
            if k == 'include':
                if type(v) is not list:
                    raise AugployError('config', 'include path(%s) should be a list' % v)

                for include_path in v:
                    include_path = path.join(config_dir, include_path)
                    if not path.isfile(include_path):
                        raise AugployError('config', 'include path(%s) not found' % include_path)

                    include_config = load_yml_file(include_path)
                    parse_include(include_config, config_dir)
                    include_configs.append(include_config)
            else:
                if type(v) is dict or type(v) is list:
                    parse_include(v, config_dir)

        if len(include_configs) == 0:
            return

        for include_config in include_configs:
            deep_in_place_merge(config, include_config)

        config.pop('include', None)
        return config
    elif type(config) is list:
        include_configs = {}
        for i, v in enumerate(config):
            include_config = parse_include(v, config_dir)
            if include_config is not None:
                include_configs[i] = include_config

        if len(include_configs.keys()) == 0:
            return

        for k, v in include_configs.iteritems():
            include_index = k
            config = config[0:include_index] + list(include_configs[include_index]) + config[include_index + 1:]


def check_secret_vars_access(string, variable_start_string, variable_end_string):
    ''' Check if this string contains jinja2 expression accessing secret vars. '''
    for var in SECRET_VARS:
        var_keys = var.split('.')
        fake_value = str(int(time.time()))
        fake_var_dict = fake_value
        for i in range(len(var_keys) - 1, -1, -1):
            fake_var_dict = {
                var_keys[i]: fake_var_dict
            }

        env = jinja2.Environment(
            undefined=RelaxedUndefined,
            variable_start_string=variable_start_string, variable_end_string=variable_end_string)
        tpl = env.from_string(string.decode('utf-8'))
        rendered = tpl.render(fake_var_dict)
        if rendered.find(fake_value) != -1:
            raise AugployError('config', 'accessing of secret var(%s) is denied' % var)


def parse_vars(config, repo_path=None):
    global_vars = config.get('global_vars', [])
    if 'repo' in config:
        repo = config['repo']
        repo['path_local'] = repo_path
        repo['path'] = '/srv/%s' % repo['name']
        global_vars['repo'] = repo

    # Change the default delimiter '{{', '}}' to avoid conflict with yaml syntax,
    # so we can refer to item without quotes, therefore, referring as dict/list is possible.
    env = jinja2.Environment('<%', '%>', '<<', '>>', '<#', '#>')
    config_str = yaml.dump(config)
    check_secret_vars_access(config_str, '<<', '>>')
    tpl = env.from_string(config_str)
    rendered = tpl.render(global_vars)
    config = yaml.load(rendered)
    return config


def parse_config(config_dir, config_path, repo_path=None):
    config = load_yml_file(config_path)
    parse_include(config, config_dir)
    config = parse_vars(config, repo_path)
    return config


def sort_deploys(deploys):
    deploy_types = set()
    deploys_dependon_repo = []
    deploys_not_dependon_repo = []
    hosts_to_deploy_repo = set()
    for deploy in deploys:
        deploy_types.add(deploy['type'])
        if 'deploy_repo' in deploy and deploy['deploy_repo'] is True:
            deploys_dependon_repo.append(deploy)

            hosts = deploy['hosts']
            hosts_yaml_str = yaml.dump(hosts)
            # Yeah, this is dirty...
            host_names = re.findall(r"name:\s*(.+)\n", hosts_yaml_str, re.MULTILINE)
            host_names = [host_name.rstrip('}') for host_name in host_names]
            hosts_to_deploy_repo = hosts_to_deploy_repo.union(set(host_names))
        else:
            deploys_not_dependon_repo.append(deploy)

    if len(deploy_types) < len(deploys):
        raise AugployError('config', 'mutiple deploys of same type in one single config file is not supported')

    deploy_of_repo = [{
        'type': 'repo',
        'hosts': [{'name': host} for host in hosts_to_deploy_repo]
    }]
    sorted_deploys = deploys_not_dependon_repo + deploy_of_repo + deploys_dependon_repo
    return sorted_deploys


def gen_common_paly():
    return {
        'name': 'setup common environment',
        'hosts': 'all',
        'roles': ['common']
    }


def gen_repo_play(config):
    repo = config['repo']
    repo_path_local = repo['path_local']
    repo_path = repo['path']
    tasks = []
    play = {
        'name': 'setup repo',
        'hosts': 'repo',
        'roles': [],
        'tasks': tasks
    }

    for engine in repo['engines']:
        engine_type = engine['type']
        role = {
            'role': engine_type
        }
        if 'vars' in engine:
            global_vars = config.get('global_vars', {})
            deep_in_place_update(global_vars, engine['vars'])
        play['roles'].append(role)

    tasks.append({
        'name': 'repo | rsync',
        'rsync_repo': 'src=%s/ dest=%s' % (repo_path_local, repo_path),
        'register': 'rsync_repo_result'
    })
    tasks.append({
        'name': 'repo | set permission',
        'file': 'path=%s owner=www-data group=www-data recurse=yes' % repo_path
    })

    templates = repo.get('templates', [])
    for template in templates:
        src = template['src']
        dest = template['dest']

        template_str = open(path.join(repo_path_local, src)).read()
        check_secret_vars_access(template_str, '{{', '}}')

        tasks.append({
            'name': 'repo | templates | from %s to %s' % (src, dest),
            'template': 'src=%s/%s dest=%s/%s' % (repo_path_local, src, repo_path, dest)
        })

    build_steps = repo.get('build_steps', [])
    for step in build_steps:
        tasks.append({
            'name': 'repo | build_steps | %s' % step['name'],
            'shell': 'chdir=%s %s' % (repo_path, step['shell'])
        })

    return play


def parse_inventory_and_vars(group, inventory, group_vars, host_vars, group_chain=None):
    group_name = group['group']
    hosts = group['hosts']
    group_chain = [] if group_chain is None else group_chain
    group_chain.append(group_name)

    if 'vars' in group:
        # Doesn't support multiple deploys of same type in one single config file.
        group_vars[group_name] = group['vars']

    host_names = []
    group_names = []
    for item in hosts:
        if 'name' in item:
            host = item
            host_name = host['name']
            host_names.append(host_name)
            hv = host_vars.get(host_name, {})
            hv_group_chain = hv.get('__groups', [])
            # Used to obtain group chain of this host.
            hv['__groups'] = hv_group_chain + group_chain
            if 'vars' in host:
                host_vars[host_name] = deep_in_place_update(hv, host['vars'])
        else:
            sub_group = item
            sub_group_name = sub_group['group']
            group_names.append(sub_group_name)
            parse_inventory_and_vars(sub_group, inventory, group_vars, host_vars, group_chain)

    if len(host_names) > 0 and len(group_names) > 0:
        raise AugployError('config', 'host lists and group lists are not compatible')

    if len(host_names) > 0:
        inventory.add_section(group_name)
        for host in host_names:
            inventory.set(group_name, host)
    else:
        group_name = '%s:children' % group_name
        inventory.add_section(group_name)
        for group in group_names:
            inventory.set(group_name, group)


def run_playbook(args, ansible_args, tmp_path):
    repo_path = prepare_repo(args.repo, tmp_path) if args.no_repo is False else None
    ap_repo_path = prepare_repo(args.ap_repo, tmp_path)
    check_bin_version(ap_repo_path)

    # Parse config file.
    config_dir = path.join(ap_repo_path, 'merged_configs')
    if repo_path is not None:
        repo_augploy_path = path.join(repo_path, 'augploy')
        if path.isdir(repo_augploy_path):
            copytree(repo_augploy_path, config_dir, symlinks=True)
    copytree(path.join(ap_repo_path, 'configs'), config_dir, symlinks=True)
    config_name = path.basename(args.config_file).rstrip('.yml')
    config_path = path.join(config_dir, args.config_file)
    if not path.isfile(config_path):
        raise AugployError('config', 'config file(%s) not found' % args.config_file)
    config = parse_config(config_dir, config_path, repo_path)

    # Parse deploys to generate inventory, host_vars and playbook.
    deploys = config['deploys']
    sorted_deploys = deploys
    if 'repo' in config:
        sorted_deploys = sort_deploys(deploys)

    inventory = ConfigParser.RawConfigParser(allow_no_value=True)
    group_vars = {}
    host_vars = {}
    plays = [gen_common_paly()]
    for deploy in sorted_deploys:
        deploy_type = deploy['type']
        group = {
            'group': deploy_type,
            'hosts': deploy['hosts']
        }
        if 'vars' in deploy:
            group['vars'] = deploy['vars']
        parse_inventory_and_vars(group, inventory, group_vars, host_vars)

        if deploy['type'] == 'repo':
            plays.append(gen_repo_play(config))
        else:
            play_path = path.join(ap_repo_path, 'plays', '%s.yml' % deploy_type)
            play = load_yml_file(play_path)
            if type(play) is dict:
                plays.append(play)
            else:
                plays = plays + play

    inventory_path = path.join(ap_repo_path, '%s_hosts' % config_name)
    with open(inventory_path, 'wb') as inventory_file:
        inventory.write(inventory_file)

    # Recursively merge global vars, group vars, host vars into host_vars file, in order,
    # otherwise there can't be duplicate key among them, because ansible-playbook's behavior is not recursively merging
    # damn it, this is complicated...
    for host, vars_ in host_vars.iteritems():
        merged_vars = deep_in_place_update({}, config.get('global_vars', {}))
        groups = vars_.pop('__groups')
        for group in groups:
            tmp_group_vars = group_vars.get(group, {})
            if len(tmp_group_vars.keys()) == 0:
                continue
            deep_in_place_update(merged_vars, tmp_group_vars)
        deep_in_place_update(merged_vars, vars_)

        if len(merged_vars.keys()) > 0:
            host_vars_path = path.join(ap_repo_path, 'host_vars', '%s.yml' % host)
            dump_yml_file(merged_vars, host_vars_path)

    playbook_path = path.join(ap_repo_path, '%s.yml' % config_name)
    dump_yml_file(plays, playbook_path)

    # Actually run playbook.
    os.chdir(ap_repo_path)
    ansible_cmd = 'ansible-playbook -i %s %s %s' % (inventory_path, playbook_path, str.join(' ', ansible_args))
    log('run: %s' % ansible_cmd)
    check_call(ansible_cmd, shell=True)


def main(origin_args):
    # Args parsing.
    parser = argparse.ArgumentParser(
        description='augploy - AUGmentum dePLOYment automation tool, powered by ansible',
        epilog='''
                repo url format:
                1. scp style, eg. git@git.augmentum.com.cn:ops/augploy.git#master, '#master' part is optional,
                can be used to specify git revision: branch name or commit id, default is master
                2. local absolute directory path, eg. /home/user/workspace/ops/augploy,
                git revision is not supported in this type.
                ''',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    args_repo_group = parser.add_mutually_exclusive_group()
    args_repo_group.add_argument(
        '-r', '--repo', default=os.getcwd(),
        help='url of the repo to deploy, default is current working directory for quick test')
    args_repo_group.add_argument(
        '-n', '--no_repo', action='store_true',
        help='to omit the defaut value of -r(--repo), in case there is no repo need to deploy')
    parser.add_argument(
        '-R', '--ap_repo', default='git@git.augmentum.com.cn:ops/augploy.git',
        help='url of augploy repo in which ansible playbooks is stored, \
              this option is for dev purpose, otherwise you should use the default value')
    parser.add_argument(
        'config_file',
        help='config file path, relative to augploy directory in the repo to deploy, \
              or relative to root directory in augploy repo')
    parser.add_argument('-V', action='version', version='%%(prog)s %s' % VERSION)

    args, ansible_args = parser.parse_known_args(origin_args)

    # In order to avoid confilcts when run multiple config files at the same time,
    # prepare tmp dir for exporting repos.
    tmp_path = path.join(BASEDIR, str(time.time()))
    makedirs(tmp_path)
    exit_code = 0

    try:
        run_playbook(args, ansible_args, tmp_path)
    except CalledProcessError as err:
        exit_code = err.returncode
    except AugployError as err:
        exit_code = 1
        log(err)
    except Exception:
        exit_code = 1
        traceback.print_exc()
    finally:
        shutil.rmtree(tmp_path)
        sys.exit(exit_code)

if __name__ == '__main__':
    main(sys.argv[1:])
