#!/usr/bin/env python3
"""Generate self-signed TLS material for the retrieval gRPC service."""
from __future__ import annotations

import argparse
import ipaddress
import os
import socket
from datetime import datetime, timedelta, timezone
from pathlib import Path

from cryptography import x509
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.x509.oid import ExtendedKeyUsageOID, NameOID


def _build_subject(common_name: str) -> x509.Name:
    return x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, common_name)])


def _collect_subject_alt_names(
    explicit_hosts: list[str] | None = None,
) -> list[x509.GeneralName]:
    names: set[x509.GeneralName] = {
        x509.DNSName("localhost"),
        x509.DNSName("kamiwaza-retrieval"),
        x509.IPAddress(ipaddress.IPv4Address("127.0.0.1")),
    }

    try:
        names.add(x509.IPAddress(ipaddress.IPv6Address("::1")))
    except ValueError:  # pragma: no cover - IPv6 support always available but defensive
        pass

    candidates = explicit_hosts or []
    env_host = os.getenv("RETRIEVAL_GRPC_HOST")
    if env_host:
        candidates.append(env_host)
    hostname = socket.gethostname()
    if hostname:
        candidates.append(hostname)
    fqdn = socket.getfqdn()
    if fqdn and fqdn != hostname:
        candidates.append(fqdn)

    for entry in candidates:
        candidate = entry.strip()
        if not candidate:
            continue
        try:
            ip = ipaddress.ip_address(candidate)
        except ValueError:
            names.add(x509.DNSName(candidate))
        else:
            names.add(x509.IPAddress(ip))

    def _general_name_sort_key(value: x509.GeneralName) -> tuple[int, str]:
        """Provide a stable sort key across DNS/IP entries."""
        if isinstance(value, x509.DNSName):
            return (0, value.value)
        if isinstance(value, x509.IPAddress):
            ip_value = value.value
            if isinstance(ip_value, (ipaddress.IPv4Address, ipaddress.IPv6Address)):
                return (1, ip_value.exploded)
            return (1, str(ip_value))
        return (2, str(value))

    return sorted(names, key=_general_name_sort_key)


def generate_self_signed_certificate(
    cert_path: Path, key_path: Path, ca_path: Path | None = None
) -> None:
    cert_path.parent.mkdir(parents=True, exist_ok=True)
    key_path.parent.mkdir(parents=True, exist_ok=True)
    if ca_path is not None:
        ca_path.parent.mkdir(parents=True, exist_ok=True)

    private_key = rsa.generate_private_key(public_exponent=65537, key_size=4096)
    subject = issuer = _build_subject("Kamiwaza Retrieval gRPC")
    san_values = _collect_subject_alt_names()

    builder = (
        x509.CertificateBuilder()
        .subject_name(subject)
        .issuer_name(issuer)
        .public_key(private_key.public_key())
        .serial_number(x509.random_serial_number())
        .not_valid_before(datetime.now(timezone.utc) - timedelta(minutes=5))
        .not_valid_after(datetime.now(timezone.utc) + timedelta(days=825))
        .add_extension(x509.SubjectAlternativeName(san_values), critical=False)
        .add_extension(x509.BasicConstraints(ca=True, path_length=None), critical=True)
        .add_extension(
            x509.KeyUsage(
                digital_signature=True,
                content_commitment=False,
                key_encipherment=True,
                data_encipherment=False,
                key_agreement=False,
                key_cert_sign=True,
                crl_sign=False,
                encipher_only=False,
                decipher_only=False,
            ),
            critical=True,
        )
        .add_extension(
            x509.ExtendedKeyUsage(
                [
                    ExtendedKeyUsageOID.SERVER_AUTH,
                    ExtendedKeyUsageOID.CLIENT_AUTH,
                ]
            ),
            critical=False,
        )
    )

    certificate = builder.sign(private_key=private_key, algorithm=hashes.SHA256())
    cert_bytes = certificate.public_bytes(serialization.Encoding.PEM)
    key_bytes = private_key.private_bytes(
        encoding=serialization.Encoding.PEM,
        format=serialization.PrivateFormat.TraditionalOpenSSL,
        encryption_algorithm=serialization.NoEncryption(),
    )

    cert_path.write_bytes(cert_bytes)
    key_path.write_bytes(key_bytes)
    if ca_path is not None:
        ca_path.write_bytes(cert_bytes)


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(
        description="Generate self-signed TLS assets for retrieval gRPC."
    )
    parser.add_argument(
        "--cert-path",
        required=True,
        help="Destination path for the PEM-encoded certificate.",
    )
    parser.add_argument(
        "--key-path",
        required=True,
        help="Destination path for the PEM-encoded private key.",
    )
    parser.add_argument(
        "--ca-path",
        required=False,
        help="Optional destination path for a CA bundle (defaults to same bytes as cert when provided).",
    )
    parser.add_argument(
        "--force",
        action="store_true",
        help="Overwrite existing files instead of exiting when certificate material already exists.",
    )
    return parser.parse_args()


def main() -> None:
    args = parse_args()
    cert_path = Path(args.cert_path).expanduser().resolve()
    key_path = Path(args.key_path).expanduser().resolve()
    ca_path = Path(args.ca_path).expanduser().resolve() if args.ca_path else None

    if not args.force and cert_path.exists() and key_path.exists():
        return

    generate_self_signed_certificate(cert_path, key_path, ca_path)


if __name__ == "__main__":
    main()
