#! /usr/bin/env python3
import argparse
from glob import glob
import json
import os
import subprocess


with open(os.path.join(os.path.dirname(__file__), "clickhouse-auth.json")) as f:
    CLICKHOUSE_AUTH = json.load(f)


def main():
    parser = argparse.ArgumentParser("Manage your Clickhouse instance")
    subparsers = parser.add_subparsers()

    # Run a clickhouse client
    parser_client = subparsers.add_parser("client")
    parser_client.set_defaults(func=command_client)

    # Create user
    parser_createuser = subparsers.add_parser("createuser")
    parser_createuser.add_argument(
        "-c",
        "--course-id",
        action="append",
        help="Restrict user to access data only from these courses.",
    )
    parser_createuser.add_argument(
        "-o",
        "--org-id",
        action="append",
        help="Restrict user to access data only from these organizations.",
    )
    parser_createuser.add_argument("username")
    parser_createuser.set_defaults(func=command_create_user)

    # Apply migrations
    parser_migrate = subparsers.add_parser("migrate")
    parser_migrate.add_argument(
        "-p",
        "--path",
        default="/etc/clickhouse-server/migrations.d/",
        help="Run migrations from this directory.",
    )
    parser_migrate.add_argument(
        "-d",
        "--dry-run",
        action="store_true",
        help="Don't actually apply migrations",
    )
    parser_migrate.add_argument(
        "-f",
        "--fake",
        action="store_true",
        help="Mark the migrations as applied, but don't actually run them.",
    )
    parser_migrate.set_defaults(func=command_migrate)

    args = parser.parse_args()
    if hasattr(args, "func"):
        args.func(args)
    else:
        parser.print_help()


def command_client(args):
    subprocess.check_call(get_client_command())


def command_create_user(args):
    conditions = []
    course_ids = args.course_id or []
    org_ids = args.org_id or []
    for course_id in course_ids:
        conditions.append(f"course_id = '{course_id}'")
    for org_id in org_ids:
        conditions.append(f"course_id LIKE 'course-v1:{org_id}+%'")
    condition = " OR ".join(conditions) if conditions else "1"
    username = args.username
    # Note that the "CREATE TEMPORARY TABLE" grant is required to make use of "numbers()" functions.
    run_query(
        f"""CREATE USER IF NOT EXISTS {username};
GRANT CREATE TEMPORARY TABLE ON *.* TO {username};"""
    )
    # Find the list of tables to which the user should have access: all tables that do not start with "_"
    tables = run_query("SHOW TABLES").strip().split("\n")
    for table in tables:
        if not table.startswith("_"):
            query = f"""GRANT SELECT ON {table} TO {username};
CREATE ROW POLICY OR REPLACE {username} ON {table} AS RESTRICTIVE FOR SELECT USING {condition} TO {username};"""
            print(query)
            run_query(query)


def command_migrate(args):
    # Create database
    query = f"""CREATE DATABASE IF NOT EXISTS {CLICKHOUSE_AUTH["database"]}"""
    subprocess.check_call(get_client_command_no_db("--query", query))
    # Create migrations table
    run_query(
        "CREATE TABLE IF NOT EXISTS _migrations (name String) ENGINE = MergeTree PRIMARY KEY(name) ORDER BY name"
    )

    # Apply migrations
    migrations = sorted(glob(os.path.join(args.path, "*")))
    for path in migrations:
        migration_name = os.path.basename(path)
        print(f"Applying migration {migration_name}... ", end=" ")
        query = f"SELECT 'applied' FROM _migrations WHERE name='{migration_name}'"
        is_applied = run_query(query)
        if is_applied == "applied":
            print("SKIP")
        else:
            if not args.dry_run and not args.fake:
                run_command("--queries-file", path)
            if not args.dry_run:
                run_query(f"INSERT INTO _migrations (name) VALUES ('{migration_name}')")
            print_suffix = " (dry run)" if args.dry_run else ""
            print_suffix += " (fake)" if args.fake else ""
            print(f"OK{print_suffix}")


def run_query(query):
    return run_command("--query", query)


def run_command(*args):
    result = subprocess.check_output(get_client_command(*args))
    return result.decode().strip()


def get_client_command(*args):
    return get_client_command_no_db("--database", CLICKHOUSE_AUTH["database"], *args)


def get_client_command_no_db(*args):
    command = [
        "clickhouse",
        "client",
        "--host",
        CLICKHOUSE_AUTH["host"],
        "--port",
        str(CLICKHOUSE_AUTH["port"]),
        "--user",
        CLICKHOUSE_AUTH["username"],
        "--password",
        CLICKHOUSE_AUTH["password"],
        "--multiline",
        "--multiquery",
        "--history_file=/tmp/.clickhouse-client-history",
    ]
    if CLICKHOUSE_AUTH["http_scheme"] == "https":
        command.append("--secure")
    command += args
    return command


if __name__ == "__main__":
    main()
