#!python

import json, os, re, sys
from openai import OpenAI

try:
	from dotenv import load_dotenv
	load_dotenv()
except ImportError:
	pass


api_key = os.getenv('OPENAI_API_KEY')
model = os.getenv('OPENAI_MODEL')
default_prompt = os.getenv('OPENAI_DEFAULT_PROMPT')
log_file = os.getenv('OPENAI_LOGFILE')
temperature = os.getenv('OPENAI_TEMPERATURE')
max_tokens = os.getenv('OPENAI_MAX_TOKENS')
full_output = False

if len(sys.argv) > 1:
	# Arguments were passed, use the arguments as the input
	if len(sys.argv) >= 2:
		if '-p' in sys.argv[1]:
			full_output = True
			input_data = ''.join(sys.argv[2:])
	else:
		input_data = ''.join(sys.argv[1:])
elif len(sys.argv) == 1:
	# No arguments were provided, read from standard input
	input_data = sys.stdin.read()


if not model:
	model = 'gpt-3.5-turbo'
if not default_prompt:
	default_prompt = ''
if not temperature:
	temperature = 0.0
else:
	temperature = float(temperature)
if not max_tokens:
	max_tokens = 2048
else:
	max_tokens = int(max_tokens)

client = OpenAI(api_key=api_key)

def write_to_log(display: str):
	with open(log_file, 'a') as f:
		f.write(f"{display}\n")

def get_answer(prompt: str, response=None, choices=None, result=None) -> str:
	response = client.chat.completions.create(
		model=model,
		messages=[
            		{"role": "system", "content": default_prompt},
            		{"role": "user", "content": prompt}
        	],
		temperature=temperature,
		max_tokens=max_tokens
		)
	choices = response.choices
	if choices:
		res = choices[0].message.content
	if res:
		return re.sub(r'^\s+|\s+$', '', res)  # Trim leading/trailing whitespace
	return None


if __name__ == '__main__':
	prompt = input_data
	answer = get_answer(prompt)
	if not answer:
		print("Error retrieving answer.")
		sys.exit()
	if full_output:
		display = (
				"----------------------------------------------------------------------\n"
				"Prompt:\n"
				f"{prompt}\n\n"
				"Answer:\n"
				f"{answer}\n"
				"----------------------------------------------------------------------"
			)
	else:
		display = (f'\n{answer}\n')
	print(display)
	if log_file:
		write_to_log(display)
