#!/usr/bin/env python3
import asyncio
import aiohttp
import sys
import json
import logging
from typing import Dict, Any, List, Optional
from urllib.parse import urljoin
import uuid

# Configure logging to stderr so it doesn't interfere with stdout JSON-RPC
logging.basicConfig(level=logging.INFO, stream=sys.stderr, format='[KamiwazaBridge] %(message)s')
logger = logging.getLogger("KamiwazaBridge")

try:
    from kamiwaza.serving.garden.tool.tool_service import ToolService
except ImportError:
    logger.error("Could not import ToolService. Ensure you are running in the Kamiwaza environment.")
    sys.exit(1)

class MCPClient:
    """Manages a single SSE connection to a backend tool."""
    def __init__(self, name: str, base_url: str, session: aiohttp.ClientSession):
        self.name = name
        self.base_url = base_url
        self.session = session
        self.sse_url = urljoin(base_url, "sse")
        self.post_endpoint: Optional[str] = None
        self.tools: List[Dict[str, Any]] = []
        self.resources: List[Dict[str, Any]] = []
        self.prompts: List[Dict[str, Any]] = []
        self.ready = False

    async def connect(self, message_queue: asyncio.Queue):
        """Connects to the SSE endpoint and listens for events."""
        logger.info(f"Connecting to {self.name} at {self.sse_url}")
        while True:
            try:
                async with self.session.get(self.sse_url, ssl=False) as response:
                    if response.status != 200:
                        logger.warning(f"Connection to {self.name} failed: {response.status}")
                        await asyncio.sleep(5)
                        continue
                    
                    logger.info(f"Connected to {self.name}")
                    
                    async for raw_line in response.content:
                        line = raw_line.decode('utf-8').strip()
                        if not line:
                            continue
                        
                        if line.startswith("event: endpoint"):
                            continue
                            
                        if line.startswith("data: "):
                            data = line[6:].strip()
                            
                            if data.startswith("/"):
                                self.post_endpoint = urljoin(self.base_url, data)
                                logger.info(f"{self.name} POST endpoint: {self.post_endpoint}")
                                self.ready = True
                                # Trigger a tools/list refresh or similar?
                                # For now, just mark ready.
                                continue
                            
                            try:
                                message = json.loads(data)
                                # Tag message with source for debugging/routing if needed?
                                # For notifications, we just forward them.
                                if message.get("method") == "notifications/message":
                                    # Forward log messages
                                    await message_queue.put(message)
                                elif "id" in message:
                                    # Responses to requests we sent?
                                    # We need to track request IDs if we send requests.
                                    # But here we are acting as a server to the client, and client to the backend.
                                    pass 
                                else:
                                    # Forward other notifications
                                    await message_queue.put(message)
                                    
                            except json.JSONDecodeError:
                                logger.warning(f"Invalid JSON from {self.name}: {data}")
                                
            except Exception as e:
                logger.error(f"Error in {self.name} connection: {e}")
                self.ready = False
                await asyncio.sleep(5)

    async def send_request(self, payload: Dict[str, Any]) -> Dict[str, Any]:
        """Sends a JSON-RPC request to the backend tool."""
        if not self.post_endpoint:
            raise Exception(f"Tool {self.name} not ready (no POST endpoint)")
            
        async with self.session.post(self.post_endpoint, json=payload, ssl=False) as response:
            if response.status != 200 and response.status != 202:
                raise Exception(f"POST to {self.name} failed: {response.status}")
            
            # For MCP over SSE, the response comes back via SSE usually?
            # Wait, standard MCP HTTP POST returns 202 Accepted and response comes via SSE.
            # BUT, if we are proxying, we need to match the response ID.
            # This implies we need to listen to SSE and capture the response matching the ID.
            # This is getting complicated.
            # Actually, many MCP implementations return the result in the POST response if it's immediate?
            # No, spec says 202 Accepted.
            
            # Responses are delivered asynchronously via SSE; immediate response is empty ack.
            return {"status": "accepted"}

class KamiwazaBridge:
    def __init__(self):
        self.tool_service = ToolService()
        self.clients: Dict[str, MCPClient] = {}
        self.session: Optional[aiohttp.ClientSession] = None
        self.out_queue: asyncio.Queue = asyncio.Queue()
        self.request_map: Dict[Any, str] = {} # Maps Client Request ID -> Backend Tool Name
        self.tool_routing: Dict[str, str] = {} # Maps Tool Name -> Backend Tool Name
        self.pending_requests: Dict[str, asyncio.Future[Dict[str, Any]]] = {}

    async def discover_tools(self):
        """Periodically discovers deployed tools."""
        while True:
            try:
                deployments = self.tool_service.list_tool_deployments()
                logger.info(f"Found {len(deployments)} deployments")
                current_names = set()
                
                for d in deployments:
                    # Check status. Assuming 'DEPLOYED' or 'RUNNING' based on previous context.
                    # The ToolDeployment schema default is 'UNINITIALIZED'.
                    if d.status not in ["DEPLOYED", "RUNNING"]:
                        continue
                        
                    app_url = d.url
                    if not app_url:
                        continue
                        
                    if not app_url.endswith("/"):
                        app_url += "/"
                        
                    name = d.name
                    current_names.add(name)
                    
                    if name not in self.clients:
                        logger.info(f"Discovered new tool: {name} at {app_url}")
                        client = MCPClient(name, app_url, self.session)
                        self.clients[name] = client
                        asyncio.create_task(client.connect(self.out_queue))
                            
                # Remove stale clients? (Optional)
                
            except Exception as e:
                logger.error(f"Discovery error: {e}")
                
            await asyncio.sleep(10)

    async def handle_stdin(self):
        """Reads JSON-RPC from stdin."""
        loop = asyncio.get_event_loop()
        reader = asyncio.StreamReader()
        protocol = asyncio.StreamReaderProtocol(reader)
        await loop.connect_read_pipe(lambda: protocol, sys.stdin)
        
        while True:
            line = await reader.readline()
            if not line:
                break
            
            try:
                request = json.loads(line)
                await self.process_request(request)
            except json.JSONDecodeError:
                logger.error("Invalid JSON from stdin")

    async def send_request_to_client(self, client_name: str, method: str, params: Optional[Dict] = None) -> Dict:
        """Sends a request to a specific backend client and waits for response."""
        client = self.clients.get(client_name)
        if not client or not client.ready:
            raise Exception(f"Client {client_name} not available")
            
        req_id = str(uuid.uuid4())
        payload: Dict[str, Any] = {"jsonrpc": "2.0", "method": method, "id": req_id}
        if params is not None:
            payload["params"] = params
            
        # Create a future for the response
        future: asyncio.Future[Dict[str, Any]] = asyncio.Future()
        self.pending_requests[req_id] = future
        
        try:
            # Send POST
            await client.send_request(payload)
            # Wait for response via SSE
            response = await asyncio.wait_for(future, timeout=10.0)
            return response
        except Exception as e:
            del self.pending_requests[req_id]
            raise e

    async def process_request(self, request: Dict[str, Any]):
        """Processes a request from the client."""
        method = request.get("method")
        req_id = request.get("id")
        
        if method == "initialize":
            response = {
                "jsonrpc": "2.0",
                "id": req_id,
                "result": {
                    "protocolVersion": "2024-11-05",
                    "capabilities": {
                        "tools": {"listChanged": True},
                        "resources": {"listChanged": True},
                        "prompts": {"listChanged": True}
                    },
                    "serverInfo": {"name": "KamiwazaBridge", "version": "1.0.0"}
                }
            }
            print(json.dumps(response))
            sys.stdout.flush()
            
        elif method == "tools/list":
            all_tools = []
            self.tool_routing.clear()
            
            # Query all clients
            tasks = []
            client_names = []
            
            for name, client in self.clients.items():
                if client.ready:
                    tasks.append(self.send_request_to_client(name, "tools/list"))
                    client_names.append(name)
            
            if tasks:
                results = await asyncio.gather(*tasks, return_exceptions=True)
                for name, res in zip(client_names, results):
                    if isinstance(res, dict) and "result" in res and "tools" in res["result"]:
                        tools = res["result"]["tools"]
                        for tool in tools:
                            t_name = tool["name"]
                            # Check for collision? For now, just overwrite or prefix?
                            # User said "expose all underlying tools".
                            # If collision, maybe prefix with client name?
                            if t_name in self.tool_routing:
                                logger.warning(f"Tool collision: {t_name} in {name} and {self.tool_routing[t_name]}")
                                # Rename?
                                # tool["name"] = f"{name}_{t_name}"
                                # t_name = tool["name"]
                            
                            self.tool_routing[t_name] = name
                            all_tools.append(tool)
                    else:
                        logger.error(f"Failed to list tools from {name}: {res}")
            
            response = {
                "jsonrpc": "2.0",
                "id": req_id,
                "result": {"tools": all_tools}
            }
            print(json.dumps(response))
            sys.stdout.flush()

        elif method == "tools/call":
            params = request.get("params", {})
            tool_name = params.get("name")
            client_name = self.tool_routing.get(tool_name)
            
            if not client_name:
                error = {"code": -32601, "message": f"Tool not found: {tool_name}"}
                print(json.dumps({"jsonrpc": "2.0", "id": req_id, "error": error}))
                sys.stdout.flush()
                return

            try:
                # Forward request
                res = await self.send_request_to_client(client_name, "tools/call", params)
                # Send back result, preserving original ID
                res["id"] = req_id
                print(json.dumps(res))
                sys.stdout.flush()
            except Exception as e:
                error = {"code": -32000, "message": str(e)}
                print(json.dumps({"jsonrpc": "2.0", "id": req_id, "error": error}))
                sys.stdout.flush()

        else:
            # Handle other methods or ignore
            pass

    async def run(self):
        self.pending_requests = {} # Map internal_req_id -> Future
        
        async with aiohttp.ClientSession() as session:
            self.session = session
            asyncio.create_task(self.discover_tools())
            asyncio.create_task(self.handle_stdin())
            
            while True:
                # Process output queue (notifications from backends)
                msg = await self.out_queue.get()
                
                # If it's a response to a pending request, resolve the future
                if "id" in msg and msg["id"] in self.pending_requests:
                    future = self.pending_requests.pop(msg["id"])
                    if not future.done():
                        future.set_result(msg)
                else:
                    # It's a notification or unrequested message, forward to client
                    # Strip ID if present? No, notifications don't have ID.
                    if "id" not in msg:
                        print(json.dumps(msg))
                        sys.stdout.flush()


if __name__ == "__main__":
    bridge = KamiwazaBridge()
    try:
        asyncio.run(bridge.run())
    except KeyboardInterrupt:
        pass
