#!/usr/bin/env python
# Copyright 2021 University of Chicago
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import json
import os
from pathlib import Path
import shlex
import subprocess
import sys
import time

from jupyter_core.paths import jupyter_runtime_dir

parser = argparse.ArgumentParser("hydra")
parser.add_argument("--id", required=True, help="The ID assigned to the kernel")
parser.add_argument(
    "--kernel", required=True, help="The kernel implementation to start"
)
parser.add_argument("--timeout", default=10, help="A timeout on kernel startup")
parser.add_argument(
    "--launcher", default="hydra-subkernel", help="The kernel launch binary"
)
parser.add_argument("--log-file", dest="log_to_file", action="store_true")
parser.add_argument("--no-log-file", dest="log_to_file", action="store_false")
parser.add_argument("--debug", dest="debug", action="store_true", default=False)
parser.set_defaults(log_to_file=True)
args = parser.parse_args(sys.argv[1:])

runtime_dir = Path(jupyter_runtime_dir())
runtime_dir.mkdir(exist_ok=True)

pid_file = runtime_dir.joinpath(f"kernel-{args.id}.pid")
connection_file = runtime_dir.joinpath(f"kernel-{args.id}.json")
log_file = runtime_dir.joinpath(f"kernel-{args.id}.log")


def cleanup_kernel_files():
    for file in [pid_file, connection_file]:
        if file.exists():
            file.unlink()


def write_subkernel_info(pid: int):
    with pid_file.open("w") as f:
        f.write(str(pid))
    subkernel = {"pid": pid}
    with connection_file.open("r") as f:
        subkernel["connection"] = json.load(f)
    print(json.dumps(subkernel))


# Check if existing process / connection_file exists and attempt to re-use
if pid_file.exists() and connection_file.exists():
    existing_pid = pid_file.read_text(encoding="utf-8")
    try:
        os.kill(int(existing_pid), 0)
    except Exception:
        # Process is not running, or pidfile malformed
        cleanup_kernel_files()
    else:
        write_subkernel_info(existing_pid)
        sys.exit()
else:
    cleanup_kernel_files()

cmd_str = f"{args.launcher} --kernel={args.kernel} --KernelManager.connection_file={connection_file}"
if args.log_to_file:
    cmd_str += f" --log-file={log_file}"
if args.debug:
    cmd_str += " --debug"
kernel = subprocess.Popen(
    shlex.split(cmd_str), stderr=subprocess.STDOUT, start_new_session=True
)

start = time.perf_counter()
while not connection_file.exists():
    if time.perf_counter() - start > args.timeout:
        ret = kernel.poll()
        if ret is not None:
            msg = f"Kernel failed with status {ret}."
        msg = "Failed to find connection file, did the kernel fail to start?"
        if args.log_to_file:
            msg += f" See {log_file} for details."
        raise TimeoutError(msg)
    time.sleep(0.1)

write_subkernel_info(kernel.pid)
