#!/usr/bin/env python
#
# Copyright 2018 Rick Chang <chchang915@gmail.com>
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys, os, netrc, json, subprocess, argparse, shlex
import xml.etree.ElementTree as etree
import ssl

try:
    from urllib.parse import urlparse, urlencode, quote
    from urllib.error import HTTPError
    import urllib.request as request
except ImportError:
    from urlparse import urlparse
    from urllib import urlencode, quote
    import urllib2 as request

# ------ Customization ------
REPO = 'repo'
CONNECT_TIMEOUT = 30
FETCH_PROTOCOL_ORDER = 'http ssh git'
DEFAULT_NETRC_PATH = '~/.netrc'
DEFAULT_GERRIT_SERVER = 'TBD'
DELIMITER = '-' * 80
DELIMITER_END = '=' * 80
PREVIEW = 'git log --no-decorate -1'
# ---------------------------

CUR_PATH    = os.getcwd()
RED         = "\x1b[31m"
GREEN       = "\x1b[32m"
NONE        = "\x1b[0m"

def Log(s):
    print(s)
    sys.stdout.flush()

def Loge(s):
    Log(RED + s + NONE)

def Logm(s):
    Log(GREEN + s + NONE)

def RunShell(cmd):
    return subprocess.call(cmd, shell=True)

def Run(cmd):
    try:
        process = subprocess.Popen(shlex.split(cmd), bufsize=0)
        process.communicate()
        return process.returncode
    except Exception as e:
        Loge("Run cmd '%s' fail. (%s)" % (cmd, e))
        return 1

def RunLog(cmd, silent=False):
    try:
        ret = subprocess.check_output(shlex.split(cmd),
                                      stderr=subprocess.STDOUT)
        return ret.decode('utf-8').rstrip()
    except Exception as e:
        if not silent:
            Loge("Run '%s' fail. (%s)" % (cmd, e))
        return None

class GerritRest:
    # https://github.com/GerritCodeReview/gerrit/blob/master/Documentation/rest-api.txt
    def __init__(self, url, username, password):
        url = url.rstrip('/')
        password_mgr = request.HTTPPasswordMgrWithDefaultRealm()
        password_mgr.add_password(None, url, username, password)
        digest_auth = request.HTTPDigestAuthHandler(password_mgr)
        basic_auth = request.HTTPBasicAuthHandler(password_mgr)
        opener = request.build_opener(digest_auth, basic_auth)
        request.install_opener(opener)
        self.server = url

    def query(self, q, verbose=False):
        headers = {
            "Accept": "application/json",
            "Content-Type": "application/json;charset=UTF-8"
        }

        try:
            _create_unverified_https_context = ssl._create_unverified_context
        except AttributeError:
            pass
        else:
            ssl._create_default_https_context = _create_unverified_https_context
        url = "%s/a%s" % (self.server, q)
        if verbose:
            Log("Query: %s" % url)
        req = request.Request(url, None, headers)
        response = request.urlopen(req, None, CONNECT_TIMEOUT)
        # skip )]}'
        response.readline()
        data = json.loads(response.read().decode("utf-8"))
        if verbose:
            Log("Response: %s" % data)
        return data

def GetMachine(url):
    info = urlparse(url)
    machine = info.netloc
    if info.port:
        return machine.replace(":%s" % info.port, "")
    return machine

def GetLoginInfoNetrc(url, path):
    path = os.path.expanduser(path)
    if not os.path.isfile(path):
        Log("Can't find netrc path '%s'" % (path))
        return None, None

    handle = netrc.netrc(path)
    machine = GetMachine(url)
    if not machine:
        Log("Can't find machine name in '%s' for netrc" % (url))
        return None, None
    info = handle.authenticators(machine)
    if not info:
        Log("Can't find machine '%s' in '%s'" % (machine, path))
        return None, None
    (login, account, password) = info
    return (login, password)

def ParseArguments(argv):
    parser = argparse.ArgumentParser()
    parser.add_argument('change_num', nargs='*',
                        help = 'ex. \'12345\', \'12345/1\'')
    parser.add_argument('-u', '--user', type=str,
                        help='gerrit user id')
    parser.add_argument('-p', '--password', type=str,
                        help='gerrit HTTP password')
    # https://review.openstack.org/Documentation/user-search.html
    parser.add_argument('-q', '--query', type=str,
                        help='query command ex. \'branch:master status:merged after:"2018-11-17 22:06:00"\'')
    parser.add_argument('--query-only', action='store_true', default=False,
                        help='do not install patch')
    parser.add_argument('-r', '--preview', type=str, default=PREVIEW,
                        help='preview command for changes ex. \'git log --oneline -1\' (default: %s)' % PREVIEW)
    parser.add_argument('-g', '--gerrit', type=str, default=DEFAULT_GERRIT_SERVER,
                        help='gerrit server url ex. \'https://gerrit.mycompany.com (default: %s)\'' % DEFAULT_GERRIT_SERVER)
    parser.add_argument('-d', '--dryrun', action='store_true', default=False,
                        help='show what would be done')
    parser.add_argument('-n', '--netrc-file', type=str, default=DEFAULT_NETRC_PATH,
                        help='netrc path (default: %s)' % DEFAULT_NETRC_PATH)
    parser.add_argument('-m', '--manifest', type=str, metavar='NAME.xml',
                        help='assign manifest file to resolve patch install path instead of using repo command')
    parser.add_argument('-i', '--install-path', type=str, default='',
                        help='assign patch install path instead of resolving path by repo command or manifest')
    parser.add_argument('-F', '--full-path', action='store_true', default=False,
                        help='display the full install path instead of the relative install path')
    parser.add_argument('-N', '--name-path', action='store_true', default=False,
                        help='display the project name instead of the relative install path')
    parser.add_argument('-x', '--exec', dest='exe', type=str,
                        help='append command after all changes installed in each project')
    parser.add_argument('-v', '--verbose', action='store_true', default=False,
                        help='show more logs')
    parser.add_argument('--version', action='version', version='1.0.2')
    opts = parser.parse_args(argv[1:])

    if not opts.change_num and not opts.query:
        parser.print_help()
        return None

    if not opts.user and not opts.password:
        (opts.user, opts.password) = GetLoginInfoNetrc(opts.gerrit, opts.netrc_file)
    if not opts.user and not opts.password:
        Log ("Please assign user name and password")
        return None
    return opts

def ChangeIdToRefId(change_num, patchset_id):
    hash_id = str(int(change_num) % 100)
    if len(hash_id) == 1:
        hash_id = "0" + hash_id
    return "refs/changes/%s/%s/%s" % (hash_id, change_num, patchset_id)

def GetRemote(fetch_info):
    remote = ''
    for prot in FETCH_PROTOCOL_ORDER.split(' '):
        if prot in fetch_info:
            remote = fetch_info[prot]['url']
            break
    return remote

def ResponseToPatch(response, change_num, patchset_id):
    info = response[0]
    project = info['project']
    current_revision = info['current_revision']
    if not patchset_id:
        patchset_id = info['revisions'][current_revision]['_number']
    fetch_info = info['revisions'][current_revision]['fetch']

    remote = GetRemote(fetch_info)
    if not remote:
        Loge("Can't find remote project for %s" % change_num)
        Log(fetch_info)
        return None

    patch = {'project': project, 'remote': remote, 'patch': "%s/%s" % (change_num, patchset_id), 'ref': ChangeIdToRefId(change_num, patchset_id)}
    return patch

def GetChangeIdInfo(change_num):
    patchset_id = ''
    if '/' in change_num:
        compound_id = change_num.split('/')
        if len(compound_id) != 2:
            return None, None
        change_num = compound_id[0]
        patchset_id = compound_id[1]

    if not change_num.isdigit():
        return None, None
    if patchset_id and not patchset_id.isdigit():
        return None, None
    return change_num, patchset_id

def QueryChanges(opts, rest):
    query = '/changes/?q=%s' % quote(opts.query);
    try:
        response = rest.query(query, opts.verbose)
    except Exception as e:
        Loge("query fail: %s" % e)
        return 1
    if opts.verbose:
        Log(DELIMITER)
    if not response:
        return None
    for res in reversed(response):
        opts.change_num.append(str(res['_number']))
    Log(DELIMITER)
    for res in response:
        Log("%8s - %s" % (res['_number'], res['subject']))
    Log("(Total: %d changes)" % len(response))
    Log(DELIMITER_END)
    return 0

def GetPatchList(opts, rest):
    patch_list = {}
    for input_num in opts.change_num:
        change_num, patchset_id = GetChangeIdInfo(input_num)
        if not change_num:
            Loge("error: Unknown change number format '%s'" % input_num)
            return None

        query = '/changes/?q=%s&o=%s' % (change_num, "CURRENT_REVISION")
        try:
            response = rest.query(query, opts.verbose)
        except Exception as e:
            Loge("query fail: %s" % e)
            return None
        if opts.verbose:
            Log(DELIMITER)
        if not response:
            Loge ("Can't find change number '%s'" % change_num)
            return None

        patch = ResponseToPatch(response, change_num, patchset_id)
        if not patch:
            return None
        patch['input_num'] = input_num

        project = patch['project']
        patch_list[project] = patch_list.get(project, [])
        patch_list[project].append(patch)
    return patch_list

def InstallPatch(opts, remote, ref_id):
    if Run("git fetch %s %s" % (remote, ref_id)):
        return 1

    if opts.dryrun:
        return Run("%s %s" % (opts.preview, "FETCH_HEAD"))
    return Run("git cherry-pick FETCH_HEAD")

def GetInputNums(patches, start = 0):
    ret = ""
    for i in range(start, len(patches)):
        ret += patches[i]['input_num'] + " "
    return ret.rstrip()

def CheckInstallPath(full_path, project, patches, opts):
    method = "manifest '%s'" % opts.manifest if opts.manifest else "command '%s list'" % REPO
    error_msg = "The install path for the project can't be found."
    if not full_path:
        error_msg = "The install path can't be resolved by %s" % method
    elif not os.path.exists(full_path):
        error_msg = "The install path '%s' is not found." % full_path
    else:
        return True

    Loge("[Project: %s]" % project)
    Log('''\
WARNING: %s
    (use '-i INSTALL_PATH' for '%s' to assign the right path)
    (use the following commands manually in the right path)''' % (error_msg,
                                                                  GetInputNums(patches)))
    Log(DELIMITER)
    for patch in patches:
        Log("git fetch %s %s && git cherry-pick FETCH_HEAD" % (patch['remote'], patch['ref']))
    Log(DELIMITER_END)
    return False

def InstallPatches(opts, project, patches, root):
    if opts.install_path:
        path = opts.install_path
        full_path = os.path.abspath(path)
    elif opts.manifest:
        ele = root.find("./project[@name='%s']" % project)
        path = ele.get('path') if ele is not None else None
        full_path = os.path.abspath(path) if path is not None else None
    else:
        path = RunLog("%s list -p %s" % (REPO, project), silent=True)
        full_path = RunLog("%s list -pf %s" % (REPO, project), silent=True)

    if not CheckInstallPath(full_path, project, patches, opts):
        return 1

    if opts.full_path:
        display_path = full_path
    elif opts.name_path:
        display_path = project
    else:
        display_path = path
    Logm("[%s]" % display_path.rstrip('/'))

    os.chdir(full_path)
    fail = 0
    for i, patch in enumerate(patches):
        remote = patch['remote']
        ref_id = patch['ref']

        if opts.dryrun:
            Log("(dryrun) Pick: %s %s" % (remote, ref_id))
        else:
            Log("Pick: %s %s" % (remote, ref_id))
        Log(DELIMITER)
        ret = InstallPatch(opts, remote, ref_id)
        Log(DELIMITER_END)
        if ret:
            Loge("error: pick fail (unfinished patches: %s)" % GetInputNums(patches, i))
            fail = 1
            break
    if not fail and opts.exe:
        RunShell(opts.exe)
    os.chdir(CUR_PATH)
    return fail

def main(argv):
    opts = ParseArguments(argv)
    if not opts:
        return 1
    
    rest = GerritRest(opts.gerrit, opts.user, opts.password)
    if opts.query:
        Log("Querying change numbers from '%s' ..." % opts.gerrit)
        Log("Search for '%s'" % opts.query)
        if QueryChanges(opts, rest):
            return 1

    if opts.query_only:
        return 0

    root = None;
    if opts.manifest:
        try:
            root = etree.parse(opts.manifest).getroot()
        except Exception as e:
            Log("Manifest '%s' parse error: %s" % (opts.manifest, e))
            return 1

    Log("Getting patches from '%s' ..." % opts.gerrit)
    patch_list = GetPatchList(opts, rest)
    if not patch_list:
        return 1

    Log("Installing patches ...")
    ret = 0
    for project, patches in patch_list.items():
        if InstallPatches(opts, project, patches, root):
            ret = 1
    return ret

if __name__ == '__main__':
    sys.exit(main(sys.argv))
