#!/usr/bin/env python3
"""Generate a local CA + server/SP certificates for Traefik and SAML."""
from __future__ import annotations

import argparse
import os
import shutil
import subprocess
import sys
import tempfile
from pathlib import Path
from urllib.parse import urlparse

ROOT = Path(__file__).resolve().parents[1]
SECRETS_DIR = ROOT / "runtime" / "secrets"
TLS_DIR = SECRETS_DIR / "tls"

CA_KEY = TLS_DIR / "kz-local-ca.key"
CA_CRT = TLS_DIR / "kz-local-ca.crt"
SERVER_KEY = TLS_DIR / "kz-local-server.key"
SERVER_CRT = TLS_DIR / "kz-local-server.crt"
SAML_KEY = TLS_DIR / "saml-sp.key"
SAML_CRT = TLS_DIR / "saml-sp.crt"

TRAEFIK_CERT_TARGETS = [
    ROOT / "deployment" / "kamiwaza-traefik" / arch / "certs"
    for arch in ("amd64", "arm64")
]


def log(msg: str) -> None:
    print(f"[generate_local_ca] {msg}")


def require_tool(tool: str) -> None:
    if shutil.which(tool) is None:
        log(f"ERROR: {tool} not found on PATH. Install it and re-run this script.")
        sys.exit(1)


def run(cmd: list[str]) -> None:
    subprocess.run(cmd, check=True)


def set_perms(path: Path, mode: int = 0o640) -> None:
    if path.exists():
        os.chmod(path, mode)


def write_san_config(common_name: str, hosts: list[str], usages: str) -> Path:
    alt_lines = []
    dns_idx = 1
    ip_idx = 1
    for host in sorted(set(hosts)):
        if host.replace(".", "").isdigit():
            alt_lines.append(f"IP.{ip_idx} = {host}")
            ip_idx += 1
        else:
            alt_lines.append(f"DNS.{dns_idx} = {host}")
            dns_idx += 1
    template = f"""
[req]
distinguished_name = req_distinguished_name
req_extensions = v3_req
prompt = no

[req_distinguished_name]
CN = {common_name}

[v3_req]
keyUsage = digitalSignature, keyEncipherment
extendedKeyUsage = {usages}
subjectAltName = @alt_names

[alt_names]
{os.linesep.join(alt_lines)}
"""
    tmp = tempfile.NamedTemporaryFile(delete=False, mode="w", suffix=".cnf")
    tmp.write(template)
    tmp.close()
    return Path(tmp.name)


def parse_external_host() -> str:
    raw = os.environ.get("KAMIWAZA_EXTERNAL_URL", "https://localhost")
    parsed = urlparse(raw)
    return parsed.hostname or "localhost"


def ensure_ca(force: bool) -> None:
    if CA_KEY.exists() and CA_CRT.exists() and not force:
        return
    TLS_DIR.mkdir(parents=True, exist_ok=True)
    log("Generating local CA")
    run(["openssl", "genrsa", "-out", str(CA_KEY), "4096"])
    run(
        [
            "openssl",
            "req",
            "-x509",
            "-new",
            "-nodes",
            "-key",
            str(CA_KEY),
            "-sha256",
            "-days",
            "3650",
            "-subj",
            "/CN=Kamiwaza Local CA",
            "-out",
            str(CA_CRT),
        ]
    )
    set_perms(CA_KEY)
    set_perms(CA_CRT)


def issue_cert(
    common_name: str, hosts: list[str], key_path: Path, crt_path: Path, usages: str
) -> None:
    cfg = write_san_config(common_name, hosts, usages)
    csr = key_path.with_suffix(".csr")
    try:
        run(["openssl", "genrsa", "-out", str(key_path), "4096"])
        run(
            [
                "openssl",
                "req",
                "-new",
                "-key",
                str(key_path),
                "-out",
                str(csr),
                "-config",
                str(cfg),
            ]
        )
        run(
            [
                "openssl",
                "x509",
                "-req",
                "-in",
                str(csr),
                "-CA",
                str(CA_CRT),
                "-CAkey",
                str(CA_KEY),
                "-CAcreateserial",
                "-out",
                str(crt_path),
                "-days",
                "825",
                "-sha256",
                "-extensions",
                "v3_req",
                "-extfile",
                str(cfg),
            ]
        )
    finally:
        os.unlink(cfg)
        if csr.exists():
            csr.unlink()
    set_perms(key_path)
    set_perms(crt_path)


def copy_to_traefik(server_cert: Path, server_key: Path, ca_cert: Path) -> None:
    for target in TRAEFIK_CERT_TARGETS:
        target.mkdir(parents=True, exist_ok=True)
        shutil.copy2(server_cert, target / "domain.crt")
        shutil.copy2(server_key, target / "domain.key")
        shutil.copy2(ca_cert, target / "client-ca.crt")


def main() -> None:
    parser = argparse.ArgumentParser(description="Generate local CA + TLS assets")
    parser.add_argument(
        "--force",
        action="store_true",
        help="regenerate certificates even if they exist",
    )
    args = parser.parse_args()

    require_tool("openssl")
    external_host = parse_external_host()
    default_hosts = [
        external_host,
        "localhost",
        "127.0.0.1",
        "host.docker.internal",
        "kamiwaza-traefik",
        "kamiwaza-auth",
    ]

    ensure_ca(args.force)

    if args.force or not SERVER_CRT.exists():
        log("Issuing Traefik/server certificate")
        issue_cert(
            external_host,
            default_hosts,
            SERVER_KEY,
            SERVER_CRT,
            "serverAuth, clientAuth",
        )

    if args.force or not SAML_CRT.exists():
        log("Issuing SAML SP certificate")
        issue_cert(
            "Kamiwaza SAML SP", [external_host], SAML_KEY, SAML_CRT, "serverAuth"
        )

    copy_to_traefik(SERVER_CRT, SERVER_KEY, CA_CRT)
    log("Local CA + certificates ready.")


if __name__ == "__main__":
    main()
