#!/usr/bin/env python3
"""
Standalone VRAM/KV cache benchmarking tool for Kamiwaza models.
Tests various context lengths and measures actual VRAM requirements.
"""

import asyncio
import json
import logging
import os
import re
import time
from dataclasses import dataclass
from typing import Dict, List, Optional
import urllib3
import requests

# Suppress SSL warnings for local testing
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)

# Configuration
KAMIWAZA_API_URI = os.environ.get("KAMIWAZA_API_URI", "https://localhost/api")
VERIFY_SSL = False

# Context lengths to test (in tokens)
CONTEXT_LENGTHS = [1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072, 262144]

# Logging setup
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)


@dataclass
class BenchmarkResult:
    """Result from a single benchmark run"""
    model_id: str
    model_name: str
    context_length: int
    success: bool
    gb_required: float
    max_concurrency: Optional[float] = None
    error_message: Optional[str] = None
    gpu_memory_utilization: Optional[float] = None
    actual_max_model_len: Optional[int] = None
    accelerator: Optional[str] = None
    accelerator_capability: Optional[float] = None
    supports_flash_attention: Optional[bool] = None
    cuda_graphs_enabled: Optional[bool] = None


class KamiwazaVRAMBenchmark:
    """Benchmarking tool for VRAM/KV cache requirements"""
    
    def __init__(self, api_uri: str = KAMIWAZA_API_URI, verify_ssl: bool = VERIFY_SSL):
        self.api_uri = api_uri.rstrip('/')
        self.verify_ssl = verify_ssl
        self.session = requests.Session()
        self.session.verify = verify_ssl
        self.results: List[BenchmarkResult] = []
        self.hardware_info = None
        self.accelerator = None
        self.accelerator_capability = None
        
    def _api_call(self, method: str, endpoint: str, **kwargs) -> requests.Response:
        """Make an API call to Kamiwaza"""
        url = f"{self.api_uri}/{endpoint.lstrip('/')}"
        try:
            response = self.session.request(method, url, **kwargs)
            response.raise_for_status()
            return response
        except requests.exceptions.RequestException as e:
            logger.error(f"API call failed: {e}")
            raise
    
    def get_hardware_info(self) -> Dict:
        """Get hardware information from the cluster"""
        try:
            response = self._api_call("GET", "/cluster/nodes")
            nodes = response.json()
            
            # Extract hardware capabilities
            for node in nodes:
                gpus = node.get("gpus", [])
                if gpus:
                    # Get first GPU info
                    gpu = gpus[0]
                    self.accelerator = gpu.get("accelerator", "cuda")
                    
                    # Determine capability based on GPU model
                    gpu_name = gpu.get("name", "").lower()
                    if "a100" in gpu_name or "a6000" in gpu_name:
                        self.accelerator_capability = 8.0
                    elif "v100" in gpu_name:
                        self.accelerator_capability = 7.0
                    elif "t4" in gpu_name:
                        self.accelerator_capability = 7.5
                    elif "a10" in gpu_name:
                        self.accelerator_capability = 8.6
                    elif "h100" in gpu_name:
                        self.accelerator_capability = 9.0
                    elif "l4" in gpu_name:
                        self.accelerator_capability = 8.9
                    else:
                        # Try to extract from compute capability if available
                        self.accelerator_capability = gpu.get("compute_capability", 8.0)
                    
                    return {
                        "accelerator": self.accelerator,
                        "accelerator_capability": self.accelerator_capability,
                        "gpu_name": gpu.get("name"),
                        "gpu_memory": gpu.get("memory_bytes")
                    }
        except Exception as e:
            logger.warning(f"Failed to get hardware info: {e}")
            # Default to common values
            self.accelerator = "cuda"
            self.accelerator_capability = 8.0
        
        return {
            "accelerator": self.accelerator,
            "accelerator_capability": self.accelerator_capability
        }
    
    def get_guide_models(self, platform: Optional[str] = None) -> List[Dict]:
        """Get models from the guide API"""
        endpoint = "/models/guide"
        if platform:
            endpoint = f"/models/guide/{platform}"
        
        response = self._api_call("GET", endpoint)
        return response.json()
    
    def get_model_config(self, model_id: str) -> Dict:
        """Get model configuration including KV cache parameters"""
        response = self._api_call("GET", f"/models/{model_id}")
        return response.json()
    
    def deploy_model(self, model_id: str, max_model_len: int, 
                    gpu_memory_utilization: float = 0.90) -> str:
        """Deploy a model with specific context length"""
        payload = {
            "model_id": model_id,
            "engine": "vllm",
            "max_model_len": max_model_len,
            "gpu_memory_utilization": gpu_memory_utilization,
            "batch_size": 1  # Testing with bs=1 as per requirement
        }
        
        response = self._api_call("POST", "/serving/deploy_model", json=payload)
        deployment = response.json()
        return deployment.get("id")
    
    def get_deployment_logs(self, deployment_id: str, lines: int = 100) -> List[str]:
        """Get logs from a deployment"""
        response = self._api_call("GET", f"/serving/deployments/{deployment_id}/logs",
                                 params={"lines": lines})
        logs = response.json()
        return logs.get("logs", [])
    
    def wait_for_deployment(self, deployment_id: str, timeout: int = 120) -> Dict:
        """Wait for deployment to be ready or fail"""
        start_time = time.time()
        
        while time.time() - start_time < timeout:
            response = self._api_call("GET", f"/serving/deployments/{deployment_id}")
            deployment = response.json()
            status = deployment.get("status")
            
            if status == "DEPLOYED":
                return {"success": True, "deployment": deployment}
            elif status in ["FAILED", "ERROR"]:
                return {"success": False, "deployment": deployment}
            
            time.sleep(2)
        
        return {"success": False, "deployment": None, "error": "Timeout"}
    
    def parse_vllm_logs(self, logs: List[str]) -> Dict:
        """Parse VLLM logs for memory information"""
        result = {
            "success": False,
            "gb_required": 0.0,
            "max_concurrency": None,
            "error_message": None,
            "actual_max_model_len": None
        }
        
        for log in logs:
            # Success pattern: "Maximum concurrency for X tokens per request: Y"
            concurrency_match = re.search(
                r"Maximum concurrency for ([\d,]+) tokens per request: ([\d.]+)x",
                log
            )
            if concurrency_match:
                result["success"] = True
                tokens = int(concurrency_match.group(1).replace(",", ""))
                result["max_concurrency"] = float(concurrency_match.group(2))
                result["actual_max_model_len"] = tokens
            
            # Error pattern: KV cache memory error
            error_match = re.search(
                r"(\d+\.?\d*) GiB KV cache is needed.*available KV cache memory \((\d+\.?\d*) GiB\)",
                log
            )
            if error_match:
                result["gb_required"] = float(error_match.group(1))
                result["error_message"] = log
            
            # Alternative error pattern
            if "ValueError" in log and "max seq len" in log:
                result["error_message"] = log
                # Try to extract estimated max model length
                est_match = re.search(r"estimated maximum model length is (\d+)", log)
                if est_match:
                    result["actual_max_model_len"] = int(est_match.group(1))
        
        return result
    
    def undeploy_model(self, deployment_id: str):
        """Undeploy a model"""
        try:
            self._api_call("DELETE", f"/serving/deployments/{deployment_id}")
        except Exception as e:
            logger.warning(f"Failed to undeploy {deployment_id}: {e}")
    
    async def benchmark_model(self, model: Dict) -> List[BenchmarkResult]:
        """Benchmark a single model across different context lengths"""
        model_id = model.get("id")
        model_name = model.get("name", model_id)
        results = []
        
        logger.info(f"Starting benchmark for model: {model_name}")
        
        # Get model's maximum context length from config
        try:
            config = self.get_model_config(model_id)
            max_ctx = config.get("max_position_embeddings", 262144)
        except Exception as e:
            logger.error(f"Failed to get config for {model_name}: {e}")
            max_ctx = 262144
        
        # Test each context length up to model's maximum
        for ctx_len in CONTEXT_LENGTHS:
            if ctx_len > max_ctx:
                logger.info(f"Skipping {ctx_len} (exceeds model max {max_ctx})")
                continue
            
            logger.info(f"Testing {model_name} with context length {ctx_len}")
            
            # Try with increasing GPU memory utilization
            for gpu_util in [0.85, 0.90, 0.92, 0.95]:
                try:
                    # Deploy model
                    deployment_id = self.deploy_model(
                        model_id, 
                        ctx_len,
                        gpu_memory_utilization=gpu_util
                    )
                    
                    # Wait for deployment
                    self.wait_for_deployment(deployment_id)
                    
                    # Get logs
                    logs = self.get_deployment_logs(deployment_id)
                    
                    # Parse logs
                    log_result = self.parse_vllm_logs(logs)
                    
                    # Determine flash attention support based on accelerator
                    supports_flash_attention = False
                    cuda_graphs_enabled = False
                    
                    if self.accelerator == "cuda":
                        supports_flash_attention = (self.accelerator_capability >= 8.0)
                        cuda_graphs_enabled = supports_flash_attention
                    elif self.accelerator == "rocm":
                        supports_flash_attention = True
                        cuda_graphs_enabled = False  # HIP graphs default off
                    elif self.accelerator == "gaudi":
                        supports_flash_attention = False
                        cuda_graphs_enabled = False
                    else:  # metal/cpu
                        supports_flash_attention = True
                        cuda_graphs_enabled = False
                    
                    # Create benchmark result
                    result = BenchmarkResult(
                        model_id=model_id,
                        model_name=model_name,
                        context_length=ctx_len,
                        success=log_result["success"],
                        gb_required=log_result.get("gb_required", 0.0),
                        max_concurrency=log_result.get("max_concurrency"),
                        error_message=log_result.get("error_message"),
                        gpu_memory_utilization=gpu_util,
                        actual_max_model_len=log_result.get("actual_max_model_len"),
                        accelerator=self.accelerator,
                        accelerator_capability=self.accelerator_capability,
                        supports_flash_attention=supports_flash_attention,
                        cuda_graphs_enabled=cuda_graphs_enabled
                    )
                    
                    results.append(result)
                    
                    # Undeploy
                    self.undeploy_model(deployment_id)
                    
                    # If successful, move to next context length
                    if result.success:
                        logger.info(f" Success at {ctx_len} tokens with gpu_util={gpu_util}")
                        break
                    else:
                        logger.warning(f" Failed at {ctx_len} tokens with gpu_util={gpu_util}")
                        # If we can't even fit with 0.95, skip larger contexts
                        if gpu_util >= 0.95:
                            logger.info(f"Cannot fit {ctx_len} tokens, skipping larger contexts")
                            return results
                    
                except Exception as e:
                    logger.error(f"Error testing {model_name} at {ctx_len}: {e}")
                    continue
            
            # Small delay between tests
            await asyncio.sleep(2)
        
        return results
    
    def extrapolate_missing_values(self, results: List[BenchmarkResult]) -> List[Dict]:
        """Extrapolate VRAM requirements for missing context lengths"""
        output = []
        
        # Group by model
        by_model = {}
        for r in results:
            if r.model_id not in by_model:
                by_model[r.model_id] = []
            by_model[r.model_id].append(r)
        
        for model_id, model_results in by_model.items():
            # Sort by context length
            model_results.sort(key=lambda x: x.context_length)
            
            # Find successful runs to establish scaling
            successful = [r for r in model_results if r.success]
            
            # Get hardware info from first result
            first_result = model_results[0]
            hardware_info = {
                "accelerator": first_result.accelerator,
                "accelerator_capability": first_result.accelerator_capability,
                "supports_flash_attention": first_result.supports_flash_attention,
                "cuda_graphs_enabled": first_result.cuda_graphs_enabled
            }
            
            if len(successful) >= 2:
                # Calculate scaling factor (should be ~linear for KV cache)
                ctx1, gb1 = successful[0].context_length, successful[0].gb_required
                ctx2, gb2 = successful[-1].context_length, successful[-1].gb_required
                
                if ctx2 > ctx1 and gb2 > gb1:
                    # GB per token
                    gb_per_token = (gb2 - gb1) / (ctx2 - ctx1)
                    base_gb = gb1 - (gb_per_token * ctx1)
                    
                    # Extrapolate for all context lengths
                    for ctx_len in CONTEXT_LENGTHS:
                        estimated_gb = base_gb + (gb_per_token * ctx_len)
                        output.append({
                            "model_id": model_id,
                            "model_name": model_results[0].model_name,
                            "max_model_len": ctx_len,
                            "gb_required": round(estimated_gb, 3),
                            "measured": any(r.context_length == ctx_len and r.success 
                                         for r in model_results),
                            **hardware_info
                        })
            else:
                # Not enough data to extrapolate
                for r in model_results:
                    if r.success:
                        output.append({
                            "model_id": r.model_id,
                            "model_name": r.model_name,
                            "max_model_len": r.context_length,
                            "gb_required": round(r.gb_required, 3),
                            "measured": True,
                            **hardware_info
                        })
        
        return output
    
    async def run_benchmark(self, platform: Optional[str] = None):
        """Run the full benchmark suite"""
        # Get hardware info first
        logger.info("Detecting hardware capabilities...")
        hw_info = self.get_hardware_info()
        logger.info(f"Hardware: {hw_info.get('accelerator', 'unknown')} "
                   f"capability={hw_info.get('accelerator_capability', 'unknown')}")
        
        # Get models from guide
        logger.info(f"Fetching models from guide (platform={platform or 'all'})")
        models = self.get_guide_models(platform)
        logger.info(f"Found {len(models)} models to benchmark")
        
        # Benchmark each model
        all_results = []
        for model in models:
            results = await self.benchmark_model(model)
            all_results.extend(results)
            self.results.extend(results)
        
        # Generate output
        output = self.extrapolate_missing_values(all_results)
        
        # Save results with hardware info
        results_data = {
            "hardware": {
                "accelerator": self.accelerator,
                "accelerator_capability": self.accelerator_capability,
                "supports_flash_attention": self.accelerator_capability >= 8.0 if self.accelerator == "cuda" else True,
                "cuda_graphs_enabled": self.accelerator_capability >= 8.0 if self.accelerator == "cuda" else False
            },
            "benchmarks": output
        }
        
        output_file = f"vram_benchmark_results_{int(time.time())}.json"
        with open(output_file, "w") as f:
            json.dump(results_data, f, indent=2)
        
        logger.info(f"Results saved to {output_file}")
        
        # Print summary
        print("\n" + "="*80)
        print("VRAM Benchmark Results Summary")
        print("="*80)
        print(f"Hardware: {self.accelerator} (capability {self.accelerator_capability})")
        print(f"Flash Attention: {'✓' if results_data['hardware']['supports_flash_attention'] else '✗'}")
        print(f"CUDA Graphs: {'✓' if results_data['hardware']['cuda_graphs_enabled'] else '✗'}")
        print("="*80)
        
        for item in output:
            measured = "" if item["measured"] else "�"
            print(f"{measured} {item['model_name']}: "
                  f"{item['max_model_len']} tokens = {item['gb_required']} GB")
        
        return results_data


async def main():
    """Main entry point"""
    import argparse
    
    parser = argparse.ArgumentParser(description="Benchmark VRAM/KV cache requirements")
    parser.add_argument("--api-uri", default=KAMIWAZA_API_URI,
                       help="Kamiwaza API URI")
    parser.add_argument("--platform", choices=["linux", "windows", "mac"],
                       help="Platform to test models for")
    parser.add_argument("--verify-ssl", action="store_true",
                       help="Verify SSL certificates")
    
    args = parser.parse_args()
    
    benchmark = KamiwazaVRAMBenchmark(
        api_uri=args.api_uri,
        verify_ssl=args.verify_ssl
    )
    
    await benchmark.run_benchmark(platform=args.platform)


if __name__ == "__main__":
    asyncio.run(main())