#!/usr/bin/env python3
"""
DMR User Database Loader  —  dmr_users
=======================================
Downloads the RadioID.net DMR user database, loads US and Canadian
registered users into the  dmr_users  table, then cross-references
against  uls_cur  to populate license details.

Source:
  RadioID.net  —  https://radioid.net/static/users.json

dmr_users schema:
  dmr_id             INTEGER   PRIMARY KEY  (RadioID / CCS7 ID)
  callsign           TEXT
  first_name         TEXT
  last_name          TEXT
  city               TEXT
  state              TEXT
  country_name       TEXT      Full country name as provided by RadioID
  lic_country        CHAR(2)   'US' or 'CA' if matched in uls_cur
  lic_operator_class TEXT      From uls_cur
  lic_first_name     TEXT      From uls_cur
  lic_last_name      TEXT      From uls_cur
  lic_street_address TEXT      From uls_cur
  lic_city           TEXT      From uls_cur
  lic_state          TEXT      From uls_cur
  lic_zip_code       TEXT      From uls_cur
  lic_expired_date   DATE      From uls_cur (US only)
  lic_matched        BOOLEAN   TRUE if call sign found in uls_cur

Usage:
    python3 dmr_loader.py [options]

Options:
    --host          PostgreSQL host      (default: localhost)
    --port          PostgreSQL port      (default: 5432)
    --dbname        Database name        (default: ham_radio)
    --user          Database user        (default: postgres)
    --password      Database password    (default: env HAM_DB_PASSWORD)
    --schema        Target schema        (default: public)
    --skip-download Use existing JSON file instead of downloading
    --dmr-json      Local path for DMR JSON  (default: ./users.json)
    --batch-size    Rows per upsert batch    (default: 2000)
  
Requirements:
    pip install psycopg2-binary requests

Notes:
  - uls_cur must be populated before running this script.
  - Only US and Canadian DMR registrations are loaded; other countries
    have no corresponding data in uls_cur and are skipped entirely.
  - All text fields are stored in upper-case.
  - RadioID uses full country names; the COUNTRY_MAP dict maps these to
    the 2-letter codes used in uls_cur.
  - Re-running this script fully refreshes dmr_users and re-runs the
    cross-reference, picking up any changes in both datasets.

2021-03-21 KB1B nedecn@kb1b.org
"""

import argparse
import json
import logging
import os
import sys
import time
from contextlib import contextmanager
from pathlib import Path

import psycopg2
import psycopg2.extras

try:
    import requests
    HAS_REQUESTS = True
except ImportError:
    HAS_REQUESTS = False

# ---------------------------------------------------------------------------
# Logging
# ---------------------------------------------------------------------------
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s  %(levelname)-8s  %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
)
log = logging.getLogger(__name__)

# ---------------------------------------------------------------------------
# URLs / defaults
# ---------------------------------------------------------------------------
DMR_URL          = "https://radioid.net/static/users.json"
DEFAULT_DMR_JSON = "./users.json"

# ---------------------------------------------------------------------------
# RadioID full country name → uls_cur 2-letter country code
# Only US and CA have license data in uls_cur; all others will be unmatched.
# ---------------------------------------------------------------------------
COUNTRY_MAP = {
    "UNITED STATES":          "US",
    "UNITED STATES OF AMERICA": "US",
    "USA":                    "US",
    "US":                     "US",
    "CANADA":                 "CA",
    "CA":                     "CA",
}

# ---------------------------------------------------------------------------
# DDL
# ---------------------------------------------------------------------------

DMR_DDL = """
CREATE TABLE IF NOT EXISTS {schema}.dmr_users (
    dmr_id             INTEGER  PRIMARY KEY,
    callsign           TEXT,
    first_name         TEXT,
    last_name          TEXT,
    city               TEXT,
    state              TEXT,
    country_name       TEXT,
    lic_country        CHAR(2),
    lic_operator_class TEXT,
    lic_first_name     TEXT,
    lic_last_name      TEXT,
    lic_street_address TEXT,
    lic_city           TEXT,
    lic_state          TEXT,
    lic_zip_code       TEXT,
    lic_expired_date   DATE,
    lic_matched        BOOLEAN  DEFAULT FALSE
);
"""

DMR_INDEXES = [
    "CREATE INDEX IF NOT EXISTS idx_dmr_callsign"
    "  ON {schema}.dmr_users (callsign);",
    "CREATE INDEX IF NOT EXISTS idx_dmr_country_name"
    "  ON {schema}.dmr_users (country_name);",
    "CREATE INDEX IF NOT EXISTS idx_dmr_lic_country"
    "  ON {schema}.dmr_users (lic_country);",
    "CREATE INDEX IF NOT EXISTS idx_dmr_lic_matched"
    "  ON {schema}.dmr_users (lic_matched);",
    "CREATE INDEX IF NOT EXISTS idx_dmr_last_name"
    "  ON {schema}.dmr_users (last_name);",
]

# Migrate columns added after initial release
DMR_MIGRATE = [
    "ALTER TABLE {schema}.dmr_users ADD COLUMN IF NOT EXISTS"
    "  lic_country        CHAR(2);",
    "ALTER TABLE {schema}.dmr_users ADD COLUMN IF NOT EXISTS"
    "  lic_operator_class TEXT;",
    "ALTER TABLE {schema}.dmr_users ADD COLUMN IF NOT EXISTS"
    "  lic_first_name     TEXT;",
    "ALTER TABLE {schema}.dmr_users ADD COLUMN IF NOT EXISTS"
    "  lic_last_name      TEXT;",
    "ALTER TABLE {schema}.dmr_users ADD COLUMN IF NOT EXISTS"
    "  lic_street_address TEXT;",
    "ALTER TABLE {schema}.dmr_users ADD COLUMN IF NOT EXISTS"
    "  lic_city           TEXT;",
    "ALTER TABLE {schema}.dmr_users ADD COLUMN IF NOT EXISTS"
    "  lic_state          TEXT;",
    "ALTER TABLE {schema}.dmr_users ADD COLUMN IF NOT EXISTS"
    "  lic_zip_code       TEXT;",
    "ALTER TABLE {schema}.dmr_users ADD COLUMN IF NOT EXISTS"
    "  lic_expired_date   DATE;",
    "ALTER TABLE {schema}.dmr_users ADD COLUMN IF NOT EXISTS"
    "  lic_matched        BOOLEAN DEFAULT FALSE;",
]

UPSERT_SQL = """
INSERT INTO {schema}.dmr_users (
    dmr_id, callsign, first_name, last_name,
    city, state, country_name,
    lic_country, lic_operator_class,
    lic_first_name, lic_last_name,
    lic_street_address, lic_city, lic_state, lic_zip_code,
    lic_expired_date, lic_matched
)
VALUES (
    %s, %s, %s, %s,
    %s, %s, %s,
    %s, %s,
    %s, %s,
    %s, %s, %s, %s,
    %s, %s
)
ON CONFLICT (dmr_id) DO UPDATE SET
    callsign           = EXCLUDED.callsign,
    first_name         = EXCLUDED.first_name,
    last_name          = EXCLUDED.last_name,
    city               = EXCLUDED.city,
    state              = EXCLUDED.state,
    country_name       = EXCLUDED.country_name,
    lic_country        = EXCLUDED.lic_country,
    lic_operator_class = EXCLUDED.lic_operator_class,
    lic_first_name     = EXCLUDED.lic_first_name,
    lic_last_name      = EXCLUDED.lic_last_name,
    lic_street_address = EXCLUDED.lic_street_address,
    lic_city           = EXCLUDED.lic_city,
    lic_state          = EXCLUDED.lic_state,
    lic_zip_code       = EXCLUDED.lic_zip_code,
    lic_expired_date   = EXCLUDED.lic_expired_date,
    lic_matched        = EXCLUDED.lic_matched;
"""

# ---------------------------------------------------------------------------
# Schema setup
# ---------------------------------------------------------------------------

def ensure_table(conn, schema: str) -> None:
    """Create dmr_users table and indexes; migrate schema from older versions."""
    with conn.cursor() as cur:
        cur.execute(f"CREATE SCHEMA IF NOT EXISTS {schema};")
        cur.execute(DMR_DDL.format(schema=schema))
        for stmt in DMR_MIGRATE:
            cur.execute(stmt.format(schema=schema))
        for idx in DMR_INDEXES:
            cur.execute(idx.format(schema=schema))
    conn.commit()
    log.info("dmr_users table and indexes verified/created.")


# ---------------------------------------------------------------------------
# HTTP download
# ---------------------------------------------------------------------------

def download_file(url: str, local_path: str, label: str) -> None:
    log.info("Downloading %s from %s", label, url)
    if HAS_REQUESTS:
        _dl_requests(url, local_path, label)
    else:
        _dl_urllib(url, local_path, label)


def _dl_requests(url: str, local_path: str, label: str) -> None:
    import requests as req
    start = time.time()
    with req.get(url, stream=True, timeout=120) as r:
        r.raise_for_status()
        total = int(r.headers.get("content-length", 0)) or None
        done  = 0
        with open(local_path, "wb") as fh:
            for chunk in r.iter_content(chunk_size=65536):
                if chunk:
                    fh.write(chunk)
                    done += len(chunk)
                    _progress(label, done, total, start)
    print()
    log.info("%s: %.1f MB in %.1f s → %s",
             label, done / 1_048_576, time.time() - start, local_path)


def _dl_urllib(url: str, local_path: str, label: str) -> None:
    import urllib.request
    log.info("(requests not installed; using urllib)")
    start = time.time()
    done  = 0
    with urllib.request.urlopen(url, timeout=120) as resp:
        total_s = resp.getheader("Content-Length")
        total   = int(total_s) if total_s else None
        with open(local_path, "wb") as fh:
            while True:
                chunk = resp.read(65536)
                if not chunk:
                    break
                fh.write(chunk)
                done += len(chunk)
                _progress(label, done, total, start)
    print()
    log.info("%s: %.1f MB in %.1f s → %s",
             label, done / 1_048_576, time.time() - start, local_path)


def _progress(label: str, done: int, total: int | None, start: float) -> None:
    mb    = done / 1_048_576
    speed = done / max(time.time() - start, 0.001) / 1_048_576
    if total:
        pct = done / total * 100
        tmb = total / 1_048_576
        print(f"\r  {label:<4}  {mb:7.1f} / {tmb:.1f} MB"
              f"  ({pct:5.1f}%)  {speed:.2f} MB/s   ",
              end="", flush=True)
    else:
        print(f"\r  {label:<4}  {mb:7.1f} MB  {speed:.2f} MB/s   ",
              end="", flush=True)


# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------

def _up(value) -> str | None:
    """Return string value upper-cased and stripped, or None if empty."""
    if not value:
        return None
    s = str(value).strip()
    return s.upper() if s else None


def _load_uls_cur(conn, schema: str) -> dict:
    """
    Load uls_cur into a dict keyed by call_sign for fast lookup.
    Only loads US and CA rows since those are the only ones we can match.
    Returns: { 'CALLSIGN': (country, operator_class, first_name, last_name,
                             street_address, city, state, zip_code,
                             expired_date) }
    """
    log.info("  Loading uls_cur into memory for cross-reference ...")
    t0 = time.time()
    with conn.cursor() as cur:
        cur.execute(f"""
            SELECT call_sign, country, operator_class,
                   first_name, last_name, street_address,
                   city, state, zip_code, expired_date
            FROM   {schema}.uls_cur
            WHERE  country IN ('US', 'CA');
        """)
        rows = cur.fetchall()

    uls = {row[0]: row[1:] for row in rows}
    elapsed = time.time() - t0
    log.info("  %d uls_cur records loaded in %.1f s.", len(uls), elapsed)
    return uls


# ---------------------------------------------------------------------------
# Main loader
# ---------------------------------------------------------------------------

def load_dmr(conn, schema: str, json_path: str, batch_size: int) -> None:
    """
    Parse the RadioID users.json, load only US and Canadian users,
    cross-reference against uls_cur, and upsert into dmr_users.
    """
    log.info("=== Loading DMR user database ===")

    # ── Read JSON ──────────────────────────────────────────────────────────
    log.info("  Parsing %s ...", json_path)
    t0 = time.time()
    with open(json_path, encoding="utf-8", errors="replace") as fh:
        raw = json.load(fh)

    # RadioID wraps the list in {"users": [...]} or serves it as a bare list
    if isinstance(raw, dict):
        users = raw.get("users", raw.get("results", []))
    else:
        users = raw

    log.info("  %d total DMR registrations found (%.1f s).",
             len(users), time.time() - t0)

    # ── Load uls_cur for cross-reference ───────────────────────────────────
    uls = _load_uls_cur(conn, schema)

    # ── Process and upsert ─────────────────────────────────────────────────
    sql   = UPSERT_SQL.format(schema=schema)
    batch = []

    total       = 0
    matched_us  = 0
    matched_ca  = 0
    unmatched   = 0
    skipped     = 0

    t0 = time.time()

    with conn.cursor() as cur:
        for user in users:
            # --- Extract RadioID fields ---
            dmr_id = user.get("id") or user.get("dmr_id")
            if not dmr_id:
                skipped += 1
                continue
            try:
                dmr_id = int(dmr_id)
            except (ValueError, TypeError):
                skipped += 1
                continue

            callsign     = _up(user.get("callsign", ""))
            first_name   = _up(user.get("fname",    ""))
            last_name    = _up(user.get("name",     "")
                               or user.get("surname", ""))
            city         = _up(user.get("city",     ""))
            state        = _up(user.get("state",    ""))
            country_raw  = _up(user.get("country",  ""))

            # Map RadioID country name to uls_cur 2-letter code.
            # Skip non-US/CA records — uls_cur only covers US and CA.
            lic_country = COUNTRY_MAP.get(country_raw) if country_raw else None
            if lic_country not in ("US", "CA"):
                skipped += 1
                continue

            # --- Cross-reference against uls_cur ---
            lic_data = uls.get(callsign) if callsign else None

            if lic_data:
                (lic_ctry, lic_op_class, lic_fname, lic_lname,
                 lic_addr, lic_city, lic_state, lic_zip,
                 lic_exp) = lic_data
                lic_matched = True
                if lic_ctry == "US":
                    matched_us += 1
                else:
                    matched_ca += 1
            else:
                lic_ctry = lic_country   # best guess from RadioID country
                lic_op_class = lic_fname = lic_lname = None
                lic_addr = lic_city = lic_state = lic_zip = None
                lic_exp  = None
                lic_matched = False
                unmatched += 1

            row = (
                dmr_id,
                callsign,
                first_name,
                last_name,
                city,
                state,
                country_raw,
                lic_ctry,
                lic_op_class,
                lic_fname,
                lic_lname,
                lic_addr,
                lic_city,
                lic_state,
                lic_zip,
                lic_exp,
                lic_matched,
            )
            batch.append(row)
            total += 1

            if len(batch) >= batch_size:
                psycopg2.extras.execute_batch(cur, sql, batch,
                                              page_size=batch_size)
                batch = []
                elapsed = time.time() - t0
                rate = total / max(elapsed, 0.001)
                print(f"\r  {total:,} rows processed  ({rate:.0f}/s)   ",
                      end="", flush=True)

        if batch:
            psycopg2.extras.execute_batch(cur, sql, batch,
                                          page_size=batch_size)

    print()
    conn.commit()

    elapsed = time.time() - t0
    log.info("  DMR load complete in %.1f s.", elapsed)
    log.info("  Total rows upserted : %d", total)
    log.info("  Skipped (non-US/CA or no ID): %d", skipped)
    log.info("  ── Match summary ──────────────────")
    log.info("  Matched US  (uls_cur): %d", matched_us)
    log.info("  Matched CA  (uls_cur): %d", matched_ca)
    log.info("  Unmatched            : %d", unmatched)
    if total > 0:
        match_pct = (matched_us + matched_ca) / total * 100
        log.info("  Match rate (US+CA)   : %.1f%%", match_pct)

    # ── Delete stale rows (DMR IDs no longer in the RadioID export) ────────
    log.info("  Purging stale DMR rows ...")
    _purge_stale(conn, schema, users)


def _purge_stale(conn, schema: str, users: list) -> None:
    """Delete dmr_users rows whose DMR ID is no longer in the RadioID export."""
    # Build active ID set — US/CA only, matching the load filter
    active_ids = set()
    for user in users:
        dmr_id = user.get("id") or user.get("dmr_id")
        if not dmr_id:
            continue
        try:
            dmr_id = int(dmr_id)
        except (ValueError, TypeError):
            continue
        country_raw = _up(user.get("country", ""))
        if COUNTRY_MAP.get(country_raw) not in ("US", "CA"):
            continue
        active_ids.add(dmr_id)

    if not active_ids:
        return

    deleted = 0
    chunk_size = 1000
    active_list = list(active_ids)

    with conn.cursor() as cur:
        cur.execute(f"SELECT dmr_id FROM {schema}.dmr_users;")
        existing = [row[0] for row in cur.fetchall()]
        stale    = [i for i in existing if i not in active_ids]

        for i in range(0, len(stale), chunk_size):
            chunk = stale[i : i + chunk_size]
            cur.execute(
                f"DELETE FROM {schema}.dmr_users WHERE dmr_id = ANY(%s);",
                (chunk,),
            )
            deleted += cur.rowcount

    conn.commit()
    log.info("  %d stale DMR rows removed.", deleted)


# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------

def parse_args() -> argparse.Namespace:
    p = argparse.ArgumentParser(
        description="Load RadioID DMR users and cross-reference with uls_cur.",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    p.add_argument("--host",       default="localhost",   help="PostgreSQL host")
    p.add_argument("--port",       default=5432, type=int, help="PostgreSQL port")
    p.add_argument("--dbname",     default="ham_radio",   help="Database name")
    p.add_argument("--user",       default="postgres",    help="Database user")
    p.add_argument("--password",   default=None,
                   help="Database password (or set HAM_DB_PASSWORD env var)")
    p.add_argument("--schema",     default="public",      help="Target schema")
    p.add_argument("--skip-download", action="store_true",
                   help="Skip download; use existing JSON file")
    p.add_argument("--dmr-json",   default=DEFAULT_DMR_JSON,
                   help="Local path for RadioID users.json")
    p.add_argument("--batch-size", default=2000, type=int,
                   help="Rows per upsert batch")
    return p.parse_args()


@contextmanager
def get_connection(args: argparse.Namespace):
    password = args.password or os.environ.get("HAM_DB_PASSWORD", "")
    conn = psycopg2.connect(
        host=args.host, port=args.port,
        dbname=args.dbname, user=args.user, password=password,
    )
    try:
        yield conn
    finally:
        conn.close()


# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------

def main() -> None:
    args = parse_args()

    # ── Download ───────────────────────────────────────────────────────────
    if not args.skip_download:
        download_file(DMR_URL, args.dmr_json, "DMR")
    else:
        if not Path(args.dmr_json).exists():
            log.error("DMR JSON not found at %s", args.dmr_json)
            sys.exit(1)
        log.info("Skipping download; using: %s", args.dmr_json)

    # ── Database ───────────────────────────────────────────────────────────
    log.info(
        "Connecting to PostgreSQL: host=%s port=%s dbname=%s user=%s schema=%s",
        args.host, args.port, args.dbname, args.user, args.schema,
    )

    with get_connection(args) as conn:
        # Verify uls_cur exists and has data
        with conn.cursor() as cur:
            cur.execute(f"""
                SELECT COUNT(*) FROM information_schema.tables
                WHERE  table_schema = %s AND table_name = 'uls_cur';
            """, (args.schema,))
            if cur.fetchone()[0] == 0:
                log.error(
                    "uls_cur table not found in schema '%s'. "
                    "Run uls_cur_loader.py first.", args.schema
                )
                sys.exit(1)

        ensure_table(conn, args.schema)
        t_start = time.time()
        load_dmr(conn, args.schema, args.dmr_json, args.batch_size)
        log.info("Total elapsed: %.1f s.", time.time() - t_start)

    log.info("Done.")


if __name__ == "__main__":
    main()

