#!python
import openai
import os
import termcolor
from prompt_toolkit import prompt
from prompt_toolkit.validation import Validator, ValidationError
from prompt_toolkit.styles import Style
import click



presets = {
    "Chat": {
        "message": "The following is a conversation with an AI assistant. The assistant is helpful, creative, clever, and very friendly.",
        "inject": {
            "state": True,
            "start": "AI:",
            "end": "Human:"
        },
    },
    "Q&A": {
        "message": "I am a highly intelligent question answering bot. If you ask me a question that is rooted in truth, I will give you the answer. If you ask me a question that is nonsense, trickery, or has no clear answer, I will respond with 'Unknown'.\n",
        "inject": {
            "state": True,
            "start": "A:",
            "end": "Q:"
        },
    },
    "Grammar Correction": {
        "message": "Correct this to standard English:\n",
        "inject": {
            "state": False,
        }
    },
    "Eli5": {
        "message":"Summarize this for a second-grade student:\n",
        "inject": {
            "state": False,
        }
    },
    "Custom": {
        "message": "",
        "inject": {
            "state": False
        }
    }
}

@click.command()
@click.option('--api_key', help='Openai API key. If not provided, will prompt for it or use the environment variable OPENAI_API_KEY.')
@click.option('--model', default='text-davinci-003', help='Model to use for text generation | (default: text-davinci-003)')
@click.option('--temperature', default=0.9, type=click.FLOAT, help='Temperature for text generation | (default: 0.9)')
@click.help_option('--help', '-h')
def generate_text(api_key, model, temperature):
    """
    Generates text using the OpenAI API.

    \b
    Examples:
        gpt-chatbot-cli
        gpt-chatbot-cli --api_key=YOUR_API_KEY
        gpt-chatbot-cli --api_key=YOUR_API_KEY --model=text-davinci-002 --temperature=0.7
    """

    def api_key_helper():
        return [('class:api-key-helper', 'Set the environment variable OPENAI_API_KEY to avoid further prompts. ')]


    def chat_prompt_helper(on, message):
            return [('class:chat-prompt-helper', "Mode: "+on+"\n"+message)]

    style = Style.from_dict({
        'api-key-helper': '#fc802d bg:#282828 bold',
        'chat-prompt-helper': "#504945 bg:#fbf0c9 bold"
    })

    def check_api_key_validity(api_key, where):
        if(where == "prompt"):
            print("Found env variable OPENAI_API_KEY")
        print("Checking for validity")
        try:
            openai.api_key = api_key
            openai.Model.list()
            print(termcolor.colored(f"API key is valid", 'light_green', attrs=["bold"]))
        except openai.OpenAIError as e:
            print(termcolor.colored(f"Invalid API key", 'light_red', attrs=["bold"]) + "\nGrab your API key from: "+termcolor.colored(f"https://beta.openai.com/account/api-keys", 'light_blue', attrs=["underline"]))
            exit()
    # import api key
    openai_api_key = os.environ.get("OPENAI_API_KEY") or api_key

    if not openai_api_key:
        try:
            openai_api_key = prompt("Please enter your OpenAI API key: ", bottom_toolbar=api_key_helper, style=style)
            check_api_key_validity(openai_api_key, "not-prompt")
        except KeyboardInterrupt:
            print("Exiting...")
            exit(0)
    else:
        check_api_key_validity(openai_api_key, "prompt")

    # set model and temperature
    lang_model = model
    config_temperature = temperature

    # Initialize conversation_history
    conversation_history = ''

    try:
        # Ask user to pick a preset
        print("Please choose a preset:")
        
        # Print the options
        options = list(presets.keys())
        print("Please select an option by number:")
        for i, key in enumerate(options):
            print(f"{i+1}: {key}")

        # Validate user input
        class NumberValidator(Validator):
            def validate(self, document):
                try:
                    value = int(document.text)
                    if value not in range(1, len(options) + 1):
                        raise ValueError()
                except ValueError:
                    raise ValidationError(
                        message='Please enter a number between 1 and {}'.format(len(options)),
                        cursor_position=len(document.text))

        # Ask the user to pick an option
        selected_num = prompt('Option: ', validator=NumberValidator())

        # Get the selected option
        selected_num = int(selected_num)
        chosen_preset = list(presets.keys())[selected_num - 1]
        
        # Append preset message to conversation_history
        if chosen_preset in presets:
            conversation_history += presets[chosen_preset]["message"]
        else:
            print("Invalid preset chosen, please try again.")
            exit()
            
        # Replace #END# and #START# with preset's end and start's string if available
        if "inject" in presets[chosen_preset] and presets[chosen_preset]["inject"]["state"]:
            state = True
            end_string = presets[chosen_preset]["inject"]["end"]
            start_string = presets[chosen_preset]["inject"]["start"]
        else:
            state = False
            end_string = ">"
            start_string = ">"
            
        # start chat loop
        while True:
            # get user input
            user_input = prompt(end_string, bottom_toolbar=chat_prompt_helper(chosen_preset, presets[chosen_preset]["message"]), style=style)

            if user_input.lower() in ["exitgpt","exit"]:
                break
            # generate response
            response = openai.Completion.create(
                engine=lang_model,
                prompt=conversation_history + end_string + user_input + "\n" + start_string,
                temperature=config_temperature,
                max_tokens=1024,
            )
            if (state):
                conversation_history += end_string + user_input + "\n" + response.choices[0].text + "\n"
            else:
                conversation_history = presets[chosen_preset]["message"]
            # print response with termcolor
            print(termcolor.colored(f"{start_string}{response.choices[0].text}", 'light_yellow'))

    except Exception as e:
        print(termcolor.colored(f"Error: {e}", 'light_red'))
        exit(1)
    except KeyboardInterrupt:
        print(termcolor.colored(f"Keyboard Interrupt, Exiting...", 'light_red'))
        exit(0)


if __name__ == '__main__':
    generate_text()
