import json
import logging
from datetime import datetime
from typing import Any, Mapping, Optional, Sequence

import pymysql

DEFAULT_MYSQL_CONFIG: dict[str, Any] = {
    "host": "127.0.0.1",
    "user": "root",
    "password": "",
    "db": "op",
    "charset": "utf8mb4",
}


def _mysql_config(cfg: Mapping[str, Any]) -> dict[str, Any]:
    merged = dict(DEFAULT_MYSQL_CONFIG)
    merged.update(cfg or {})
    return merged


def _connect_db(mysql_cfg: Mapping[str, Any]):
    return pymysql.connect(
        host=mysql_cfg.get("host", "127.0.0.1"),
        user=mysql_cfg.get("user", "root"),
        password=mysql_cfg.get("password", ""),
        db=mysql_cfg.get("db", "op"),
        charset=mysql_cfg.get("charset", "utf8mb4"),
        cursorclass=pymysql.cursors.DictCursor,
    )


def timestamp_str(value: Optional[Any]) -> str:
    if isinstance(value, datetime):
        return value.replace(microsecond=0).isoformat() + "Z"
    if isinstance(value, str):
        return value
    return datetime.utcnow().replace(microsecond=0).isoformat() + "Z"


class LocationRepository:
    """Shared helper to read and manage OCPI location/EVSE payloads."""

    def __init__(self, mysql_cfg: Mapping[str, Any], tariff_service=None):
        self.mysql_cfg = _mysql_config(mysql_cfg)
        self.tariff_service = tariff_service

    @staticmethod
    def _normalize_station_id(raw: Optional[str]) -> str:
        if not raw:
            return ""
        return str(raw).strip().strip("/")

    @staticmethod
    def _normalize_evse_uid(raw: Optional[str]) -> str:
        """Ensure EVSE IDs are stored with non-null values."""

        if raw is None:
            return ""
        return str(raw).strip()

    def _available_columns(self, conn, table: str) -> set[str]:
        try:
            with conn.cursor() as cur:
                cur.execute(f"SHOW COLUMNS FROM {table}")
                return {row["Field"] for row in cur.fetchall()}
        except Exception:
            return set()

    def _ensure_tables(self, conn) -> None:
        with conn.cursor() as cur:
            cur.execute(
                """
                CREATE TABLE IF NOT EXISTS op_ocpi_location_overrides (
                    location_id VARCHAR(255) NOT NULL,
                    evse_uid VARCHAR(255) NOT NULL DEFAULT '',
                    payload JSON NOT NULL,
                    last_updated TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
                    PRIMARY KEY (location_id, evse_uid)
                ) CHARACTER SET utf8mb4
                """
            )
            cur.execute(
                """
                CREATE TABLE IF NOT EXISTS op_ocpi_location_sync_log (
                    id INT AUTO_INCREMENT PRIMARY KEY,
                    backend_id INT NULL,
                    location_id VARCHAR(255) NULL,
                    evse_uid VARCHAR(255) NULL,
                    status_code INT NULL,
                    success TINYINT(1) NOT NULL DEFAULT 0,
                    error_message TEXT NULL,
                    created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
                    KEY idx_backend_id (backend_id),
                    KEY idx_location_id (location_id)
                ) CHARACTER SET utf8mb4
                """
            )
        conn.commit()

    def _fetch_redirect_rows(self, conn) -> list[dict[str, Any]]:
        cols = self._available_columns(conn, "op_redirects")
        if not cols:
            return []
        selected = [
            col
            for col in (
                "id",
                "source_url",
                "ws_url",
                "activity",
                "location_name",
                "location_link",
                "webui_remote_access_url",
                "load_management_remote_access_url",
                "created_at",
            )
            if col in cols
        ]
        if not selected:
            return []
        with conn.cursor() as cur:
            cur.execute(f"SELECT {', '.join(selected)} FROM op_redirects")
            return cur.fetchall()

    def _fetch_backend_assignments(self, conn) -> dict[str, list[dict[str, Any]]]:
        try:
            with conn.cursor() as cur:
                cur.execute(
                    """
                    SELECT wb.station_id,
                           wb.enabled,
                           wb.priority,
                           b.backend_id,
                           b.name,
                           b.url,
                           b.token,
                           b.modules,
                           b.enabled AS backend_enabled
                    FROM op_ocpi_wallbox_backends AS wb
                    LEFT JOIN op_ocpi_backends AS b ON b.backend_id = wb.backend_id
                    """
                )
                assignments: dict[str, list[dict[str, Any]]] = {}
                for row in cur.fetchall():
                    assignments.setdefault(row.get("station_id") or "", []).append(row)
                return assignments
        except Exception:
            logging.getLogger("location_repository").exception("Failed to load backend assignments")
            return {}

    def _fetch_overrides(self, conn) -> list[dict[str, Any]]:
        try:
            with conn.cursor() as cur:
                cur.execute(
                    "SELECT location_id, evse_uid, payload, last_updated FROM op_ocpi_location_overrides"
                )
                return cur.fetchall()
        except Exception:
            logging.getLogger("location_repository").exception("Failed to load overrides")
            return []

    def _base_location_payload(self, row: Mapping[str, Any]) -> dict[str, Any]:
        station_id = self._normalize_station_id(
            row.get("source_url") or row.get("ws_url") or row.get("id")
        )
        last_updated = timestamp_str(row.get("created_at"))
        evse_uid = station_id or "unknown-evse"
        evse_id = row.get("ws_url") or evse_uid
        payload = {
            "id": station_id or evse_uid,
            "name": row.get("location_name") or station_id or "Unknown",
            "last_updated": last_updated,
            "evses": [
                {
                    "uid": evse_uid,
                    "evse_id": evse_id,
                    "status": "AVAILABLE",
                    "last_updated": last_updated,
                    "connectors": [
                        {
                            "id": "1",
                            "standard": "IEC_62196_T2",
                            "format": "UNKNOWN",
                            "power_type": "AC_1_PHASE",
                        }
                    ],
                }
            ],
        }
        link = row.get("location_link")
        if link:
            payload["directions"] = [link]
        remote_access = row.get("webui_remote_access_url") or row.get(
            "load_management_remote_access_url"
        )
        if remote_access:
            payload.setdefault("additional_geo_info", []).append(
                {"source": "remote_access", "value": remote_access}
            )
        return payload

    def _apply_location_override(
        self, base: dict[str, Any], override: Mapping[str, Any]
    ) -> dict[str, Any]:
        payload = override.get("payload") or {}
        merged = {**base, **payload}
        if "evses" in payload:
            merged["evses"] = payload.get("evses") or []
        merged["last_updated"] = timestamp_str(override.get("last_updated"))
        return merged

    def _apply_evse_override(
        self, base: dict[str, Any], override: Mapping[str, Any]
    ) -> dict[str, Any]:
        evse_uid = override.get("evse_uid") or ""
        payload = override.get("payload") or {}
        payload.setdefault("uid", evse_uid)
        evses: list[dict[str, Any]] = list(base.get("evses") or [])
        replaced = False
        for idx, evse in enumerate(evses):
            if evse.get("uid") == evse_uid:
                evses[idx] = {**evse, **payload}
                replaced = True
                break
        if not replaced:
            evses.append(payload)
        base["evses"] = evses
        base["last_updated"] = timestamp_str(override.get("last_updated"))
        return base

    def _attach_overrides(
        self, location: dict[str, Any], overrides: Sequence[Mapping[str, Any]]
    ) -> dict[str, Any]:
        result = dict(location)
        for override in overrides:
            if override.get("evse_uid"):
                result = self._apply_evse_override(result, override)
            else:
                result = self._apply_location_override(result, override)
        return result

    def _tariff_assignment_map(self) -> dict[tuple[str, str | None], list[str]]:
        if not self.tariff_service:
            return {}
        try:
            assignments = self.tariff_service.list_assignments()
        except Exception:
            logging.getLogger("location_repository").debug(
                "Failed to load tariff assignments", exc_info=True
            )
            return {}
        mapping: dict[tuple[str, str | None], list[str]] = {}
        for assignment in assignments:
            location_id = self._normalize_station_id(assignment.get("location_id"))
            evse_uid = assignment.get("evse_uid") or None
            tariff_id = assignment.get("tariff_id")
            if not location_id or not tariff_id:
                continue
            mapping.setdefault((location_id, evse_uid), []).append(str(tariff_id))
        return mapping

    def _attach_tariffs_to_payload(
        self, payload: dict[str, Any], station_id: str, assignments: Mapping[tuple[str, str | None], list[str]]
    ) -> dict[str, Any]:
        result = dict(payload)
        normalized_station = self._normalize_station_id(station_id)
        location_tariffs = assignments.get((normalized_station, None))
        if location_tariffs:
            result["tariffs"] = location_tariffs
        evses = []
        for evse in result.get("evses", []):
            evse_uid = evse.get("uid")
            tariffs = assignments.get((normalized_station, evse_uid))
            updated_evse = dict(evse)
            if tariffs:
                updated_evse["tariffs"] = tariffs
            evses.append(updated_evse)
        if evses:
            result["evses"] = evses
        return result

    def location_records(self) -> list[dict[str, Any]]:
        try:
            conn = _connect_db(self.mysql_cfg)
        except Exception:
            logging.getLogger("location_repository").exception("Failed to connect to database")
            return []

        try:
            self._ensure_tables(conn)
            redirects = self._fetch_redirect_rows(conn)
            assignments = self._fetch_backend_assignments(conn)
            overrides = self._fetch_overrides(conn)
            tariff_assignments = self._tariff_assignment_map()
            override_map: dict[str, list[dict[str, Any]]] = {}
            for override in overrides:
                key = self._normalize_station_id(override.get("location_id"))
                override_map.setdefault(key, []).append(override)

            records: list[dict[str, Any]] = []
            for row in redirects:
                station_id = self._normalize_station_id(
                    row.get("source_url") or row.get("ws_url")
                )
                if not station_id:
                    continue
                base_payload = self._base_location_payload(row)
                payload = self._attach_overrides(
                    base_payload, override_map.get(station_id, [])
                )
                payload = self._attach_tariffs_to_payload(
                    payload, station_id, tariff_assignments
                )
                records.append(
                    {
                        "location": payload,
                        "station_id": station_id,
                        "backends": assignments.get(station_id, []),
                    }
                )

            for station_id, backend_rows in assignments.items():
                if any(r.get("station_id") == station_id for r in redirects):
                    continue
                payload = {
                    "id": station_id,
                    "name": station_id or "Unknown",
                    "last_updated": timestamp_str(None),
                    "evses": [
                        {
                            "uid": station_id,
                            "evse_id": station_id,
                            "status": "AVAILABLE",
                            "last_updated": timestamp_str(None),
                            "connectors": [],
                        }
                    ],
                }
                payload = self._attach_overrides(
                    payload, override_map.get(station_id, [])
                )
                payload = self._attach_tariffs_to_payload(
                    payload, station_id, tariff_assignments
                )
                records.append(
                    {
                        "location": payload,
                        "station_id": station_id,
                        "backends": backend_rows,
                    }
                )
            return records
        finally:
            try:
                conn.close()
            except Exception:
                pass

    def list_locations(self) -> list[dict[str, Any]]:
        return [record.get("location") for record in self.location_records()]

    def get_location(self, location_id: str) -> Optional[dict[str, Any]]:
        normalized = self._normalize_station_id(location_id)
        for record in self.location_records():
            if self._normalize_station_id(record.get("station_id")) == normalized:
                return record.get("location")
        return None

    def upsert_location_override(
        self, location_id: str, payload: Mapping[str, Any], evse_uid: Optional[str] = None
    ) -> None:
        conn = _connect_db(self.mysql_cfg)
        try:
            self._ensure_tables(conn)
            with conn.cursor() as cur:
                cur.execute(
                    """
                    INSERT INTO op_ocpi_location_overrides (location_id, evse_uid, payload)
                    VALUES (%s, %s, %s)
                    ON DUPLICATE KEY UPDATE payload = VALUES(payload)
                    """,
                    (location_id, self._normalize_evse_uid(evse_uid), json.dumps(payload)),
                )
            conn.commit()
        finally:
            try:
                conn.close()
            except Exception:
                pass

    def _load_override_payload(
        self, conn, location_id: str, evse_uid: str
    ) -> dict[str, Any]:
        try:
            with conn.cursor() as cur:
                cur.execute(
                    """
                    SELECT payload
                    FROM op_ocpi_location_overrides
                    WHERE location_id=%s AND evse_uid=%s
                    LIMIT 1
                    """,
                    (location_id, evse_uid),
                )
                row = cur.fetchone()
        except Exception:
            logging.getLogger("location_repository").debug(
                "Failed to fetch override payload", exc_info=True
            )
            return {}
        if not row:
            return {}
        payload = row.get("payload")
        if isinstance(payload, str):
            try:
                return json.loads(payload)
            except Exception:
                return {}
        return payload or {}

    def merge_location_override(
        self, location_id: str, payload: Mapping[str, Any], evse_uid: Optional[str] = None
    ) -> dict[str, Any]:
        conn = _connect_db(self.mysql_cfg)
        normalized_evse_uid = self._normalize_evse_uid(evse_uid)
        try:
            self._ensure_tables(conn)
            existing = self._load_override_payload(conn, location_id, normalized_evse_uid)
            merged = {**existing, **payload}
            if normalized_evse_uid:
                merged.setdefault("uid", normalized_evse_uid)
            with conn.cursor() as cur:
                cur.execute(
                    """
                    INSERT INTO op_ocpi_location_overrides (location_id, evse_uid, payload)
                    VALUES (%s, %s, %s)
                    ON DUPLICATE KEY UPDATE payload = VALUES(payload)
                    """,
                    (location_id, normalized_evse_uid, json.dumps(merged)),
                )
            conn.commit()
            return merged
        finally:
            try:
                conn.close()
            except Exception:
                pass

    def delete_location_override(self, location_id: str) -> None:
        conn = _connect_db(self.mysql_cfg)
        try:
            self._ensure_tables(conn)
            with conn.cursor() as cur:
                cur.execute(
                    "DELETE FROM op_ocpi_location_overrides WHERE location_id=%s",
                    (location_id,),
                )
            conn.commit()
        finally:
            try:
                conn.close()
            except Exception:
                pass

    def delete_evse_override(self, location_id: str, evse_uid: str) -> None:
        conn = _connect_db(self.mysql_cfg)
        try:
            self._ensure_tables(conn)
            with conn.cursor() as cur:
                cur.execute(
                    """
                    DELETE FROM op_ocpi_location_overrides
                    WHERE location_id=%s AND evse_uid=%s
                    """,
                    (location_id, self._normalize_evse_uid(evse_uid)),
                )
            conn.commit()
        finally:
            try:
                conn.close()
            except Exception:
                pass

    def log_sync_result(
        self,
        backend_id: Optional[int],
        location_id: Optional[str],
        evse_uid: Optional[str],
        success: bool,
        status_code: Optional[int] = None,
        error_message: Optional[str] = None,
    ) -> None:
        try:
            conn = _connect_db(self.mysql_cfg)
        except Exception:
            logging.getLogger("location_repository").exception(
                "Failed to open DB connection for sync logging"
            )
            return

        try:
            self._ensure_tables(conn)
            with conn.cursor() as cur:
                cur.execute(
                    """
                    INSERT INTO op_ocpi_location_sync_log (
                        backend_id, location_id, evse_uid, status_code, success, error_message
                    ) VALUES (%s, %s, %s, %s, %s, %s)
                    """,
                    (
                        backend_id,
                        location_id,
                        evse_uid,
                        status_code,
                        1 if success else 0,
                        error_message,
                    ),
                )
            conn.commit()
        except Exception:
            logging.getLogger("location_repository").exception("Failed to log sync result")
        finally:
            try:
                conn.close()
            except Exception:
                pass

    def latest_sync_results(
        self, location_ids: Optional[Sequence[str]] = None, *, limit: int = 500
    ) -> dict[str, list[dict[str, Any]]]:
        """Return the most recent sync status per backend/EVSE for the given locations."""

        try:
            conn = _connect_db(self.mysql_cfg)
        except Exception:
            logging.getLogger("location_repository").exception(
                "Failed to open DB connection for sync status"
            )
            return {}

        try:
            self._ensure_tables(conn)
            where = []
            params: list[Any] = []
            if location_ids:
                where.append(
                    "location_id IN ({})".format(",".join(["%s"] * len(location_ids)))
                )
                params.extend(location_ids)
            where_clause = f"WHERE {' AND '.join(where)}" if where else ""
            query = f"""
                SELECT l.location_id,
                       l.evse_uid,
                       l.backend_id,
                       l.status_code,
                       l.success,
                       l.error_message,
                       l.created_at,
                       b.name AS backend_name
                FROM op_ocpi_location_sync_log AS l
                LEFT JOIN op_ocpi_backends AS b ON b.backend_id = l.backend_id
                {where_clause}
                ORDER BY l.created_at DESC
                LIMIT %s
            """
            params.append(limit)
            with conn.cursor() as cur:
                cur.execute(query, params)
                rows = cur.fetchall()
        except Exception:
            logging.getLogger("location_repository").exception(
                "Failed to fetch sync status"
            )
            return {}
        finally:
            try:
                conn.close()
            except Exception:
                pass

        seen: set[tuple[str, Optional[int], Optional[str]]] = set()
        status_map: dict[str, list[dict[str, Any]]] = {}
        for row in rows or []:
            loc_id = self._normalize_station_id(row.get("location_id"))
            key = (loc_id, row.get("backend_id"), row.get("evse_uid"))
            if key in seen:
                continue
            seen.add(key)
            status_map.setdefault(loc_id, []).append(
                {
                    "backend_id": row.get("backend_id"),
                    "backend_name": row.get("backend_name"),
                    "evse_uid": row.get("evse_uid"),
                    "status_code": row.get("status_code"),
                    "success": bool(row.get("success", 0)),
                    "error_message": row.get("error_message"),
                    "created_at": timestamp_str(row.get("created_at")),
                }
            )
        return status_map
