import argparse
import logging

from typing import Dict, Any
from src.utils import load_attack_config

logger = logging.getLogger(__name__)

def validate_float_range(value: str, range_values: list, param_name: str) -> float:
    """
        Validate float values within specified range.
    """
    try:
        float_val = float(value)
        min_val, max_val = range_values
        if min_val <= float_val <= max_val:
            return float_val
        raise ValueError
    except ValueError:
        raise argparse.ArgumentTypeError(
            f"{param_name} must be between {min_val} and {max_val}, got: {value}"
        )

def validate_int_range(value: str, range_values: list, param_name: str) -> int:
    """
        Validate integer values within specified range.
    """
    try:
        int_val = int(value)
        min_val, max_val = range_values
        if min_val <= int_val <= max_val:
            return int_val
        raise ValueError
    except ValueError:
        raise argparse.ArgumentTypeError(
            f"{param_name} must be between {min_val} and {max_val}, got: {value}"
        )

def validate_int(value: str) -> int:
    try:
        int_val = int(value)
        if int_val <= 0:
            raise ValueError(f"{value} must be a positive integer")
        return int_val
    except ValueError:
        raise argparse.ArgumentTypeError(f"num_prompts must be a positive integer")

def validate_attack(attack_value_options: list):
    """
    Creates an argparse validator for comma-separated attack indices.

    This function returns a validation function designed to be used as the 'type' for an argparse argument. The returned validator checks if the input string:
    1. Represents a single integer or a comma-separated list of integers.
    2. Contains no empty elements (e.g., '1,,2' is invalid).
    3. Does not end with a comma (e.g., '1,2,' is invalid).
    4. Contains only integers that fall within the valid range [1, N], where N is the number of items in `attack_value_options`.
    """
    def validator(value: str) -> list:
        try:
            # Handle trailing/multiple commas
            if ',,' in value or value.endswith(','):
                raise ValueError("Invalid format - no trailing commas allowed")
            
            # Convert to integers, strip whitespace
            attack_types = [int(x.strip()) for x in value.split(',') if x.strip()]
            
            # Validate range
            max_value = len(attack_value_options)
            for attack_type in attack_types:
                if not (1 <= attack_type <= max_value):
                    raise ValueError(f"Values must be between 1 and {max_value}")
                    
            # Convert to 0-based indices
            return [x-1 for x in attack_types]
            
        except ValueError as e:
            raise argparse.ArgumentTypeError(str(e))
    return validator

def create_parser(config: Dict[str, Any]) -> argparse.ArgumentParser:
    
    """
        Create argument parser based on YAML configuration.
    
        Args:
            config (Dict): Configuration dictionary loaded from the backend YAML file
    """
    parser = argparse.ArgumentParser(
        description="Welcome to Blueinfy PenTestPrompt CLI Tool",
        formatter_class=argparse.RawDescriptionHelpFormatter
    )

    # Model Provider Selection
    providers = list(config.get('Model Providers', {}).keys())
    models = set()

    for provider, pconfig in config.get('Model Providers', {}).items():
        models.update(pconfig.get('models', []))

    parser.add_argument(
        "--provider",
        type=str,
        choices=providers,
        required=True,
        help=f"Select the AI model provider: {', '.join(providers)}"
    )

    parser.add_argument(
        "--model",
        type=str,
        choices=models,
        required=True,
        help=f"Select the model you want to use from {models}"
    )

    parser.add_argument(
        "--temperature",
        type=float,
        default=0.7,
        help=f"Provide the temperature for the model: It controls the randomness and creativity of the responses",
    )
    
    parser.add_argument(
        "--top_p",
        type=float,
        default=0.7,
        help=f"Provide the top_p for the model: It controls the randomness of the output by determining the possible words to consider when generating the next word.",
    )
    
    parser.add_argument(
        "--max_tokens",
        type=int,
        default=4096,
        help=f"Select the max tokens limit of the model",
    )

    parser.add_argument(
        "--api-key",
        type=str,
        required=True,
        help="Add the API key for the selected provider"
    )

    app_values = config.get('Application', {}).get('values', [])
    parser.add_argument(
        "--application",
        type=int,
        choices=range(1, len(app_values)+1),
        required=True,
        help=f"Application type (1-{len(app_values)}): " + 
             ", ".join(f"{i+1}: {app}" for i, app in enumerate(app_values))
    )

    parser.add_argument(
        "--application-description",
        type=str,
        required=True,
        help="Describe the application for which you want to generate prompts.",
    )

    attack_values_options = config.get('Configuration', {}).get('attack_prompt', [])
    parser.add_argument(
        "--attack",
        type=validate_attack(attack_values_options),
        required=True,
        help=f"Enter a comma-separated list of attack types\n\n Attack type (1-{len(attack_values_options)}): " + 
        ", ".join(f"{i+1}: {attack}" for i, attack in enumerate(attack_values_options))
    )

    num_prompts = config.get('Configuration', {}).get('num_prompts', 50)
    parser.add_argument(
        "--num_prompts",
        type=validate_int,
        default=num_prompts,
        required=True,
        help=f"Number of prompts to generate (default: {num_prompts}). Note - Higher the number of prompts, the performance might degrade. Thus, it is recommended to keep the number of prompts less than 1000)"
    )
    
    parser.add_argument(
        '--request_file', 
        type=str, 
        help='input txt file which contains the sample request. \nNOTE-1: The special token ### must be present in the body where you want to replace and test the generated prompts. \nNOTE-2: Required format is <METHOD> <FULL_URL> <HTTP/VERSION> \nExample: POST https://target.com/endpoint HTTP/2'
    )

    parser.add_argument(
        '--special_token', 
        type=str, 
        default="###", 
        help='a special token in the body which will be replaced with the prompts'
    )
    
    parser.add_argument(
        '--output_file', 
        type=str, 
        default="data.json", 
        help='output JSON file where the request-response logs are saved'
    )

    parser.add_argument(
        '--response_checker_file', 
        type=str, 
        default="src/config/response_checker.txt", 
        help='the path to the txt file which contains the keywords based on which we evaluate responses'
    )
    
    parser.add_argument(
        "--report_type",
        type=str,
        choices=["findings", "errors", "combined"],
        default="combined",
        help=f"Choose the report_type that you wish to generate - whether only findings, only errors or the combined logs"
    )

    parser.add_argument(
        '--response_analysis_file', 
        type=str, 
        help='output csv file where the response analysis is saved'
    )

    parser.add_argument(
        '--additional_prompts_file', 
        type=str,
        help="Additional prompts that you want to test the application with"
    )

    return parser

def validate_args(args, config):
    """
        Validate arguments after parsing.
        Raises error if the arguments are not valid
    """
    provider_config = config['Model Providers'][args.provider]
    
    if args.model not in provider_config['models']:
        raise ValueError(f"Model {args.model} not available for provider {args.provider}")
    
    temp_range = provider_config['temperature_range']
    validate_float_range(args.temperature, temp_range, 'temperature')
    
    top_p_range = provider_config['top_p_range']
    validate_float_range(args.top_p, top_p_range, 'top_p')

    max_tokens_range = provider_config['max_tokens_range']
    validate_int_range(args.max_tokens, max_tokens_range, 'max_tokens')

def parse_cli_arguments(config: dict):

    """
    Fetch the parameters from the user input and returns the configuration dictionary.
    Args:
        Input: 
            config (Dict): configurations loaded from YAML file
        Output:
            config_dict (Dict): configurations dict created via user CLI input
            prompt_dict (Dict): dictionary for automating the testing of generated prompts 
    """
    # 1st, update the attack config in main config file    
    attack_config = load_attack_config(attack_folder_path="src/config/Attacks")
    config['Configuration']['attack_prompt'] = attack_config

    parser = create_parser(config)
    args = parser.parse_args()
    validate_args(args, config)

    # Fetching application information
    app_values = config['Application']['values']
    selected_application = app_values[args.application-1]
    if selected_application == "Other":
        selected_application = input("Please enter the custom domain of your application:\t")
    application_description = args.application_description
    logger.info(f"Selected application is: {selected_application}\n")

    # Fetching attack information
    attack_values_options = list(config.get('Configuration', {}).get('attack_prompt', []).keys())
    selected_attacks = [attack_values_options[index] for index in args.attack]
    logger.info(f"Selected Attacks are: {selected_attacks}\n")
    
    attack_info_dict = {}
    if "Custom Prompt" in selected_attacks:
        attack_instruction = config['Configuration']['attack_prompt']["Custom Prompt"].strip()
        # If user selected custom prompt and has not specified instruction in the config file, taking runtime input from user
        if attack_instruction == "<insert instruction>":
            attack_info_dict["Custom Prompt"] = input("Please enter the custom attack prompt:\t")
        else:
            attack_info_dict["Custom Prompt"] = config['Configuration']['attack_prompt']["Custom Prompt"].strip()

    else:
        for selected_attack in selected_attacks:
            selected_attack_instruction = config['Configuration']['attack_prompt'][selected_attack].strip()
            attack_info_dict[selected_attack] = selected_attack_instruction

    config_dict = {
        "model_provider": args.provider,
        "selected_model": args.model,
        "temperature": args.temperature,
        "top_p": args.top_p,
        "max_tokens": args.max_tokens,
        "api_key": args.api_key,
        "system_instruction": config['Configuration']['system_prompt'],
        "application": selected_application,
        "application_description": application_description,
        "attack_info": attack_info_dict,
        "num_prompts": args.num_prompts,
    }

    if args.request_file is None:
        method = "Generate Prompts"
    else:
        method = "Generate and Evaluate"
    
    response_analysis_file = args.response_analysis_file
    if response_analysis_file is None and method == "Generate and Evaluate":
        response_analysis_file = f"{args.report_type}_analysis.csv"

    eval_dict = {
        "method": method,
        "special_token": args.special_token,
        "request_file": args.request_file,
        "output_file": args.output_file,
        "request_sleep": config.get('request_sleep', 3),
        "response_checker_file": args.response_checker_file,
        "report_type": args.report_type,
        "response_analysis_file": response_analysis_file,
        "additional_prompts_file": args.additional_prompts_file
    }

    return config_dict, eval_dict