import asyncio
import os
import ray
import signal
import uuid
from datetime import datetime, timezone
from sqlalchemy.orm import sessionmaker
from kamiwaza.scheduler.config import settings  # Adjust import path as needed
from kamiwaza.services.models.services import ModelService
from kamiwaza.cluster.services import ClusterService
from kamiwaza.cluster.models.cluster import DBMeta  # Adjust import path as needed
from kamiwaza.util.netutil import node_hostip
from kamiwaza.db.handle import get_db_handle
from kamiwaza.util.config import RuntimeConfig
import logging
from logging.handlers import RotatingFileHandler
from kamiwaza.lib.logging.utils import get_log_directory

log_dir = get_log_directory()
os.makedirs(log_dir, exist_ok=True)

# Configure root logger
logging.basicConfig(level=logging.DEBUG)  # Keep console logging if desired

# Create a rotating file handler
file_handler = RotatingFileHandler(
    filename=os.path.join(log_dir, "kamiwaza-scheduler.log"),
    maxBytes=10 * 1024 * 1024,  # 10MB per file
    backupCount=5,  # Keep 5 backup files
    encoding='utf-8'
)

# Set formatter
formatter = logging.Formatter(
    '%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    datefmt='%Y-%m-%d %H:%M:%S'
)
file_handler.setFormatter(formatter)

# Get the logger and add the handler
logger = logging.getLogger(__name__)
logger.addHandler(file_handler)

# Prevent duplicate logging
logger.propagate = False

SCHEDULED_TASK_LOG = os.path.join(log_dir, "kamiwaza_scheduled_task.log")


def _append_signal_log(message: str) -> None:
    try:
        with open(SCHEDULED_TASK_LOG, "a", encoding="utf-8") as f:
            f.write(message + "\n")
    except Exception:
        # Best effort logging; ignore failures
        pass


# Define global variables for signal handling with rate limiting
sleep_interrupt_event = asyncio.Event()
last_signal_execution_time = 0
pending_signal_task = None
POST_DOWNLOAD_KEY_PREFIX = "post_download_required_"

async def process_pending_post_downloads():
    """
    Drain any post-download requests that were enqueued in etcd by workers that
    completed downloads. Each request triggers ModelService.post_download which
    handles deployment checks and cluster restarts in-process.
    """
    rc = RuntimeConfig()
    try:
        config_entries = rc.get_config_dict()
    except Exception as e:
        logger.error(f"Failed to read RuntimeConfig for post-download requests: {e}", exc_info=True)
        return

    pending_keys = sorted(
        key for key in config_entries.keys() if key.startswith(POST_DOWNLOAD_KEY_PREFIX)
    )

    if not pending_keys:
        return

    model_service = ModelService()
    loop = asyncio.get_running_loop()

    for key in pending_keys:
        raw_model_id = key.replace(POST_DOWNLOAD_KEY_PREFIX, "", 1)
        model_id = None
        if raw_model_id and raw_model_id.lower() != "global":
            try:
                model_id = uuid.UUID(raw_model_id)
            except ValueError:
                logger.warning(f"Ignoring invalid post-download key '{key}'")
                try:
                    rc.delete_config(key)
                except Exception as delete_error:
                    logger.error(f"Failed to delete invalid post-download key '{key}': {delete_error}")
                continue

        logger.info(f"Processing post-download request for model {model_id or 'global'}")
        try:
            await loop.run_in_executor(None, model_service.post_download, model_id)
            rc.delete_config(key)
        except Exception as e:
            logger.error(f"Error during post-download processing for {model_id}: {e}", exc_info=True)

async def signal_handler(signum):
    """
    Asynchronous signal handler with 3-second rate limiting and batching.
    
    Signals within a 3-second window are batched together and executed once
    at the window boundary to prevent excessive task execution.

    Args:
        signum (int): The signal number received.
    """
    global last_signal_execution_time, pending_signal_task
    
    current_time = datetime.now(timezone.utc).timestamp()
    time_since_last_execution = current_time - last_signal_execution_time
    
    # Write signal receipt to log
    now_str = datetime.now(timezone.utc).strftime('%Y-%m-%d %H:%M:%S %Z')
    _append_signal_log(
        f"{now_str} - Received signal {signum} - time since last execution: {time_since_last_execution:.2f}s"
    )
    
    if time_since_last_execution >= 3.0:
        # Execute immediately - 3+ seconds have passed
        last_signal_execution_time = current_time
        _append_signal_log(
            f"{now_str} - Executing signal immediately (3s+ elapsed)"
        )
        sleep_interrupt_event.set()
    else:
        # Batch signal - schedule execution when 3s window expires
        if pending_signal_task is None or pending_signal_task.done():
            delay = 3.0 - time_since_last_execution
            _append_signal_log(
                f"{now_str} - Batching signal, will execute in {delay:.2f}s"
            )
            pending_signal_task = asyncio.create_task(delayed_signal_execution(delay))
        else:
            # Signal already pending, just log that we're batching
            _append_signal_log(
                f"{now_str} - Signal batched with existing pending execution"
            )

async def delayed_signal_execution(delay):
    """
    Execute delayed signal handling after the rate limiting window expires.
    
    Args:
        delay (float): Time to wait before executing the signal
    """
    global last_signal_execution_time
    
    await asyncio.sleep(delay)
    last_signal_execution_time = datetime.now(timezone.utc).timestamp()
    
    now_str = datetime.now(timezone.utc).strftime('%Y-%m-%d %H:%M:%S %Z')
    _append_signal_log(f"{now_str} - Executing batched signal after delay")
    
    sleep_interrupt_event.set()

async def store_periodic_pid():
    """
    Store the current process ID (PID) and host information in the DBMeta table.

    This function creates a database session, retrieves or creates DBMeta entries
    for the periodic PID and host information, and updates them with current values.
    Also checks if Ray is initialized to determine if running in a cluster context.

    The stored information includes:
    - periodic_pid: Current process ID
    - periodic_host: Host identifier from node_hostip()
    - periodic_ray: Whether process is running in Ray cluster
    - periodic_ray_node_id: Node ID if running in Ray cluster
    - periodic_ray_job_id: Job ID if running in Ray cluster
    """
    engine = get_db_handle(settings.database_url)
    Session = sessionmaker(bind=engine)
    session = Session()
    
    try:
        # Get process info
        pid = os.getpid()
        
        host = node_hostip()  # Fallback hostname

        # Prefer the IP address Ray is using if available. This ensures that the
        # host value stored in DBMeta matches the node information reported by
        # ``ray.nodes()`` which is required for ``run_on_named_node`` to work.
        if ray.is_initialized():
            try:
                host = ray.util.get_node_ip_address()
            except Exception as e:
                logger.debug(
                    f"store_periodic_pid: Failed to get node IP from Ray; using node_hostip(): {e}"
                )

        # Get Ray-specific details if running in Ray
        ray_context = {}
        if ray.is_initialized():
            ray_context = ray.runtime_context.get_runtime_context()
        
        meta_keys = {
            "periodic_pid": str(pid),
            "periodic_host": host,
            "periodic_ray": "true" if ray.is_initialized() else "false",
            "periodic_ray_node_id": ray_context.get_node_id()
            if ray.is_initialized()
            else "",
            "periodic_ray_job_id": ray_context.get_job_id()
            if ray.is_initialized()
            else "",
        }
        
        for key, value in meta_keys.items():
            meta_entry = session.query(DBMeta).filter(DBMeta.key == key).first()
            if meta_entry:
                meta_entry.value = value
            else:
                new_meta = DBMeta(key=key, value=value)
                session.add(new_meta)
                
        session.commit()
        logger.info(f"Stored periodic process info: {meta_keys}")
        
    except Exception as e:
        logger.error(f"Failed to store periodic process info in database: {e}")
        raise
    finally:
        session.close()

async def periodic_tasks():
    """
    Execute periodic tasks at regular intervals.

    This function runs continuously, performing the following tasks:
    1. Stores the periodic PID
    2. Processes model downloads
    3. Updates cluster nodes (rate limited to 15s)
    4. Refreshes model guide (2x per day)

    The function sleeps for a configurable amount of time between cycles,
    and can be interrupted by a signal.
    """
    await store_periodic_pid()
    
    last_run_time = 0
    last_node_update_time = 0
    last_health_check_time = 0  # Track last health check
    # Track child processes that need cleanup
    child_processes = set()
    
    # Track model guide refresh timing
    startup_time = datetime.now(timezone.utc).timestamp()
    last_model_guide_refresh = 0
    model_guide_refresh_interval = 12 * 60 * 60  # 12 hours in seconds
    model_guide_startup_delay = 5 * 60  # 5 minutes in seconds
    
    while True:
        current_time = datetime.now(timezone.utc).timestamp()
        
        if current_time - last_run_time >= 2.9:
            # Clean up any finished child processes
            finished_processes = set()
            for process in child_processes:
                if not process.is_alive():
                    process.join(timeout=5)  # Give each process 5 seconds to clean up
                    finished_processes.add(process)
            child_processes -= finished_processes
            
            logger.debug("Starting model downloads processing cycle")
            model_service = ModelService()
            try:
                await model_service.process_downloads()
                logger.debug("Successfully completed model downloads processing")
            except Exception as e:
                logger.error(f"Error processing model downloads: {e}", exc_info=True)
            

            # Check if it's time to update cluster nodes (rate limited to 15s)
            if current_time - last_node_update_time >= 15:
                logger.debug("Starting cluster node update")
                cluster_service = ClusterService()
                try:
                    await cluster_service.periodic_node_update()
                    logger.debug("Successfully completed cluster node update")
                    last_node_update_time = current_time
                except Exception as e:
                    logger.error(f"Error in cluster node update: {e}", exc_info=True)
            
            # Check if it's time to refresh model guide
            time_since_startup = current_time - startup_time
            time_since_last_refresh = current_time - last_model_guide_refresh
            
            should_refresh_guide = False
            if last_model_guide_refresh == 0:
                # First refresh after startup delay
                if time_since_startup >= model_guide_startup_delay:
                    should_refresh_guide = True
                    logger.info("Performing initial model guide refresh after startup delay")
            else:
                # Subsequent refreshes every 12 hours
                if time_since_last_refresh >= model_guide_refresh_interval:
                    should_refresh_guide = True
                    logger.info("Performing scheduled model guide refresh")
            
            if should_refresh_guide:
                logger.debug("Starting model guide refresh")
                model_service = ModelService()
                try:
                    # First try to import from local file if it exists
                    import_result = await model_service.import_model_guide()
                    if "error" not in import_result:
                        logger.info(f"Model guide import successful: {import_result}")
                    
                    # Then try to refresh from external endpoint
                    refresh_result = await model_service.refresh_model_guide()
                    if "error" not in refresh_result:
                        logger.info(f"Model guide refresh successful: {refresh_result}")
                    else:
                        logger.warning(f"Model guide refresh failed: {refresh_result.get('error')}")
                    
                    last_model_guide_refresh = current_time
                except Exception as e:
                    logger.error(f"Error refreshing model guide: {e}", exc_info=True)
            
            # Check for models waiting to be deployed after download
            try:
                logger.info("[DEPLOYMENT_DEBUG] Scheduler checking for waiting deployments...")
                model_service = ModelService()
                deployment_result = await model_service.deploy_all_waiting_downloaded()
                logger.info(f"[DEPLOYMENT_DEBUG] Scheduler deployment check result: {deployment_result}")
                if deployment_result.get("checked", 0) > 0:
                    logger.info(f"Deployment check: {deployment_result.get('message')}")
                else:
                    logger.debug(f"[DEPLOYMENT_DEBUG] No deployments to check: {deployment_result.get('message')}")
            except Exception as e:
                logger.error(f"[DEPLOYMENT_DEBUG] Error checking waiting deployments: {e}", exc_info=True)

            # React to any post-download signals from worker processes
            try:
                await process_pending_post_downloads()
            except Exception as e:
                logger.error(f"Error processing post-download requests: {e}", exc_info=True)
            
            # Check deployment health status every 30 seconds
            if current_time - last_health_check_time >= 30:
                try:
                    logger.debug("Starting deployment health checks...")
                    from kamiwaza.serving.services import ServingService
                    serving_service = ServingService()
                    health_results = await serving_service.health_check()
                    
                    if health_results:
                        ready_count = len([h for h in health_results if h.get('status') == 'READY'])
                        initializing_count = len([h for h in health_results if h.get('status') == 'INITIALIZING'])
                        error_count = len([h for h in health_results if h.get('status') == 'ERROR'])
                        
                        logger.info(f"Health check completed: {ready_count} ready, {initializing_count} initializing, {error_count} errors")
                    else:
                        logger.debug("No deployments found for health checking")
                    
                    last_health_check_time = current_time
                except Exception as e:
                    logger.error(f"Error during deployment health checks: {e}", exc_info=True)
            
            last_run_time = datetime.now(timezone.utc).timestamp()
        
        elapsed_time = datetime.now(timezone.utc).timestamp() - current_time
        sleep_time = max(settings.cycle_time - elapsed_time, 0)
        
        sleep_task = asyncio.create_task(asyncio.sleep(sleep_time))
        interrupt_task = asyncio.create_task(sleep_interrupt_event.wait())
        
        try:
            await asyncio.wait([sleep_task, interrupt_task], return_when=asyncio.FIRST_COMPLETED)
        finally:
            sleep_interrupt_event.clear()  # Reset the event for the next cycle

def handle_child_process(signum, frame): # noqa
    """Handle terminated child processes to prevent zombies."""
    try:
        while True:
            # WNOHANG specified to make this non-blocking
            pid, status = os.waitpid(-1, os.WNOHANG)
            if pid == 0:
                break
    except ChildProcessError:
        pass

async def main():
    """
    Main asynchronous function to set up and run the periodic tasks.

    This function:
    1. Gets the current event loop
    2. Registers the signal handler for SIGUSR1
    3. Runs the periodic tasks
    """
    # Get the current event loop
    loop = asyncio.get_running_loop()
    
    # In both Ray and local mode, we want to handle SIGUSR1 the same way
    # The difference is how the signal gets to us - through Ray or directly
    loop.add_signal_handler(signal.SIGUSR1, lambda: asyncio.create_task(signal_handler(signal.SIGUSR1)))
    
    # Handle child processes cleanup in both modes
    signal.signal(signal.SIGCHLD, handle_child_process)
    
    # Run the periodic tasks in a continuous loop with exception handling
    while True:
        try:
            await periodic_tasks()
        except Exception as e:
            logger.error(f"Exception in periodic tasks, restarting: {e}", exc_info=True)
            # Wait a bit before restarting to avoid tight error loops
            await asyncio.sleep(5)

if __name__ == "__main__":
    """
    Entrypoint for the scheduler runner.
    This block prepares the Ray runtime environment with the same environment variable
    treatment as main.py, ensuring all relevant KAMIWAZA_* env vars and other
    critical runtime settings are passed to Ray.
    Also sets log_to_driver as per main.py and start_ray.
    """
    import traceback
    import os

    # Gather environment variables for Ray runtime_env
    runtime_env = {
        "env_vars": {
            "KAMIWAZA_ENV_INIT": "yes"
        }
    }

    # Always include KAMIWAZA_ROOT and KAMIWAZA_LIB_ROOT if present
    kamiwaza_root = os.getenv("KAMIWAZA_ROOT")
    kamiwaza_lib_root = os.getenv("KAMIWAZA_LIB_ROOT")
    if kamiwaza_root:
        runtime_env["env_vars"]["KAMIWAZA_ROOT"] = kamiwaza_root
    if kamiwaza_lib_root:
        runtime_env["env_vars"]["KAMIWAZA_LIB_ROOT"] = kamiwaza_lib_root

    # Always set PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION to "python"
    runtime_env["env_vars"]["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = os.getenv(
        "PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION", "python"
    )

    # Add PYTHONPATH if it exists
    if "PYTHONPATH" in os.environ:
        runtime_env["env_vars"]["PYTHONPATH"] = os.environ["PYTHONPATH"]

    # Add HF_TOKEN if present
    if "HF_TOKEN" in os.environ:
        runtime_env["env_vars"]["HF_TOKEN"] = os.environ["HF_TOKEN"]

    # Add KAMIWAZA_COMMUNITY if present
    if "KAMIWAZA_COMMUNITY" in os.environ:
        runtime_env["env_vars"]["KAMIWAZA_COMMUNITY"] = os.environ["KAMIWAZA_COMMUNITY"]

    # Preserve Python no-bytecode flag if set
    if "PYTHONDONTWRITEBYTECODE" in os.environ:
        runtime_env["env_vars"]["PYTHONDONTWRITEBYTECODE"] = os.environ["PYTHONDONTWRITEBYTECODE"]

    # Add all KAMIWAZA_* env vars not already set in env_vars
    for key, value in os.environ.items():
        if ((
            "KAMIWAZA" in key or key.startswith("HF_"))
            or key.startswith("FORWARDAUTH_")
            or key == "PYTHONDONTWRITEBYTECODE"
        ) and key not in runtime_env["env_vars"]:
            runtime_env["env_vars"][key] = value

    # Set KAMIWAZA_RUNNING to true
    runtime_env["env_vars"]["KAMIWAZA_RUNNING"] = "true"
    os.environ["KAMIWAZA_RUNNING"] = "true"

    # Get Ray init address from env
    ray_init_address = os.getenv("KAMIWAZA_RAY_INIT_ADDRESS")

    # Force log_to_driver=False to avoid stealing from main process
    # Always set ignore_reinit_error=True for graceful reinitialization
    ray_params = {
        "runtime_env": runtime_env,
        "address": ray_init_address if ray_init_address else "auto",
        "log_to_driver": False,
        "ignore_reinit_error": True
    }

    # Initialize Ray with error handling and stack trace logging
    import ray
    try:
        ray.init(**ray_params)
    except Exception as e:
        logger.error("Failed to initialize Ray with params:", ray_params)
        traceback.print_exc()
        raise e

    from kamiwaza.util.locks import DistributedLock
    from kamiwaza.util.locks import start_lock_watchdog

    # Pick the locks you want to watch
    WATCHED_LOCKS = ["health_check"]  # add others as you introduce them
    WATCHDOG_TTL = 30                 # should match the lock TTL for those keys
    CHECK_EVERY = 30                  # how often to sweep for stale/corrupt values

    rc = RuntimeConfig()

    # Try to be the leader. If someone else is, we noop.
    _watchdog_leader = DistributedLock(rc, "watchdog_leader", ttl_seconds=90)
    _watchdog = None
    if _watchdog_leader.acquire(timeout=0.2, max_retry=0):
        _watchdog = start_lock_watchdog(
            rc,
            lock_names=WATCHED_LOCKS,
            ttl_seconds=WATCHDOG_TTL,
            check_interval=CHECK_EVERY,
            stale_factor=10.0,
        )
        if not _watchdog:
            logger.error("Failed to start lock watchdog")
    else:
        # Another scheduler instance is already the leader; that's fine.
        pass

    try:
        import asyncio
        asyncio.run(main())
    except Exception as e:
        logger.error("Exception in scheduler main loop:")
        traceback.print_exc()
        raise e
    finally:
        if _watchdog:
            _watchdog.stop()
        # Keep the leader lock alive until we stop; then release
        try:
            _watchdog_leader.release()
        except Exception:
            pass
