#!/usr/bin/python3

import argparse
import platform
import subprocess
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Iterable


DEFAULT_CONNECT_TIMEOUT = 5
DEFAULT_MAX_TIME = 10
DEFAULT_ARCH = platform.machine() or "x86_64"


@dataclass
class MirrorResult:
    mirror_url: str
    db_url: str
    success: bool
    speed_bytes_per_sec: float | None
    http_code: str | None
    total_time: float | None
    error: str | None


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(
        description=(
            "Test download speed for each Arch Linux package mirror listed in a "
            "mirrorlist by downloading the repo database with curl."
        )
    )
    parser.add_argument("mirrorlist", type=Path, help="Path to the mirrorlist file")
    parser.add_argument("repo", help="Repository name used to build the .db URL")
    parser.add_argument(
        "--arch",
        default=DEFAULT_ARCH,
        help=f"Architecture to substitute for $arch (default: {DEFAULT_ARCH})",
    )
    parser.add_argument(
        "--connect-timeout",
        type=float,
        default=DEFAULT_CONNECT_TIMEOUT,
        help=f"Curl connect timeout in seconds (default: {DEFAULT_CONNECT_TIMEOUT})",
    )
    parser.add_argument(
        "--max-time",
        type=float,
        default=DEFAULT_MAX_TIME,
        help=f"Maximum time per download in seconds (default: {DEFAULT_MAX_TIME})",
    )
    parser.add_argument(
        "--insecure",
        action="store_true",
        help="Pass -k to curl to ignore TLS certificate validation errors",
    )
    return parser.parse_args()


def extract_mirrors(mirrorlist_path: Path) -> set[str]:
    mirrors: set[str] = set()

    with mirrorlist_path.open("r", encoding="utf-8") as handle:
        for raw_line in handle:
            line = raw_line.strip()
            if not line:
                continue

            if line.startswith("#"):
                line = line[1:].strip()
                if not line:
                    continue
            if "=" not in line:
                continue

            key, value = (part.strip() for part in line.split("=", 1))
            if key != "Server" or not value:
                continue
            mirrors.add(value)

    return mirrors


def build_db_url(server_url: str, repo: str, arch: str) -> str:
    base_url = server_url.replace("$repo", repo).replace("${repo}", repo)
    base_url = base_url.replace("$arch", arch).replace("${arch}", arch)
    return f"{base_url.rstrip('/')}/{repo}.db"


def measure_mirror(
    mirror_url: str,
    repo: str,
    arch: str,
    connect_timeout: float,
    max_time: float,
    insecure: bool,
) -> MirrorResult:
    db_url = build_db_url(mirror_url, repo, arch)
    write_out = r"%{speed_download}\t%{http_code}\t%{time_total}"
    cmd = [
        "curl",
        "--silent",
        "--show-error",
        "--location",
        "--output",
        "/dev/null",
        "--connect-timeout",
        str(connect_timeout),
        "--max-time",
        str(max_time),
        "--write-out",
        write_out,
        db_url,
    ]
    if insecure:
        cmd.insert(1, "--insecure")

    completed = subprocess.run(
        cmd,
        capture_output=True,
        text=True,
        check=False,
    )

    stdout = completed.stdout.strip()
    stderr = completed.stderr.strip() or None

    if completed.returncode != 0:
        return MirrorResult(
            mirror_url=mirror_url,
            db_url=db_url,
            success=False,
            speed_bytes_per_sec=None,
            http_code=None,
            total_time=None,
            error=stderr or f"curl exited with code {completed.returncode}",
        )

    parts = stdout.split("\t")
    if len(parts) != 3:
        return MirrorResult(
            mirror_url=mirror_url,
            db_url=db_url,
            success=False,
            speed_bytes_per_sec=None,
            http_code=None,
            total_time=None,
            error=f"unexpected curl output: {stdout!r}",
        )

    speed_text, http_code, time_text = parts

    try:
        speed = float(speed_text)
        total_time = float(time_text)
    except ValueError:
        return MirrorResult(
            mirror_url=mirror_url,
            db_url=db_url,
            success=False,
            speed_bytes_per_sec=None,
            http_code=http_code,
            total_time=None,
            error=f"failed to parse curl metrics: {stdout!r}",
        )

    success = http_code.startswith("2") or http_code.startswith("3")
    return MirrorResult(
        mirror_url=mirror_url,
        db_url=db_url,
        success=success,
        speed_bytes_per_sec=speed if success else None,
        http_code=http_code,
        total_time=total_time,
        error=None if success else f"HTTP {http_code}",
    )


def format_speed(speed_bytes_per_sec: float | None) -> str:
    if speed_bytes_per_sec is None:
        return "-"
    mib_per_sec = speed_bytes_per_sec / (1024 * 1024)
    return f"{mib_per_sec:.2f} MiB/s"


def format_result_line(result: MirrorResult) -> str:
    status = "OK" if result.success else "FAIL"
    speed = format_speed(result.speed_bytes_per_sec)
    total_time = f"{result.total_time:.2f}s" if result.total_time is not None else "-"
    http_code = result.http_code or "-"
    suffix = f" ({result.error})" if result.error else ""
    return (
        f"{status:<7} {speed:>12} {total_time:>8} {http_code:>5} "
        f"{result.db_url}{suffix}"
    )


def print_results(results: Iterable[MirrorResult]) -> None:
    ordered = sorted(
        results,
        key=lambda result: (
            not result.success,
            -(result.speed_bytes_per_sec or 0.0),
            result.mirror_url,
        ),
    )

    print(
        f"{'status':<7} {'speed':>12} {'time':>8} {'http':>5} mirror",
        flush=True,
    )
    for result in ordered:
        print(format_result_line(result))


def main() -> int:
    args = parse_args()

    if not args.mirrorlist.is_file():
        print(f"mirrorlist file not found: {args.mirrorlist}", file=sys.stderr)
        return 2

    mirrors = extract_mirrors(args.mirrorlist)
    if not mirrors:
        print("no Server entries found in mirrorlist", file=sys.stderr)
        return 2

    results: list[MirrorResult] = []
    total = len(mirrors)

    for index, mirror in enumerate(mirrors, start=1):
        db_url = build_db_url(mirror, args.repo, args.arch)
        print(f"[{index}/{total}] Testing {db_url}", flush=True)
        result = measure_mirror(
            mirror_url=mirror,
            repo=args.repo,
            arch=args.arch,
            connect_timeout=args.connect_timeout,
            max_time=args.max_time,
            insecure=args.insecure,
        )
        results.append(result)
        print(f"[{index}/{total}] {format_result_line(result)}", flush=True)

    print_results(results)
    successful = sum(1 for result in results if result.success)
    print(f"\n{successful}/{len(results)} mirrors succeeded")
    return 0 if successful else 1


if __name__ == "__main__":
    raise SystemExit(main())
