#!/usr/bin/env python3

import argparse
import hashlib
import logging
import os
import shutil
import sys
from typing import Dict, Optional

from tabulate import tabulate

# Suppress gRPC initialization warnings
os.environ['GRPC_ENABLE_FORK_SUPPORT'] = '0'
os.environ["GLOG_minloglevel"] = "2"
os.environ['GRPC_VERBOSITY'] = 'ERROR'
os.environ['GRPC_TRACE'] = 'none'
logging.getLogger('abseil').setLevel(logging.ERROR)

from kamiwaza.services.models.services import ModelService  # noqa: E402
from kamiwaza.services.models.config import settings  # noqa: E402
from kamiwaza.services.models.schemas.model import Model  # noqa: E402
from kamiwaza.services.models.models.model_file import DBModelFile  # noqa: E402
from kamiwaza.db.engine import DatabaseManager  # noqa: E402

def get_size_str(size_bytes: int) -> str:
    """Convert bytes to human readable string."""
    for unit in ['B', 'KB', 'MB', 'GB', 'TB']:
        if size_bytes < 1024:
            return f"{size_bytes:.2f} {unit}"
        size_bytes /= 1024
    return f"{size_bytes:.2f} TB"

def get_disk_space(path: str) -> Dict[str, int]:
    """Get disk space information for a path."""
    try:
        if os.path.exists(path):
            total, used, free = shutil.disk_usage(path)
            return {"total": total, "used": used, "free": free}
        return {"total": 0, "used": 0, "free": 0}
    except Exception as e:
        print(f"Error getting disk space for {path}: {e}")
        return {"total": 0, "used": 0, "free": 0}

def get_model_location(model: Model) -> Optional[str]:
    """Get the current storage location of a model's files."""
    if not model.m_files:
        return None
    # All files should be in the same location, so we check the first one
    return model.m_files[0].storage_location

def get_model_size(model: Model) -> int:
    """Calculate total size of a model's files."""
    return sum(f.size or 0 for f in model.m_files)

def check_downloads_in_progress(model: Model) -> bool:
    """Check if any files are currently being downloaded."""
    return any(f.is_downloading for f in model.m_files if hasattr(f, 'is_downloading'))

def cleanup_empty_parents(path: str):
    """Recursively remove empty parent directories."""
    current = path
    while current:
        if os.path.exists(current):
            try:
                if len(os.listdir(current)) == 0:
                    os.rmdir(current)
                    print(f"  - Removed empty directory: {current}")
                else:
                    print(f"  - Directory not empty, stopping cleanup at: {current}")
                    break
            except OSError as e:
                print(f"  - Could not remove directory {current}: {e}")
                break
        current = os.path.dirname(current)
        # Stop when we hit a configured root path
        if any(current == loc[1] for loc in settings.models_file_locations if loc[0] == 'file'):
            print(f"  - Reached configured root path: {current}")
            break

def list_models_and_space():
    """Display all models and their storage information."""
    try:
        ms = ModelService()
        models = ms.list_models(load_files=True)
        
        if not models:
            print("\nNo models found in the database.")
            return

        # Prepare model information
        model_data = []
        total_size = 0
        
        for model in models:
            size = get_model_size(model)
            location = get_model_location(model)
            total_size += size
            model_data.append([
                model.name,
                str(model.id),  # Convert UUID to string for better display
                get_size_str(size),
                location or "Unknown",
                "Yes" if check_downloads_in_progress(model) else "No"
            ])

        # Display models table
        print("\nModels and their current locations:")
        print(tabulate(
            model_data,
            headers=["Name", "ID", "Size", "Location", "Downloads Active"],
            tablefmt="grid"
        ))
        print(f"\nTotal space used by models: {get_size_str(total_size)}")

        # Display available paths and their space
        print("\nConfigured paths and available space:")
        path_data = []
        for storage_type, path in settings.models_file_locations:
            if storage_type == 'file':
                space = get_disk_space(path)
                path_data.append([
                    path,
                    get_size_str(space["total"]),
                    get_size_str(space["free"]),
                    get_size_str(space["used"])
                ])

        if path_data:
            print(tabulate(
                path_data,
                headers=["Path", "Total Space", "Free Space", "Used Space"],
                tablefmt="grid"
            ))
        else:
            print("\nNo file storage locations configured in settings.")

    except Exception as e:
        print(f"Error listing models and space: {e}")
        return

def validate_path(path: str) -> bool:
    """Validate that a path is one of the configured mount points."""
    return any(path == loc[1] for loc in settings.models_file_locations if loc[0] == 'file')

def get_model_subpath(model: Model) -> str:
    """Extract the model's subdirectory structure from its current location."""
    if not model.m_files or not model.m_files[0].storage_location:
        return model.repo_modelId.replace('/', os.sep)
        
    current_path = model.m_files[0].storage_location
    for storage_type, base_path in settings.models_file_locations:
        if storage_type == 'file' and current_path.startswith(base_path):
            return os.path.relpath(current_path, base_path)
            
    # Fallback to repo_modelId if we can't extract from current path
    return model.repo_modelId.replace('/', os.sep)

def calculate_md5(filepath):
    """Calculate MD5 checksum of a file."""
    md5_hash = hashlib.md5() # nosec: not a security check
    with open(filepath, "rb") as f:
        for chunk in iter(lambda: f.read(4096), b""):
            md5_hash.update(chunk)
    return md5_hash.hexdigest()

def relocate_model(model_id: str, mount_point: str, dry_run: bool = False):
    """Relocate a model to a new mount point, preserving its subdirectory structure."""
    ms = ModelService()
    
    try:
        # Validate model
        model = ms.get_model(model_id, load_files=True)
        if not model.m_files:
            print(f"Error: Model {model.name} has no files to relocate")
            return False

        # Validate mount point
        if not validate_path(mount_point):
            print(f"Error: Mount point {mount_point} is not in configured locations")
            print("Configured mount points are:")
            for storage_type, path in settings.models_file_locations:
                if storage_type == 'file':
                    print(f"  - {path}")
            return False

        # Check for downloads in progress
        if check_downloads_in_progress(model):
            print(f"Error: Model {model.name} has downloads in progress")
            return False

        # Calculate required space using the mount point
        total_size = get_model_size(model)
        mount_space = get_disk_space(mount_point)["free"]
        
        if total_size > mount_space:
            print(f"Error: Insufficient space at mount point {mount_point}. Need {get_size_str(total_size)}, " +
                  f"have {get_size_str(mount_space)}")
            return False

        # Determine the full target path
        model_subpath = get_model_subpath(model)
        target_path = os.path.join(mount_point, model_subpath)

        if dry_run:
            print(f"\nDry run: Would move {model.name} ({get_size_str(total_size)}) to {target_path}")
            print("Files that would be moved:")
            for file in model.m_files:
                print(f"  - {file.name} ({get_size_str(file.size)})")
            return True

        # Perform the move
        print(f"\nMoving {model.name} to {target_path}")
        
        # Create target directory if it doesn't exist
        os.makedirs(target_path, exist_ok=True)
        
        # Move each file and update database
        moved_files = []
        try:
            with DatabaseManager.get_session('main')() as db:
                for file in model.m_files:
                    if not file.storage_location:
                        print(f"Warning: File {file.name} has no storage location")
                        continue

                    source_path = os.path.join(file.storage_location, file.name)
                    new_path = os.path.join(target_path, file.name)
                    
                    if not os.path.exists(source_path):
                        raise Exception(f"Source file {source_path} not found")
                    
                    print(f"Copying {file.name}...")
                    # First copy the file
                    shutil.copy2(source_path, new_path)
                    moved_files.append((source_path, new_path))
                    
                    # Verify the copy
                    if not os.path.exists(new_path):
                        raise Exception(f"Copy of {file.name} not found at destination")
                        
                    # Verify file size
                    if os.path.getsize(new_path) != os.path.getsize(source_path):
                        raise Exception(f"Size mismatch for {file.name} after copy")
                        
                    # Calculate and verify checksums
                    source_checksum = calculate_md5(source_path)
                    dest_checksum = calculate_md5(new_path)
                    
                    if source_checksum != dest_checksum:
                        raise Exception(f"Checksum mismatch for {file.name} after copy")
                        
                    print(f"  - Verified copy of {file.name}")

                # All files copied and verified successfully, now update database
                print("\nAll files copied and verified. Updating database...")
                for file in model.m_files:
                    db_file = db.query(DBModelFile).filter(DBModelFile.id == file.id).first()
                    if db_file:
                        db_file.storage_location = target_path
                db.commit()
                print("Database updated successfully")
                
                # Now safe to remove original files
                print("\nRemoving original files...")
                source_dirs = set()  # Keep track of unique parent directories
                for source_path, _ in moved_files:
                    try:
                        os.unlink(source_path)
                        print(f"  - Removed {source_path}")
                        source_dirs.add(os.path.dirname(source_path))
                    except Exception as e:
                        print(f"Warning: Could not remove {source_path}: {e}")
                
                # Clean up empty parent directories
                print("\nCleaning up empty directories...")
                for source_dir in source_dirs:
                    cleanup_empty_parents(source_dir)
                
                print(f"Successfully relocated {model.name} to {target_path}")
                return True

        except Exception as e:
            print(f"Error during relocation: {e}")
            print("\nRolling back - removing copied files...")
            for _, new_path in moved_files:
                try:
                    if os.path.exists(new_path):
                        os.unlink(new_path)
                        print(f"  - Removed {new_path}")
                except Exception as e:
                    print(f"Error removing {new_path} during rollback: {e}")
            return False

    except Exception as e:
        print(f"Error: {e}")
        return False

def main():
    parser = argparse.ArgumentParser(
        description="Relocate model files between configured paths",
        formatter_class=argparse.RawTextHelpFormatter
    )
    parser.add_argument("--list", action="store_true", 
                      help="List all models and their locations")
    parser.add_argument("--model-id", 
                      help="ID of the model to relocate")
    parser.add_argument("--target", 
                      help="Target mount point (e.g., /mnt/tmp) - subdirectories will be created automatically")
    parser.add_argument("--dry-run", action="store_true",
                      help="Show what would be done without making changes")

    if len(sys.argv) == 1:
        parser.print_help()
        return

    args = parser.parse_args()

    if args.list:
        list_models_and_space()
        return

    if bool(args.model_id) != bool(args.target):
        print("Error: Both --model-id and --target must be provided for relocation")
        return

    if args.model_id and args.target:
        relocate_model(args.model_id, args.target, args.dry_run)

if __name__ == "__main__":
    main()
