#!python

import json
import os
import re
import sys
import argparse
from openai import OpenAI

try:
    from dotenv import load_dotenv
    load_dotenv()
except ImportError:
    pass

api_key = os.getenv('OPENAI_API_KEY')
model = os.getenv('OPENAI_MODEL', 'gpt-3.5-turbo')
default_prompt = os.getenv('OPENAI_DEFAULT_PROMPT', '')
log_file = os.getenv('OPENAI_LOGFILE', 'gpt.log')
temperature = float(os.getenv('OPENAI_TEMPERATURE', 0.0))
max_tokens = int(os.getenv('OPENAI_MAX_TOKENS', 2048))

if not model:
    model = 'gpt-3.5-turbo'
if not default_prompt:
    default_prompt = ''
if not temperature:
    temperature = 0.0
else:
    temperature = float(temperature)
if not max_tokens:
    max_tokens = 2048
else:
    max_tokens = int(max_tokens)

client = OpenAI(api_key=api_key)

def write_to_log(display: str):
    if log_file:
        with open(log_file, 'a') as f:
            f.write(f"{display}\n")

def get_answer(prompt: str, response=None, choices=None, result=None) -> str:
    response = client.chat.completions.create(
        model=model,
        messages=[
            {"role": "system", "content": default_prompt},
            {"role": "user", "content": prompt}
        ],
        temperature=temperature,
        max_tokens=max_tokens
    )
    choices = response.choices
    if choices:
        res = choices[0].message.content
        if res:
            return re.sub(r'^\s+|\s+$', '', res)  # Trim leading/trailing whitespace
    return None

def main():
    parser = argparse.ArgumentParser(description='Process input for the model.')
    parser.add_argument('-p', '--print', action='store_true', help='print full output')
    parser.add_argument('input_data', nargs='*', help='input data for the model')

    args = parser.parse_args()
    full_output = args.print

    if not sys.stdin.isatty():
        # If there's piped or redirected input, read it
        input_data = sys.stdin.read().strip()
    else:
        # Otherwise, use the arguments passed
        input_data = ' '.join(args.input_data)

    if not input_data:
        print("No input provided.")
        sys.exit()

    prompt = input_data
    answer = get_answer(prompt)
    if not answer:
        print("Error retrieving answer.")
        sys.exit()

    if full_output:
        display = (
            "----------------------------------------------------------------------\n"
            "Prompt:\n"
            f"{prompt}\n\n"
            "Answer:\n"
            f"{answer}\n"
            "----------------------------------------------------------------------"
        )
    else:
        display = f'\n{answer}\n'

    print(display)
    write_to_log(display)

if __name__ == '__main__':
    main()

