from __future__ import annotations

from datetime import datetime, timezone
from threading import Lock
from typing import Any, Callable, Iterable, Mapping, Optional, Sequence

EMSP_TOKEN_TABLE_SQL = """
CREATE TABLE IF NOT EXISTS op_emsp_tokens (
    uid VARCHAR(255) PRIMARY KEY,
    auth_id VARCHAR(255),
    issuer VARCHAR(255),
    type VARCHAR(32) NOT NULL DEFAULT 'RFID',
    contract_id VARCHAR(255),
    valid TINYINT(1) DEFAULT 1,
    whitelist VARCHAR(32),
    local_rfid VARCHAR(255),
    status VARCHAR(32) NOT NULL DEFAULT 'valid',
    source VARCHAR(16) NOT NULL DEFAULT 'emsp',
    valid_until DATETIME NULL,
    updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
    KEY idx_auth_id (auth_id),
    KEY idx_local_rfid (local_rfid),
    KEY idx_status (status),
    KEY idx_source (source)
) CHARACTER SET utf8mb4
"""

_STATUS_FROM_WHITELIST = {
    "ALLOWED": "valid",
    "WHITELISTED": "valid",
    "EXPIRED": "expired",
    "BLOCKED": "blocked",
    "NO_LONGER_VALID": "blocked",
}

_WHITELIST_BY_STATUS = {
    "valid": "ALLOWED",
    "expired": "EXPIRED",
    "blocked": "BLOCKED",
}


def _as_bool(value: Any) -> Optional[bool]:
    if isinstance(value, bool):
        return value
    if value is None:
        return None
    if isinstance(value, (int, float)):
        return bool(value)
    if isinstance(value, str):
        lowered = value.strip().lower()
        if lowered in {"1", "true", "yes", "y", "valid"}:
            return True
        if lowered in {"0", "false", "no", "n", "invalid"}:
            return False
    return None


def _serialize_dt(value: Any) -> Optional[str]:
    if value is None:
        return None
    if isinstance(value, datetime):
        if value.tzinfo:
            return value.astimezone(timezone.utc).isoformat()
        return value.replace(tzinfo=timezone.utc).isoformat()
    return str(value)


def _parse_dt(value: Any) -> Optional[datetime]:
    if value is None:
        return None
    if isinstance(value, datetime):
        return value
    if isinstance(value, (int, float)):
        try:
            return datetime.fromtimestamp(value, tz=timezone.utc)
        except Exception:
            return None
    if isinstance(value, str) and value.strip():
        try:
            return datetime.fromisoformat(value.replace("Z", "+00:00"))
        except Exception:
            return None
    return None


def _normalize_status(value: Any, *, valid: Optional[bool], whitelist: Optional[str]) -> str:
    if isinstance(value, str) and value.strip():
        lowered = value.strip().lower()
        if lowered in {"valid", "active"}:
            return "valid"
        if lowered in {"expired", "inactive"}:
            return "expired"
        if lowered in {"blocked", "locked", "blacklisted"}:
            return "blocked"
    wl_status = _STATUS_FROM_WHITELIST.get(str(whitelist or "").upper())
    if wl_status:
        return wl_status
    if valid is False:
        return "blocked"
    return "valid"


def _normalize_source(value: Any) -> str:
    if isinstance(value, str) and value.strip():
        normalized = value.strip().lower()
        if normalized in {"emsp", "cpo"}:
            return normalized
    return "emsp"


def _normalize_token_type(value: Any) -> str:
    if isinstance(value, str) and value.strip():
        return value.strip().upper()
    return "RFID"


class TokenCache:
    def __init__(self) -> None:
        self._cache: dict[str, dict[str, Any]] = {}
        self._hits = 0
        self._misses = 0
        self._last_reset = datetime.utcnow()
        self._lock = Lock()

    def get(self, uid: str) -> Optional[dict[str, Any]]:
        key = uid.strip().upper()
        with self._lock:
            cached = self._cache.get(key)
            if cached is None:
                self._misses += 1
                return None
            self._hits += 1
            return dict(cached)

    def set(self, token: Mapping[str, Any]) -> None:
        uid = str(token.get("uid") or "").strip()
        if not uid:
            return
        key = uid.upper()
        with self._lock:
            self._cache[key] = dict(token)

    def invalidate(self, uid: Optional[str] = None) -> None:
        with self._lock:
            if uid:
                key = uid.strip().upper()
                self._cache.pop(key, None)
            else:
                self._cache.clear()
                self._hits = 0
                self._misses = 0
                self._last_reset = datetime.utcnow()

    def stats(self) -> dict[str, Any]:
        with self._lock:
            total_requests = self._hits + self._misses
            hit_rate = (self._hits / total_requests) if total_requests else 0.0
            return {
                "entries": len(self._cache),
                "hits": self._hits,
                "misses": self._misses,
                "hit_rate": round(hit_rate, 4),
                "last_reset": _serialize_dt(self._last_reset),
            }


class TokenService:
    def __init__(self, connect_fn: Callable[[], Any]) -> None:
        self._connect_fn = connect_fn
        self.cache = TokenCache()

    def ensure_table(self, conn: Any) -> None:
        with conn.cursor() as cur:
            cur.execute(EMSP_TOKEN_TABLE_SQL)
            cur.execute("SHOW COLUMNS FROM op_emsp_tokens LIKE 'status'")
            if not cur.fetchone():
                cur.execute(
                    "ALTER TABLE op_emsp_tokens ADD COLUMN status VARCHAR(32) NOT NULL DEFAULT 'valid'"
                )
            cur.execute("SHOW COLUMNS FROM op_emsp_tokens LIKE 'type'")
            if not cur.fetchone():
                cur.execute(
                    "ALTER TABLE op_emsp_tokens ADD COLUMN type VARCHAR(32) NOT NULL DEFAULT 'RFID'"
                )
            cur.execute("SHOW COLUMNS FROM op_emsp_tokens LIKE 'contract_id'")
            if not cur.fetchone():
                cur.execute("ALTER TABLE op_emsp_tokens ADD COLUMN contract_id VARCHAR(255)")
            cur.execute("SHOW COLUMNS FROM op_emsp_tokens LIKE 'source'")
            if not cur.fetchone():
                cur.execute(
                    "ALTER TABLE op_emsp_tokens ADD COLUMN source VARCHAR(16) NOT NULL DEFAULT 'emsp'"
                )
            cur.execute("SHOW COLUMNS FROM op_emsp_tokens LIKE 'valid_until'")
            if not cur.fetchone():
                cur.execute("ALTER TABLE op_emsp_tokens ADD COLUMN valid_until DATETIME NULL")
            cur.execute("SHOW INDEX FROM op_emsp_tokens WHERE Key_name='idx_status'")
            if not cur.fetchall():
                cur.execute("ALTER TABLE op_emsp_tokens ADD KEY idx_status (status)")
            cur.execute("SHOW INDEX FROM op_emsp_tokens WHERE Key_name='idx_source'")
            if not cur.fetchall():
                cur.execute("ALTER TABLE op_emsp_tokens ADD KEY idx_source (source)")
        conn.commit()

    @staticmethod
    def normalize_payload(entry: Mapping[str, Any]) -> Optional[dict[str, Any]]:
        uid = entry.get("uid") or entry.get("UID") or entry.get("id")
        if not uid:
            return None
        whitelist = entry.get("whitelistStatus") or entry.get("whitelist")
        token_type = _normalize_token_type(entry.get("type"))
        raw_valid = entry.get("valid")
        if raw_valid is None and "validity" in entry:
            raw_valid = entry.get("validity")
        valid = _as_bool(raw_valid)

        status = _normalize_status(entry.get("status"), valid=valid, whitelist=whitelist)
        valid_until = _parse_dt(
            entry.get("validUntil") or entry.get("valid_until") or entry.get("expires_at")
        )
        local_rfid_raw = entry.get("localRfid") or entry.get("rfid") or entry.get("localId")

        return {
            "uid": str(uid).strip(),
            "auth_id": entry.get("authId") or entry.get("Auth-ID") or entry.get("auth_id"),
            "issuer": entry.get("issuer"),
            "type": token_type,
            "valid": valid if valid is not None else status == "valid",
            "whitelist": whitelist.strip().upper() if isinstance(whitelist, str) else None,
            "local_rfid": str(local_rfid_raw).strip().upper() if local_rfid_raw else None,
            "status": status,
            "source": _normalize_source(entry.get("source") or entry.get("origin")),
            "valid_until": valid_until,
            "contract_id": entry.get("contract_id") or entry.get("contractId") or entry.get("contract"),
            "last_updated": _serialize_dt(entry.get("last_updated") or entry.get("lastUpdated")),
        }

    def _row_to_token(self, row: Mapping[str, Any]) -> dict[str, Any]:
        status = _normalize_status(
            row.get("status"),
            valid=_as_bool(row.get("valid", 1)),
            whitelist=row.get("whitelist"),
        )
        whitelist = row.get("whitelist") or _WHITELIST_BY_STATUS.get(status, "ALLOWED")
        valid_flag = status == "valid" and _as_bool(row.get("valid", 1)) is not False
        valid_until = row.get("valid_until")
        if isinstance(valid_until, str):
            valid_until = _parse_dt(valid_until)
        token_type = _normalize_token_type(row.get("type"))
        contract_id = row.get("contract_id") or row.get("auth_id") or row.get("uid")
        token = {
            "uid": row.get("uid"),
            "auth_id": row.get("auth_id"),
            "issuer": row.get("issuer") or "Pipelet",
            "valid": valid_flag,
            "whitelist": whitelist,
            "local_rfid": row.get("local_rfid"),
            "status": status,
            "source": _normalize_source(row.get("source")),
            "valid_until": valid_until,
            "updated_at": row.get("updated_at"),
            "type": token_type,
            "contract_id": contract_id,
        }
        return token

    def list_tokens(
        self,
        *,
        search: Optional[str] = None,
        status: Optional[str] = None,
        date_from: Optional[datetime] = None,
        date_to: Optional[datetime] = None,
        offset: int = 0,
        limit: int = 50,
    ) -> tuple[list[dict[str, Any]], int]:
        try:
            conn = self._connect_fn()
        except Exception:
            return [], 0

        try:
            self.ensure_table(conn)
            where: list[str] = ["1=1"]
            params: list[Any] = []
            if search:
                where.append(
                    "(UPPER(uid)=%s OR UPPER(auth_id)=%s OR UPPER(local_rfid)=%s)"
                )
                normalized = search.strip().upper()
                params.extend([normalized, normalized, normalized])
            if status:
                where.append("LOWER(status)=%s")
                params.append(status.strip().lower())
            if date_from:
                where.append("updated_at >= %s")
                params.append(date_from)
            if date_to:
                where.append("updated_at <= %s")
                params.append(date_to)
            where_sql = " AND ".join(where)

            with conn.cursor() as cur:
                cur.execute(
                    f"SELECT COUNT(*) AS total_count FROM op_emsp_tokens WHERE {where_sql}",
                    params,
                )
                total = cur.fetchone().get("total_count", 0)
                cur.execute(
                    f"""
                    SELECT uid, auth_id, issuer, type, contract_id, valid, whitelist, local_rfid, status, source, valid_until, updated_at
                    FROM op_emsp_tokens
                    WHERE {where_sql}
                    ORDER BY updated_at DESC
                    LIMIT %s OFFSET %s
                    """,
                    params + [limit, offset],
                )
                rows = cur.fetchall()
            tokens = [self._row_to_token(r) for r in rows]
            for token in tokens:
                self.cache.set(token)
            return tokens, int(total)
        finally:
            try:
                conn.close()
            except Exception:
                pass

    def get_token(self, uid: str) -> Optional[dict[str, Any]]:
        cached = self.cache.get(uid)
        if cached:
            return cached
        try:
            conn = self._connect_fn()
        except Exception:
            return None
        try:
            self.ensure_table(conn)
            with conn.cursor() as cur:
                cur.execute(
                    """
                    SELECT uid, auth_id, issuer, type, contract_id, valid, whitelist, local_rfid, status, source, valid_until, updated_at
                    FROM op_emsp_tokens
                    WHERE uid=%s
                    LIMIT 1
                    """,
                    (uid,),
                )
                row = cur.fetchone()
            if not row:
                return None
            token = self._row_to_token(row)
            self.cache.set(token)
            return token
        finally:
            try:
                conn.close()
            except Exception:
                pass

    def upsert_tokens(self, entries: Iterable[Mapping[str, Any]]) -> int:
        normalized = [
            self.normalize_payload(entry)
            for entry in entries
            if isinstance(entry, Mapping)
        ]
        rows = [item for item in normalized if item]
        if not rows:
            return 0
        try:
            conn = self._connect_fn()
        except Exception:
            return 0
        try:
            self.ensure_table(conn)
            with conn.cursor() as cur:
                for token in rows:
                    token["type"] = _normalize_token_type(token.get("type"))
                    token.setdefault("contract_id", token.get("auth_id") or token.get("uid"))
                    status = _normalize_status(
                        token.get("status"), valid=token.get("valid"), whitelist=token.get("whitelist")
                    )
                    whitelist = token.get("whitelist") or _WHITELIST_BY_STATUS.get(status, "ALLOWED")
                    cur.execute(
                        """
                        INSERT INTO op_emsp_tokens (uid, auth_id, issuer, type, contract_id, valid, whitelist, local_rfid, status, source, valid_until)
                        VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
                        ON DUPLICATE KEY UPDATE
                            auth_id=VALUES(auth_id),
                            issuer=VALUES(issuer),
                            type=VALUES(type),
                            contract_id=VALUES(contract_id),
                            valid=VALUES(valid),
                            whitelist=VALUES(whitelist),
                            local_rfid=VALUES(local_rfid),
                            status=VALUES(status),
                            source=VALUES(source),
                            valid_until=VALUES(valid_until)
                        """,
                            (
                            token.get("uid"),
                            token.get("auth_id"),
                            token.get("issuer"),
                            token["type"],
                            token.get("contract_id"),
                            1 if status == "valid" else 0,
                            whitelist,
                            token.get("local_rfid"),
                            status,
                            token.get("source") or "emsp",
                            token.get("valid_until"),
                        ),
                    )
            conn.commit()
            for token in rows:
                token["type"] = _normalize_token_type(token.get("type"))
                token.setdefault("contract_id", token.get("auth_id") or token.get("uid"))
                status = _normalize_status(
                    token.get("status"), valid=token.get("valid"), whitelist=token.get("whitelist")
                )
                token["status"] = status
                token["whitelist"] = token.get("whitelist") or _WHITELIST_BY_STATUS.get(
                    status, "ALLOWED"
                )
                token["valid"] = status == "valid"
                self.cache.set(token)
            return len(rows)
        finally:
            try:
                conn.close()
            except Exception:
                pass

    def delete_token(self, uid: str) -> bool:
        try:
            conn = self._connect_fn()
        except Exception:
            return False
        try:
            self.ensure_table(conn)
            with conn.cursor() as cur:
                cur.execute("DELETE FROM op_emsp_tokens WHERE uid=%s", (uid,))
            conn.commit()
            self.cache.invalidate(uid)
            return True
        finally:
            try:
                conn.close()
            except Exception:
                pass

    def sync_tokens(
        self, entries: Sequence[Mapping[str, Any]], *, clear_cache: bool = False
    ) -> dict[str, Any]:
        stored = self.upsert_tokens(entries)
        if clear_cache:
            self.cache.invalidate()
        return {
            "stored": stored,
            "cache": self.cache.stats(),
        }

    def cache_stats(self) -> dict[str, Any]:
        return self.cache.stats()

    @staticmethod
    def serialize_rest(token: Mapping[str, Any]) -> dict[str, Any]:
        return {
            "uid": token.get("uid"),
            "authId": token.get("auth_id"),
            "issuer": token.get("issuer"),
            "valid": bool(token.get("valid")),
            "status": token.get("status"),
            "source": token.get("source"),
            "whitelistStatus": token.get("whitelist"),
            "localRfid": token.get("local_rfid"),
            "validUntil": _serialize_dt(token.get("valid_until")),
            "updatedAt": _serialize_dt(token.get("updated_at")),
        }

    @staticmethod
    def serialize_ocpi(token: Mapping[str, Any]) -> dict[str, Any]:
        whitelist_raw = token.get("whitelist") or _WHITELIST_BY_STATUS.get(
            token.get("status", "valid"), "ALLOWED"
        )
        whitelist = whitelist_raw.strip().upper() if isinstance(whitelist_raw, str) else whitelist_raw
        last_updated_raw = token.get("last_updated") or token.get("updated_at")
        contract_id = token.get("contract_id") or token.get("auth_id") or token.get("uid")
        payload = {
            "uid": token.get("uid"),
            "type": _normalize_token_type(token.get("type")),
            "auth_id": token.get("auth_id") or token.get("uid"),
            "issuer": token.get("issuer") or "Pipelet",
            "valid": bool(token.get("status", "valid") == "valid"),
            "whitelist": whitelist,
            "contract_id": contract_id,
            "last_updated": _serialize_dt(last_updated_raw),
        }
        if token.get("local_rfid"):
            payload["visual_number"] = token.get("local_rfid")
        if token.get("valid_until"):
            payload["valid_until"] = _serialize_dt(token.get("valid_until"))
        payload["status"] = token.get("status")
        payload["source"] = token.get("source")
        return payload
