#!/usr/bin/python3
import argparse
import asyncio
import os
import sys
import logging
import tempfile
from rich.logging import RichHandler


def parse_args():
	ap = argparse.ArgumentParser()
	for arg, kwargs in CLI_ARGS.items():
		if "positional" in kwargs and kwargs["positional"]:
			new_kwargs = kwargs.copy()
			del new_kwargs["positional"]
			ap.add_argument(arg, **new_kwargs)
		else:
			ap.add_argument("--" + arg, **kwargs)
	return ap.parse_args()

from_git = False

f_path = os.path.dirname(os.path.realpath(__file__))
proj_path = os.path.normpath(os.path.join(f_path, "../.git"))
if os.path.isdir(proj_path):
	mod_path = os.path.normpath(os.path.join(proj_path, "../funtoo_ramdisk"))
	if os.path.exists(mod_path):
		from_git = True
		sys.path.insert(0, mod_path)
		support_root = os.path.normpath(os.path.join(f_path, "../support"))

if not from_git:
	# import the module just to use it to get the path to the files.
	import funtoo_ramdisk
	support_root = os.path.normpath(os.path.join(funtoo_ramdisk.__file__, "../../support"))

from funtoo_ramdisk.initramfs import InitialRamDisk
try:
	from funtoo_ramdisk.version import __version__
except:
	__version__ = "git"

CLI_ARGS = {
	"kernel": {"default": None, "action": "store", "nargs": "?", "help": "What kernel to use to build initramfs (default: /usr/src/linux symlink."},
	"destination": {"default": None, "action": "store", "positional": True, "help": "The output initramfs filename to create."},
	"debug": {"default": False, "action": "store_true"},
	"compression": {"default": "xz", "action": "store", "help": "Compression method. One of: xz, zstd"},
	"backtrace": {"default": False, "action": "store_true", "help": "Show full backtrace of any exception."},
	"force": {"default": False, "action": "store_true", "help": "Force overwrite of initramfs destination if it exists."},
	"version": {"action": "version", "version": f"funtoo-ramdisk {__version__}"}
}

args = parse_args()

log = logging.getLogger("ramdisk")
if args.debug:
	log.setLevel(logging.DEBUG)
else:
	log.setLevel(logging.INFO)
handler = RichHandler(show_path=False, show_time=False)
log.addHandler(handler)
if args.debug:
	log.warning("DEBUG enabled")
if from_git:
	log.warning(f"Running from git repository {os.path.dirname(proj_path)}")


async def main_thread():
	output_initramfs = os.path.normpath(args.destination)
	if os.path.exists(output_initramfs) and not args.force:
		raise FileExistsError("Specified destination initramfs already exists -- use --force to overwrite.")
	os.makedirs(os.path.dirname(output_initramfs), exist_ok=True)
	if args.kernel is None:
		log.info("No kernel specified, so going to try to use /usr/src/linux symlink.")
		if not os.path.islink("/usr/src/linux"):
			raise FileNotFoundError("/usr/src/linux does not exist or is not a symlink.")
		link_target = os.readlink("/usr/src/linux")

		# grab data from Makefile, so we can determine correct kernel name for finding modules:
		datums = ["VERSION", "PATCHLEVEL", "SUBLEVEL", "EXTRAVERSION"]
		got_datums = {}
		with open(os.path.join(os.path.dirname("/usr/src/linux"), link_target, "Makefile"), "r") as mkf:
			while len(got_datums.keys()) != 4:
				line = mkf.readline()
				if not line:
					break
				for datum in datums:
					if datum not in got_datums:
						if line.startswith(f"{datum} ="):
							got_datums[datum] = line.split("=")[1].strip()
		if len(got_datums.keys()) != 4:
			raise ValueError(f"Could not extract: {datums} from /usr/src/linux/Makefile.")
		kernel_version = "{VERSION}.{PATCHLEVEL}.{SUBLEVEL}{EXTRAVERSION}".format(**got_datums)
		module_path = f"/lib/modules/{kernel_version}"
		if not os.path.exists(module_path):
			raise FileNotFoundError(f"Expected kernel module path {module_path} not found.")
	else:
		kernel_version = args.kernel
	with tempfile.TemporaryDirectory(prefix="ramdisk-", dir="/var/tmp") as temp_dir:
		ramdisk = InitialRamDisk(
			temp_root=temp_dir,
			support_root=support_root,
			kernel_version=kernel_version,
			compression=args.compression
		)
		ramdisk.create_ramdisk(output_initramfs)


if __name__ == "__main__":
	try:
		success = asyncio.run(main_thread())
	except Exception as e:
		log.error(f"{e.__class__.__name__}: {e}")
		if args.backtrace:
			log.fatal("Backtrace", exc_info=e)
		success = False
	if not success:
		sys.exit(1)
