import json
import time

from src.utils.input import parse_cli_arguments
from src.utils.utils import craft_input_prompt, dict_to_df, load_yaml_config, setup_logging
from src.utils.automate_requests import call_automate_requests
from src.utils.response_evaluator import call_response_evaluator
from src.models.openai_model import OpenAIModel
from tqdm import tqdm

CONFIG_FILE = "src/config/config.yaml"

def main():

    logger = setup_logging()
    
    config_file_path = CONFIG_FILE
    yaml_config = load_yaml_config(config_file_path)

    config_dict, eval_dict = parse_cli_arguments(yaml_config)

    model_provider, selected_model, temperature, top_p, max_tokens, api_key, system_instruction, application, application_description, attack_info, num_prompts = config_dict.values()
    method, special_token, request_file, output_file, request_sleep,response_checker_file, report_type, response_analysis_file, additional_prompts_file  = eval_dict.values()

    if model_provider == "OpenAI":
        model = OpenAIModel(api_key, selected_model)
    
    else:
        raise ModuleNotFoundError(f"{model_provider} is not yet supported")
    
    if not model.initialize_client():
        raise Exception("Failed to initialize the model client")
    
    output_prompts = []
    for attack_key, attack_value in attack_info.items():
        
        user_instruction = craft_input_prompt(
            application=application,
            application_description=application_description,
            attack_instruction=attack_value,
            num_prompts=num_prompts
        )
        
        output = model.generate_prompt(
            system_instruction=system_instruction,
            user_instruction=user_instruction,
            num_prompts=num_prompts,
            temperature=temperature,
            top_p=top_p,
            max_completion_tokens=max_tokens
        )
        
        if output["status"] == "success":
            logger.info(f"{output["message"]} for {attack_key}")
            for sample in output["prompts"]:
                if isinstance(sample, str) and sample.strip():
                    output_prompts.append(sample.strip())
                else:
                    logger.warning(f"Skipping invalid prompt: {sample}")
        else:
            logger.error(f"Error in generating prompts for {attack_key}: {output['message']}")
            if "prompts" in output:
                output_prompts.extend(output['prompts'])
    
    if additional_prompts_file is not None:
        try:
            with open(additional_prompts_file, "r", encoding="utf-8") as f:
                additional_prompts = f.readlines()
            
            additional_prompts = [x.strip() for x in additional_prompts if x.strip()]
            output_prompts.extend(additional_prompts)
        
        except Exception as e:
            logger.error(f"Error reading additional prompts file: {e}")

    if output_prompts is not None:

        prompts_path = "evaluation_prompts.txt"
        with open (prompts_path, "w", encoding="utf-8") as f:
            f.write('\n'.join(output_prompts))
        logger.info(f"Prompts stored in {prompts_path}")

        if method == "Generate and Evaluate":
            response_dict = {}

            with open(response_checker_file, "r", encoding="utf-8") as f:
                keywords = f.readlines()

            for idx, prompt in enumerate(tqdm(output_prompts, desc="Processing request-response automation")):
                
                prompt = prompt.strip()
                logger.info(f"Sanitizing prompt: {prompt}")
                input = {"prompt": f"""{prompt}"""}
                prompt_str = json.dumps(input["prompt"])

                response_data = call_automate_requests(
                    request_file=request_file,
                    prompt=prompt_str,
                    special_token=special_token,
                    keywords=keywords
                )
                
                response_dict[prompt_str] = response_data
                logger.info(f"Sleeping for {request_sleep} seconds")
                time.sleep(request_sleep)
            
                # Logging request-response to file
                with open(output_file, "w", encoding="utf-8") as file:
                    json.dump(response_dict, file, indent=4)

            error_responses, safe_responses, vulnerable_responses, combined_responses = call_response_evaluator(json_data=response_dict).values()

            try:
                if (report_type == "error"):
                    df = dict_to_df(error_responses)
                elif (report_type == "findings"):
                    df = dict_to_df(vulnerable_responses)
                else:
                    df = dict_to_df(combined_responses)
                
                # Logging response analysis to file
                df.index = df.index + 1
                df.to_csv(f"{response_analysis_file}")
                logger.info(f"Data logged in {response_analysis_file}")

            except Exception as e:
                logger.error(f"Error logging data: {str(e)}")            

    else:
        logger.error("No prompts generated - please retry.")

if __name__ == "__main__":
    main()