#!/usr/bin/env python

import re
from io import StringIO
from os import environ
from os.path import join, isfile, realpath, expanduser
from gnureadline import parse_and_bind, set_completer, read_history_file, write_history_file
from argparse import ArgumentParser
from contextlib import contextmanager
from signal import signal, SIGINT, SIGALRM, setitimer, ITIMER_REAL
from textwrap import dedent
from glob import glob
from functools import partial

from lark import Lark
from lark.visitors import Interpreter

from sm_aicli import *

REVISION:str|None
try:
  REVISION=environ["AICLI_REVISION"]
except Exception:
  try:
    from subprocess import check_output, DEVNULL
    REVISION=check_output(['git', 'rev-parse', 'HEAD'],
                          cwd=environ['AICLI_ROOT'],
                          stderr=DEVNULL).decode().strip()
  except Exception:
    try:
      from sm_aicli.revision import REVISION as __rv__
      REVISION = __rv__
    except ImportError:
      REVISION = None

VERSION:str|None
try:
  VERSION=check_output(['python3', 'setup.py', '--version'],
                       cwd=environ['AICLI_ROOT'],
                       stderr=DEVNULL).decode().strip()
except Exception:
  try:
    from sm_aicli.revision import VERSION as __ver__
    VERSION = __ver__
  except ImportError:
    VERSION = None

def _get_cmd_prefix():
  prefices = list({c[0] for c in COMMANDS if len(c)>0})
  assert len(prefices) == 1
  return prefices[0]

CMDPREFIX = _get_cmd_prefix()

ARG_PARSER = ArgumentParser(description="Command-line arguments")
ARG_PARSER.add_argument(
  "--model-dir",
  type=str,
  help="Model directory to prepend to model file names",
  default=None
)
ARG_PARSER.add_argument(
  "--model", "-m",
  type=str,
  help="Model to use. STR1 is 'gpt4all' (the default) or 'openai'. STR2 is the model name",
  metavar="[STR1:]STR2",
  # default="mistral-7b-instruct-v0.1.Q4_0.gguf",
  # default='/home/grwlf/.local/share/nomic.ai/GPT4All/Meta-Llama-3-8B-Instruct.Q4_0.gguf'
  default=None
)
ARG_PARSER.add_argument(
  "--num-threads", "-t",
  type=int,
  help="Number of threads to use",
  default=None
)
ARG_PARSER.add_argument(
  '--model-apikey',
  type=str,
  help="Model provider-specific API key",
  metavar="STR",
  default=None
)
ARG_PARSER.add_argument(
  "--model-temperature",
  type=float,
  help="Temperature parameter of the model",
  default=None
)
ARG_PARSER.add_argument(
  "--device", "-d",
  type=str,
  help="Device to use for chatbot, e.g. gpu, amd, nvidia, intel. Defaults to CPU",
  default=None
)
ARG_PARSER.add_argument(
  "--readline-key-send",
  type=str,
  help="Terminal code to treat as Ctrl+Enter (default: \\C-k)",
  default="\\C-k"
)
ARG_PARSER.add_argument(
  '--readline-prompt',
  type=str,
  help="Input prompt (default: >>>)",
  default=">>> "
)
HISTORY_DEF="_sm_aicli_history"
ARG_PARSER.add_argument(
  '--readline-history',
  type=str,
  metavar='FILE',
  help=f"History file name (default is '{HISTORY_DEF}'; set empty to disable)",
  default=HISTORY_DEF
)
ARG_PARSER.add_argument(
  '--verbose',
  type=str,
  metavar='NUM',
  help="Set the verbosity level 0-no,1-full",
)
ARG_PARSER.add_argument(
  '--revision',
  action='store_true',
  help="Print the revision",
)
ARG_PARSER.add_argument(
  '--version',
  action='store_true',
  help="Print the version",
)



@contextmanager
def with_sigint(_handler):
  """ A Not very correct singal handler. One also needs to mask signals during switching handlers """
  prev=signal(SIGINT,_handler)
  try:
    yield
  finally:
    signal(SIGINT,prev)

def as_float(val:str, default:float|None)->float|None:
  return float(val) if val not in {None,"def","default"} else default
def as_int(val:str, default:int|None)->int|None:
  return int(val) if val not in {None,"def","default"} else default

def ask1(args, model:ModelProvider, message:str) -> None:
  opt:Options = Options(verbose=as_int(args.verbose, 0))
  try:
    def _signal_handler(signum,frame):
      model.interrupt()
      print("\n<Keyboard interrupt>")

    model.set_temperature(as_float(args.model_temperature,None))
    response_generator = model.stream(message, opt=opt)
    with with_sigint(_signal_handler):
      for token in response_generator:
        print(token, end='', flush=True)
  finally:
    print()

def print_help():
  print(f"Commands: {' '.join(COMMANDS)}")

def print_aux(args, s:str)->None:
  print(f"INFO: {s}")

def print_aux_err(args, s:str)->None:
  print(f"ERROR: {s}")

@contextmanager
def with_chat_session(model:ModelProvider):
  if model is None:
    yield
  else:
    with model.with_chat_session():
      yield

def model_locations(args, model)->list[str]:
  for path in [model] + ([join(args.model_dir,model)] if args.model_dir else []):
    for match in glob(expanduser(path)):
      yield realpath(match)

def ensure_quoted(s:str)->str:
  if not (len(s)>0 and s[0]=='"'):
    s = '"' + s
  if not (len(s)>0 and s[-1]=='"'):
    s = s + '"'
  return s

def parse_apikey(args)->str|None:
  if args.model_apikey is None:
    return None
  else:
    schema,val = args.model_apikey
    if schema=="verbatim":
      return val
    elif schema=="file":
      try:
        with open(expanduser(val)) as f:
          return f.read().strip()
      except Exception as err:
        raise ValueError(str(err)) from err
    else:
      raise ValueError(f"Unsupported APIKEY schema '{schema}'")

def parse_model(args) -> ModelProvider | None:
  if args.model is None:
    return None
  else:
    mprov,mname = args.model
    if mprov == 'gpt4all':
      filename = None
      for model in model_locations(args, mname):
        if isfile(model):
          filename = model
          break
      mname = filename if filename is not None else mname
      return GPT4AllModelProvider(mname, device=args.device)
    elif mprov == 'openai':
      apikey = parse_apikey(args)
      if apikey is None:
        raise ValueError("OpenAI models require apikey")
      return OpenAIModelProvider(model=mname, apikey=apikey)
    elif mprov == 'dummy':
      apikey = parse_apikey(args)
      return DummyModelProvider(model=mname, apikey=apikey)
    else:
      raise ValueError(f"Invalid model provider '{mprov}'")

def main(cmdline=None):
  header = StringIO()
  args = ARG_PARSER.parse_args(cmdline)
  if args.revision:
    print(REVISION)
    return 0
  if args.version:
    print(VERSION)
    return 0
  if args.model_apikey is not None:
    header.write(f"/apikey {ensure_quoted(args.model_apikey)}\n")
  if args.model is not None:
    header.write(f"/model {ensure_quoted(args.model)}\n")
  parse_and_bind('tab: complete')
  parse_and_bind(f'"{args.readline_key_send}": "{CMD_ASK}\n"')
  hint = args.readline_key_send.replace('\\', '')

  print(f"Type /help or a question followed by the /ask command (or by pressing "
        f"`{hint}` key).")
  old_model = None
  model = None

  def _apply():
    nonlocal model, old_model
    if old_model != args.model:
      if args.model is not None:
        model = parse_model(args)
        print_aux(args, "Model is now initialized")
      else:
        if model:
          print_aux(args, "Model is now destroyed")
          del model
        model = None
      old_model = args.model
    if model is not None:
      num_threads = as_int(args.num_threads, None)
      old_num_threads = model.get_thread_count()
      if num_threads is not None:
        model.set_thread_count(num_threads)
        new_num_threads = model.get_thread_count()
        if old_num_threads != new_num_threads:
          print_aux(args, f"Num threads is now {new_num_threads}")

  class Repl(Interpreter):
    def __init__(self):
      self._reset()
    def _reset(self):
      self.in_echo = False
      self.message = ""
      self.exit_request = False
      self.reset_request = False
    def reset(self):
      old_message = self.message
      self._reset()
      if len(old_message)>0:
        print_aux(args, "Message buffer is now empty")
    def _finish_echo(self):
      if self.in_echo:
        print()
      self.in_echo = False
    def as_verbatim(self, tree):
      return "verbatim"
    def as_file(self, tree):
      return "file"
    def apikey_value(self, tree):
      return tree.children[0].value
    def apikey_string(self, tree):
      val = self.visit_children(tree)
      return tuple(val) if len(val)==2 else ("verbatim",val[0])
    def mp_gpt4all(self, tree):
      return "gpt4all"
    def mp_openai(self, tree):
      return "openai"
    def mp_dummy(self, tree):
      return "dummy"
    def model_provider(self, tree):
      return tree.children[0].value
    def model_name(self, tree):
      return tree.children[0].value
    def model_string(self, tree):
      val = self.visit_children(tree)
      return tuple(val) if len(val)==2 else ("gpt4all",val[0])
    def command(self, tree):
      nonlocal model
      self._finish_echo()
      command = tree.children[0].value
      if command == CMD_ECHO:
        self.in_echo = True
      elif command == CMD_ASK:
        if model is None:
          raise RuntimeError("No model is active, use /model first")
        if len(self.message.strip()) == 0:
          raise RuntimeError("Empty message buffer, write something first")
        else:
          ask1(args, model, self.message)
          self.message = ""
      elif command == CMD_HELP:
        print_help()
      elif command == CMD_EXIT:
        self.exit_request = True
      elif command == CMD_MODEL:
        res = self.visit_children(tree)
        if len(res)>2:
          mprov,mname = res[3]
          args.model = (mprov,mname)
          print_aux(args, f"Will use provider '{mprov}' name '{mname}'")
        else:
          args.model = None
          print_aux(args, f"Will deinitialize model ")
        self.reset_request = True
      elif command == CMD_NTHREADS:
        args.num_threads = tree.children[2].children[0].value
        print_aux(args, f"Num threads will be set to '{args.num_threads}'")
        self.reset_request = True
      elif command == CMD_TEMP:
        args.model_temperature = tree.children[2].children[0].value
        print_aux(args, f"Temperature will be set to '{args.model_temperature}'")
      elif command == CMD_RESET:
        print_aux(args, "Message buffer will be cleared")
        self.reset_request = True
      elif command == CMD_APIKEY:
        res = self.visit_children(tree)
        if len(res)<3:
          raise ValueError("API key should not be empty")
        schema,arg = res[3]
        print_aux(args, f"Model API schema \"{schema}\" value \"{arg}\"")
        args.model_apikey = (schema,arg)
        self.reset_request = True
      elif command == CMD_VERBOSE:
        args.verbose = tree.children[2].children[0].value
        print_aux(args, f"Verbosity will be set to '{args.verbose}'")
      else:
        raise ValueError(f"Unknown command: {command}")
    def text(self, tree):
      text = tree.children[0].value
      if self.in_echo:
        print(text, end='')
      else:
        for cmd in COMMANDS:
          if cmd in text:
            print_aux(args, f"Warning: '{cmd}' was parsed as a text")
        self.message += text
    def escape(self, tree):
      text = tree.children[0].value[1:]
      if self.in_echo:
        print(text, end='')
      else:
        self.message += text

  if args.readline_history:
    try:
      read_history_file(args.readline_history)
      print_aux(args, f"History file loaded")
    except FileNotFoundError:
      print_aux(args, f"History file not loaded")
  else:
    print_aux(args, f"History file is not used")

  repl = Repl()
  repl.visit(PARSER.parse(header.getvalue()))
  while not repl.exit_request:
    _apply()
    repl.reset()
    with with_chat_session(model):
      try:
        while all([not repl.exit_request, not repl.reset_request]):
          repl.visit(PARSER.parse(input(args.readline_prompt)))
          repl._finish_echo()
          if args.readline_history:
            write_history_file(args.readline_history)
      except (ValueError,RuntimeError) as err:
        print_aux_err(args, err)
      except EOFError:
        print()
        break

if __name__ == "__main__":
  main()
