import json
import os
import pandas as pd
import streamlit as st
import time
import tempfile
from typing import Any, Dict, List, Optional

from src.models import BaseModel, OpenAIModel
from src.dashboard.sidebar import render_sidebar
from src.utils import (
    call_automate_requests,
    call_response_evaluator,
    craft_input_prompt,
    dict_to_df,
    load_attack_config,
    load_yaml_config,
    render_main_form,
    return_processing_overlay_style,
    reset_output_states,
    setup_logging
)

# Defining file paths
CONFIG_PATH = "src/config/config.yaml"
LOGO_PATH = "src/dashboard/logo.png"
ATTACK_FOLDER_PATH = "src/config/Attacks"
RESPONSE_CHECKER_FILE = "src/config/response_checker.txt"

logger = setup_logging()

def initialize_session_state():
    """Initialize all session state variables"""
    st.set_page_config(
        page_title="Blueinfy PenTestPrompt",
        page_icon=":computer:",
        layout="wide"
    )
    
    defaults = {
        'is_processing': False,
        'processing_complete': False,
        'output_prompts': [],
        'current_request': 0,
        'total_requests': 0,
        'automated_response_results': None,
        'attack_prompts': {},
        'selected_attacks': [],
        'error_message': None,
        'form_data': None,
        'automated_evaluation_results': {},
    }
    
    for key, default_value in defaults.items():
        if key not in st.session_state:
            st.session_state[key] = default_value

def show_processing_overlay():
    """Show a processing overlay that prevents UI interaction"""
    if st.session_state.is_processing:
        return_processing_overlay_style()

class MainView:
    def __init__(self, model_instance: BaseModel):
        self.model = model_instance
    
    def process_prompts(self, **kwargs) -> Dict[str, Any]:
        """
        Process the prompts with the given parameters.
        Returns a dictionary with status and list of prompts or error message.
        """
        output_prompts = []
        
        if not self.model:
            return {"status": "error", "message": "Model client not initialized."}

        if not kwargs.get('attack_prompt_dic'):
            return {"status": "success", "prompts": []} 
        
        for attack_key, attack_value in kwargs['attack_prompt_dic'].items():

            if not attack_value:
                st.warning(f"Skipping generation for '{attack_key}' as instruction is empty.")
                continue
            
            user_instruction = craft_input_prompt(
                application=kwargs['application'],
                application_description=kwargs['description'],
                attack_instruction=attack_value,
                num_prompts=kwargs['num_prompts']
            )

            try:
                output = self.model.generate_prompt(
                    system_instruction=kwargs['system_instruction'],
                    user_instruction=user_instruction,
                    num_prompts = kwargs['num_prompts'],
                    temperature=kwargs['temperature'],
                    top_p=kwargs['top_p'],
                    max_completion_tokens=kwargs['max_tokens']
                )

                if output["status"] == "success":
                    st.success(f"{output['message']} for {attack_key}")
                    output_prompts.extend(output['prompts'])
                else:
                    st.session_state.error_message = f"Error in generating prompts for {attack_key}: {output['message']}"

            except Exception as e:
                st.session_state.error_message = f"Exception during generation for {attack_key}: {str(e)}"
        
        if not output_prompts and kwargs.get('attack_prompt_dic'):
            return {"status": "error", "message": st.session_state.error_message}
        
        else:
            return {"status": "success", "prompts": output_prompts}


def process_results(request_file, prompts: List[str], special_token: str, request_sleep: int) -> Dict[str, Any]:
    """Processes prompts against a request file."""
    temp_files = []
    response_dict = {}

    if not isinstance(prompts, list):
        st.error("Internal Error: Prompts provided to process_results must be a list.")
        return {"status": "error", "message": "Invalid prompt format for processing."}

    if not prompts:
        st.warning("No prompts provided to evaluate.")
        return {"status": "success", "result": {}}

    if not request_file:
        st.error("Sample request file is missing for evaluation.")
        return {"status": "error", "message": "Missing sample request file."}

    st.session_state.is_processing = True
    results_placeholder = st.empty()
    progress_bar = st.progress(0)
    status_text = st.empty()
    results_df = pd.DataFrame([])
    try:
        with tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.txt') as req_tmp:
            content = request_file.getvalue()
            if not content:
                raise ValueError("Empty request file uploaded.")
            req_tmp.write(content)
            request_path = req_tmp.name
            temp_files.append(request_path)

        try:
            with open(RESPONSE_CHECKER_FILE, "r", encoding="utf-8") as f:
                keywords = [line.strip() for line in f if line.strip()]
        
        except FileNotFoundError:
            st.error(f"Response checker file not found at: {RESPONSE_CHECKER_FILE}")
            return {"status": "error", "message": "Response checker configuration missing."}
        except Exception as e:
            st.error(f"Error reading response checker file: {e}")
            return {"status": "error", "message": f"Error reading response checker file: {e}"}

        total = len(prompts)
        status_text.text(f"Starting evaluation for {total} prompts...")

        for idx, prompt in enumerate(prompts, 1):
            progress = float(idx) / total
            progress_bar.progress(progress)
            status_text.text(f"Processing request {idx} of {total}: {prompt[:50]}...")

            if not isinstance(prompt, str) or not prompt.strip():
                st.warning(f"Skipping empty or invalid prompt at index {idx-1}.")
                continue
            
            prompt = prompt.strip()
            logger.info(f"Sanitizing prompt: {prompt}")
            input = {"prompt": f"""{prompt}"""}
            prompt_str = json.dumps(input["prompt"])

            response_data = call_automate_requests(
                request_file=request_path,
                prompt=prompt_str,
                special_token=special_token,
                keywords=keywords
            )
            response_df = pd.DataFrame([response_data])
            response_dict[prompt_str] = response_data
            results_df = pd.concat([results_df, response_df], ignore_index=True)

            with results_placeholder.container():
                
                display_df = results_df.copy()
                display_df.index = display_df.index + 1
                st.dataframe(display_df, use_container_width=True)
            time.sleep(request_sleep)
        
        progress_bar.progress(1.0)
        status_text.text(f"Fetched responses for {total} prompts!")

        return {"status": "success", "result": response_dict}

    except ValueError as ve:
         st.error(f"Input Error: {ve}")
         return {"status": "error", "message": f"Input Error: {ve}"}
    except Exception as e:
        st.error(f"An unexpected error occurred during evaluation: {e}")
        return {"status": "error", "message": f"Unexpected evaluation error: {e}"}
    
    finally:
        st.session_state.is_processing = False
        for tmp_file in temp_files:
            try:
                if os.path.exists(tmp_file):
                    os.unlink(tmp_file)
            except Exception as e:
                st.warning(f"Failed to cleanup temporary file {tmp_file}: {e}")


def validate_inputs(form_data: Dict, api_key: Optional[str]) -> Optional[str]:
    """
    Validates form inputs based on the selected mode.
    Returns an error message string if validation fails, otherwise None.
    """
    mode = form_data['mode']
    application = form_data['application']
    description = form_data['description']
    attack_prompt_dic = form_data['attack_prompt_dic']
    additional_prompts_file = form_data['additional_prompts']
    sample_request_file = form_data['sample_request']

    generation_details_provided = bool(application or description or attack_prompt_dic)
    evaluation_intended = mode == "Generate Prompts & Evaluate Results"
    generation_intended = mode == "Generate Prompts" or (evaluation_intended and generation_details_provided)

    if generation_intended and not api_key:
        return ":red[API Key is required in the sidebar to generate prompts.]"

    if generation_intended:
        if not application:
            return ":red[Please select or specify the Application Type for generation.]"
        if not description:
            return ":red[Please describe the Application Context for generation.]"
        if not attack_prompt_dic:
            return ":red[Please select at least one Attack Type for generation.]"
        
        if "Custom Prompt" in attack_prompt_dic and (not attack_prompt_dic["Custom Prompt"] or attack_prompt_dic["Custom Prompt"] == "<insert instruction>"):
            return ":red[Please enter your custom prompt for generation]"
        
        if "Custom Prompt" in attack_prompt_dic and not attack_prompt_dic["Custom Prompt"].strip():
            return ":red[Please provide your custom attack prompt instruction, or deselect 'Custom Prompt'.]"

    if evaluation_intended:
        if not sample_request_file:
            return ":red[Please upload the Sample Request File for evaluation.]"
        
        
        if additional_prompts_file is None and not generation_details_provided:
            return ":red[For evaluation, please either upload additional prompts or provide details for prompt generation (Application, Description, Attack Type).]"

    return None


def main():
    
    required_files = {
        'config': CONFIG_PATH,
        'logo': LOGO_PATH,
        'response keywords': RESPONSE_CHECKER_FILE
    }
    missing_files = [name for name, path in required_files.items() if not os.path.exists(path)]
    if missing_files:
        st.error(f"Fatal Error: Missing required file(s): {', '.join(missing_files)}. Please ensure the application structure is correct.")
        
        for name in missing_files:
            st.error(f"Expected path for {name}: {required_files[name]}")
        return
    
    initialize_session_state()
    try:
        config = load_yaml_config(CONFIG_PATH)
        attack_config = load_attack_config(attack_folder_path=ATTACK_FOLDER_PATH)
        
        config.setdefault('Configuration', {})['attack_prompt'] = attack_config
    except Exception as e:
        st.error(f"Fatal Error: Failed to load configuration files: {e}")
        return

    model_info = config.get('Model Providers', {})

    model_provider, selected_model, api_key, temperature, top_p, max_tokens = render_sidebar(
        logo_path=LOGO_PATH,
        model_info=model_info
    )
    request_sleep = config.get('request_sleep', 10)

    st.title("PenTestPrompt")

    # Initializing model
    model = None
    if model_provider == "OpenAI":
        model = OpenAIModel(api_key, selected_model)
    # Add support for other Model providers if necessary

    view = MainView(model)
    form_data = render_main_form(config)
    
    if form_data['submit_button']:

        validation_error = validate_inputs(form_data, api_key)
        if validation_error:
            st.warning(validation_error)
            st.stop()

        reset_output_states()
        
        st.session_state.error_message = None
        st.session_state.form_data = form_data
        st.session_state.is_processing = True
        st.session_state.automated_response_results = {}
        
        show_processing_overlay()

        all_prompts_for_evaluation: List[str] = []

        try:
            # Generate Prompts (if required) 
            should_generate = (form_data['mode'] == "Generate Prompts" or (form_data['mode'] == "Generate Prompts & Evaluate Results" and form_data['attack_prompt_dic']))

            if should_generate:
                if not model:
                    raise ValueError(f"Model provider '{model_provider}' is not configured or selected.")
                if not model.initialize_client():
                    raise ValueError(f"Failed to initialize {model_provider} client. Check API key and configuration.")

                with st.spinner('Generating prompts...'):
                    generation_result = view.process_prompts(
                        application=form_data['application'],
                        description=form_data['description'],
                        system_instruction=form_data['system_instruction'],
                        attack_prompt_dic=form_data['attack_prompt_dic'],
                        num_prompts=form_data['num_prompts'],
                        temperature=temperature,
                        top_p=top_p,
                        max_tokens=max_tokens
                    )

                if generation_result["status"] == "error":
                    st.session_state.error_message = f"Prompt Generation Failed: {generation_result.get('message', 'Unknown error')}"

                    if not generation_result.get("prompts"):
                        raise ValueError(st.session_state.error_message)
                    else:
                        st.warning("Partial prompts generated despite errors. Continuing evaluation with available prompts.")
                        all_prompts_for_evaluation.extend(generation_result.get("prompts", []))

                else:
                    generated_prompts = generation_result.get("prompts", [])
                    if not generated_prompts and form_data['attack_prompt_dic']:
                        st.warning("Generation finished, but no prompts were produced. Check attack instructions and model response.")
                        st.session_state.processing_complete = True
                    else:
                        for sample in generated_prompts:
                            if isinstance(sample, str) and sample.strip():
                                all_prompts_for_evaluation.append(sample.strip())
                            else:
                                st.warning(f"Skipping invalid prompt: {sample}")

            # Add uploaded prompts to main prompt evaluation list
            uploaded_prompts_content = None
            if form_data['additional_prompts']:
                try:
                    uploaded_prompts_content = form_data['additional_prompts'].read().decode('utf-8')
                    uploaded_prompts_list = [line.strip() for line in uploaded_prompts_content.splitlines() if line.strip()]
                    if uploaded_prompts_list:
                        st.info(f"Adding {len(uploaded_prompts_list)} prompts from uploaded file.")
                        existing_prompts_set = set(all_prompts_for_evaluation)

                        for p in uploaded_prompts_list:
                            if p not in existing_prompts_set:
                                all_prompts_for_evaluation.append(p)

                    else:
                        st.warning("Uploaded prompts file is empty or contains only whitespace.")
                except Exception as e:
                    st.error(f"Error reading uploaded prompts file: {e}")

            st.session_state.output_prompts = all_prompts_for_evaluation

            # Evaluate Results
            if form_data['mode'] == "Generate Prompts & Evaluate Results":
                if not st.session_state.output_prompts:
                    st.warning("No prompts available for evaluation (neither generated nor uploaded).")
                
                else:
                    st.info(f"Starting evaluation with {len(st.session_state.output_prompts)} total prompts.")
                    evaluation_result = process_results(
                        request_file=form_data['sample_request'],
                        prompts=st.session_state.output_prompts,
                        special_token=form_data['special_token'],
                        request_sleep=request_sleep,
                    )
                    
                    if evaluation_result["status"] == "error":
                        st.session_state.error_message = evaluation_result.get("message", "Evaluation failed.")
                        
                    elif evaluation_result["status"] == "success":
                        st.session_state.automated_response_results = evaluation_result.get("result", {})
                                                
                        if st.session_state.automated_response_results:
                            error_responses, safe_responses, vulnerable_responses, combined_responses = call_response_evaluator(json_data=st.session_state.automated_response_results).values()

                            
                            st.session_state['automated_evaluation_results']['error'] = error_responses
                            st.session_state['automated_evaluation_results']['safe'] = safe_responses
                            st.session_state['automated_evaluation_results']['vulnerable'] = vulnerable_responses
                            st.session_state['automated_evaluation_results']['combined'] = combined_responses
                        
                        else:
                            st.info("Evaluation completed, but no responses were recorded.")
                            st.session_state['automated_evaluation_results'] = {'error': {}, 'safe': {}, 'vulnerable': {}, 'combined': {}}

            st.session_state.processing_complete = True
            st.session_state.is_processing = False


        except Exception as e:
            st.session_state.error_message = f"An unexpected error occurred: {str(e)}"
            st.session_state.processing_complete = True
            st.session_state.is_processing = False
        
        finally:
            
            st.session_state.is_processing = False
            st.rerun()

    # Displaying error message if exists
    if st.session_state.error_message:
        st.error(st.session_state.error_message)
        st.session_state.error_message = None

    if st.session_state.processing_complete and st.session_state.output_prompts:
        st.subheader("Prompts for Evaluation" if st.session_state.automated_response_results is not None else "Generated Prompts")
        
        output_prompts_str = '\n\n'.join(st.session_state.output_prompts)
        st.text_area(
            label=" ",
            value=output_prompts_str,
            height=200,
            key="output_prompts_display"
        )
        
        st.download_button(
            label="Download Prompts",
            data=output_prompts_str,
            file_name="prompts.txt",
            mime="text/plain"
        )

    if st.session_state.processing_complete and st.session_state.automated_evaluation_results:
        st.subheader("Evaluation Report")

        has_vulnerable = bool(st.session_state.automated_evaluation_results.get('vulnerable'))
        has_combined = bool(st.session_state.automated_evaluation_results.get('combined'))

        if not has_combined:
            st.info("No evaluation results to display.")
        else:
            report_options = ["Show all results"]
            if has_vulnerable:
                report_options.insert(0, "Show only findings")

            default_report_index = 0 if len(report_options) > 1 and has_vulnerable else 0

            evaluation_report_type = st.radio(
                "Filter Report:",
                report_options,
                index=default_report_index,
                horizontal=True,
                key="report_type_radio",
            )

            data_to_display = None
            file_name = "evaluation_report.csv"

            if evaluation_report_type == "Show only findings" and has_vulnerable:
                data_to_display = dict_to_df(st.session_state['automated_evaluation_results']['vulnerable'])
                file_name = "vulnerable_findings_report.csv"
            else:
                data_to_display = dict_to_df(st.session_state['automated_evaluation_results']['combined'])
                file_name = "full_evaluation_report.csv"


            if data_to_display is not None and not data_to_display.empty:
                display_df = data_to_display.copy()
                display_df.index = display_df.index + 1
                st.dataframe(display_df, use_container_width=True)

                st.download_button(
                    label="Download Report",
                    data=data_to_display.to_csv(index=False).encode('utf-8'),
                    file_name=file_name,
                    mime="text/csv",
                    help="Downloads the displayed evaluation results as a CSV file."
                )
            elif evaluation_report_type == "Show only findings":
                 st.info("No vulnerable findings were identified in this evaluation.")


if __name__ == "__main__":
    main()