from __future__ import annotations

import json
import logging
from datetime import datetime, timezone
from typing import Any, Mapping, Optional, Sequence

import pymysql

DEFAULT_MYSQL_CONFIG: dict[str, Any] = {
    "host": "127.0.0.1",
    "port": 3306,
    "user": "root",
    "password": "",
    "db": "op",
    "charset": "utf8mb4",
}
DEFAULT_CURRENCY = "EUR"

logger = logging.getLogger(__name__)


def _mysql_config(cfg: Mapping[str, Any]) -> dict[str, Any]:
    merged = dict(DEFAULT_MYSQL_CONFIG)
    merged.update(cfg or {})
    return merged


def _connect_db(mysql_cfg: Mapping[str, Any]):
    return pymysql.connect(
        host=mysql_cfg.get("host", "127.0.0.1"),
        port=int(mysql_cfg.get("port", 3306)),
        user=mysql_cfg.get("user", "root"),
        password=mysql_cfg.get("password", ""),
        db=mysql_cfg.get("db", "op"),
        charset=mysql_cfg.get("charset", "utf8mb4"),
        cursorclass=pymysql.cursors.DictCursor,
    )


def _timestamp_str(value: Optional[Any] = None) -> str:
    if isinstance(value, datetime):
        target = value.astimezone(timezone.utc) if value.tzinfo else value.replace(tzinfo=timezone.utc)
        return target.replace(microsecond=0).isoformat().replace("+00:00", "Z")
    return datetime.utcnow().replace(microsecond=0).isoformat() + "Z"


def _parse_datetime(value: Any) -> Optional[datetime]:
    if value is None or value == "":
        return None
    if isinstance(value, datetime):
        return value
    try:
        return datetime.fromisoformat(str(value).replace("Z", ""))
    except Exception:
        return None


def _normalize_tariff_id(value: Any) -> str:
    if value is None:
        return ""
    return str(value).strip().strip("/")


def _safe_json(value: Any, default: Any):
    if isinstance(value, str):
        try:
            return json.loads(value)
        except Exception:
            return default
    return value if value is not None else default


class TariffService:
    """Centralized tariff CRUD and assignment helper for OCPI."""

    def __init__(
        self,
        mysql_cfg: Mapping[str, Any],
        *,
        fallback_tariffs: Optional[Sequence[Mapping[str, Any]]] = None,
        default_currency: str = DEFAULT_CURRENCY,
    ):
        self.mysql_cfg = _mysql_config(mysql_cfg or {})
        self.fallback_tariffs: list[dict[str, Any]] = [
            dict(item) for item in (fallback_tariffs or []) if isinstance(item, Mapping)
        ]
        self.default_currency = (default_currency or DEFAULT_CURRENCY).upper()
        self._seeded = False

    # ------------------------------------------------------------------ #
    # Schema helpers
    # ------------------------------------------------------------------ #
    def ensure_schema(self) -> None:
        try:
            conn = _connect_db(self.mysql_cfg)
        except Exception:
            return
        try:
            self._ensure_tables(conn)
        finally:
            try:
                conn.close()
            except Exception:
                pass

    def _ensure_tables(self, conn) -> None:
        self._ensure_tariff_table(conn)
        self._ensure_assignment_table(conn)

    def _ensure_tariff_table(self, conn) -> None:
        with conn.cursor() as cur:
            cur.execute(
                """
                CREATE TABLE IF NOT EXISTS op_ocpi_tariffs (
                    tariff_id VARCHAR(255) PRIMARY KEY,
                    name VARCHAR(255) NULL,
                    currency VARCHAR(3) NOT NULL DEFAULT 'EUR',
                    elements JSON NULL,
                    emsp_surcharges JSON NULL,
                    tax_included TINYINT(1) NOT NULL DEFAULT 0,
                    valid_from DATETIME NULL,
                    valid_until DATETIME NULL,
                    payload JSON NOT NULL,
                    last_updated TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP
                ) CHARACTER SET utf8mb4
                """
            )
            self._ensure_tariff_columns(cur)
        conn.commit()

    def _ensure_tariff_columns(self, cur) -> None:
        cur.execute("SHOW COLUMNS FROM op_ocpi_tariffs")
        existing = {row["Field"] for row in cur.fetchall()}
        alterations: dict[str, str] = {
            "name": "ALTER TABLE op_ocpi_tariffs ADD COLUMN name VARCHAR(255) NULL AFTER tariff_id",
            "currency": "ALTER TABLE op_ocpi_tariffs ADD COLUMN currency VARCHAR(3) NOT NULL DEFAULT 'EUR' AFTER name",
            "elements": "ALTER TABLE op_ocpi_tariffs ADD COLUMN elements JSON NULL AFTER currency",
            "emsp_surcharges": "ALTER TABLE op_ocpi_tariffs ADD COLUMN emsp_surcharges JSON NULL AFTER elements",
            "tax_included": "ALTER TABLE op_ocpi_tariffs ADD COLUMN tax_included TINYINT(1) NOT NULL DEFAULT 0 AFTER emsp_surcharges",
            "valid_from": "ALTER TABLE op_ocpi_tariffs ADD COLUMN valid_from DATETIME NULL AFTER tax_included",
            "valid_until": "ALTER TABLE op_ocpi_tariffs ADD COLUMN valid_until DATETIME NULL AFTER valid_from",
            "payload": "ALTER TABLE op_ocpi_tariffs ADD COLUMN payload JSON NOT NULL AFTER valid_until",
            "last_updated": "ALTER TABLE op_ocpi_tariffs ADD COLUMN last_updated TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP AFTER payload",
        }
        for column, statement in alterations.items():
            if column not in existing:
                cur.execute(statement)

    def _ensure_assignment_table(self, conn) -> None:
        with conn.cursor() as cur:
            cur.execute(
                """
                CREATE TABLE IF NOT EXISTS op_ocpi_tariff_assignments (
                    id INT AUTO_INCREMENT PRIMARY KEY,
                    tariff_id VARCHAR(255) NOT NULL,
                    location_id VARCHAR(255) NOT NULL,
                    evse_uid VARCHAR(255) DEFAULT NULL,
                    backend_id INT DEFAULT NULL,
                    UNIQUE KEY uniq_tariff_target (tariff_id, location_id, evse_uid, backend_id),
                    KEY idx_location (location_id),
                    CONSTRAINT fk_tariff_assignment_tariff FOREIGN KEY (tariff_id)
                        REFERENCES op_ocpi_tariffs(tariff_id)
                        ON DELETE CASCADE
                ) CHARACTER SET utf8mb4
                """
            )
        conn.commit()

    # ------------------------------------------------------------------ #
    # Data hydration
    # ------------------------------------------------------------------ #
    def _hydrate_tariff_row(self, row: Mapping[str, Any]) -> dict[str, Any]:
        payload = _safe_json(row.get("payload"), default={})
        if not isinstance(payload, Mapping):
            payload = {}
        elements = _safe_json(row.get("elements"), default=payload.get("elements") or []) or []
        payload = dict(payload)
        tariff_id = _normalize_tariff_id(row.get("tariff_id"))
        if tariff_id and not payload.get("id"):
            payload["id"] = tariff_id
        payload["currency"] = (row.get("currency") or payload.get("currency") or self.default_currency).upper()
        payload["elements"] = elements
        payload["tax_included"] = bool(row.get("tax_included")) or bool(payload.get("tax_included"))
        payload["emsp_surcharges"] = _safe_json(
            row.get("emsp_surcharges"),
            default=payload.get("emsp_surcharges") or [],
        ) or []

        valid_from = row.get("valid_from") or payload.get("valid_from")
        valid_until = row.get("valid_until") or payload.get("valid_until")
        if isinstance(valid_from, datetime):
            payload["valid_from"] = _timestamp_str(valid_from)
        elif valid_from:
            payload["valid_from"] = str(valid_from)
        if isinstance(valid_until, datetime):
            payload["valid_until"] = _timestamp_str(valid_until)
        elif valid_until:
            payload["valid_until"] = str(valid_until)

        if row.get("last_updated"):
            payload["last_updated"] = _timestamp_str(row.get("last_updated"))
        else:
            payload.setdefault("last_updated", _timestamp_str())

        if row.get("name") and not payload.get("tariff_alt_text"):
            payload["tariff_alt_text"] = [
                {"language": "en", "text": row["name"]},
                {"language": "de", "text": row["name"]},
            ]
        return payload

    def _seed_from_fallback(self, conn) -> None:
        if self._seeded or not self.fallback_tariffs:
            return
        try:
            with conn.cursor() as cur:
                cur.execute("SELECT COUNT(*) AS cnt FROM op_ocpi_tariffs")
                row = cur.fetchone() or {}
                existing = int(row.get("cnt") or 0)
        except Exception:
            logger.debug("Skipping tariff seed; count failed", exc_info=True)
            return
        if existing:
            self._seeded = True
            return

        for tariff in self.fallback_tariffs:
            try:
                self.upsert_tariff(tariff, conn=conn, _skip_seed=True)
            except Exception:
                logger.debug("Failed to seed fallback tariff", exc_info=True)
        self._seeded = True

    # ------------------------------------------------------------------ #
    # CRUD operations
    # ------------------------------------------------------------------ #
    def list_tariffs(
        self,
        *,
        date_from: Optional[datetime] = None,
        date_to: Optional[datetime] = None,
        offset: int = 0,
        limit: int = 50,
    ) -> tuple[list[dict[str, Any]], int]:
        try:
            conn = _connect_db(self.mysql_cfg)
        except Exception:
            return self._fallback_tariffs(date_from=date_from, date_to=date_to, offset=offset, limit=limit)

        try:
            self._ensure_tables(conn)
            self._seed_from_fallback(conn)
            return self._fetch_tariffs(conn, date_from=date_from, date_to=date_to, offset=offset, limit=limit)
        finally:
            try:
                conn.close()
            except Exception:
                pass

    def _fetch_tariffs(
        self,
        conn,
        *,
        date_from: Optional[datetime],
        date_to: Optional[datetime],
        offset: int,
        limit: int,
    ) -> tuple[list[dict[str, Any]], int]:
        conditions = []
        params: list[Any] = []
        if date_from:
            conditions.append("last_updated >= %s")
            params.append(date_from)
        if date_to:
            conditions.append("last_updated <= %s")
            params.append(date_to)

        where_clause = f"WHERE {' AND '.join(conditions)}" if conditions else ""
        count_query = f"SELECT COUNT(*) AS cnt FROM op_ocpi_tariffs {where_clause}"
        query = f"""
            SELECT tariff_id, name, currency, elements, emsp_surcharges, tax_included, valid_from, valid_until, payload, last_updated
            FROM op_ocpi_tariffs
            {where_clause}
            ORDER BY last_updated DESC, tariff_id
            LIMIT %s OFFSET %s
        """
        params_with_pagination = list(params) + [max(limit, 0), max(offset, 0)]

        with conn.cursor() as cur:
            cur.execute(count_query, params)
            total_row = cur.fetchone() or {}
            total = int(total_row.get("cnt") or 0)

            cur.execute(query, params_with_pagination)
            rows = cur.fetchall()

        tariffs = [self._hydrate_tariff_row(row) for row in rows or []]
        return tariffs, total

    def _fallback_tariffs(
        self,
        *,
        date_from: Optional[datetime],
        date_to: Optional[datetime],
        offset: int,
        limit: int,
    ) -> tuple[list[dict[str, Any]], int]:
        filtered: list[dict[str, Any]] = []
        for tariff in self.fallback_tariffs:
            payload = dict(tariff)
            last_updated_raw = tariff.get("last_updated")
            try:
                last_updated = _parse_datetime(last_updated_raw) or datetime.utcnow()
            except Exception:
                last_updated = datetime.utcnow()
            if date_from and last_updated < date_from:
                continue
            if date_to and last_updated > date_to:
                continue
            payload.setdefault("last_updated", _timestamp_str(last_updated))
            filtered.append(payload)
        total = len(filtered)
        start = max(offset, 0)
        end = start + max(limit, 0)
        return filtered[start:end], total

    def get_tariff(self, tariff_id: str) -> Optional[dict[str, Any]]:
        normalized = _normalize_tariff_id(tariff_id)
        if not normalized:
            return None

        try:
            conn = _connect_db(self.mysql_cfg)
        except Exception:
            for item in self.fallback_tariffs:
                if _normalize_tariff_id(item.get("id")) == normalized:
                    payload = dict(item)
                    payload.setdefault("last_updated", _timestamp_str())
                    return payload
            return None

        try:
            self._ensure_tables(conn)
            with conn.cursor() as cur:
                cur.execute(
                    """
                    SELECT tariff_id, name, currency, elements, emsp_surcharges, tax_included, valid_from, valid_until, payload, last_updated
                    FROM op_ocpi_tariffs
                    WHERE tariff_id=%s
                    """,
                    (normalized,),
                )
                row = cur.fetchone()
            if not row:
                return None
            return self._hydrate_tariff_row(row)
        finally:
            try:
                conn.close()
            except Exception:
                pass

    def upsert_tariff(
        self,
        payload: Mapping[str, Any],
        *,
        conn=None,
        _skip_seed: bool = False,
    ) -> dict[str, Any]:
        tariff_id = _normalize_tariff_id(payload.get("id") or payload.get("tariff_id"))
        if not tariff_id:
            raise ValueError("Tariff ID is required.")
        if len(tariff_id) > 255:
            raise ValueError("Tariff ID must be 255 characters or fewer.")

        currency = (payload.get("currency") or self.default_currency or DEFAULT_CURRENCY).upper()
        if len(currency) != 3:
            raise ValueError("Currency must be a 3-letter code.")

        valid_from = _parse_datetime(payload.get("valid_from"))
        if payload.get("valid_from") and valid_from is None:
            raise ValueError("Valid from must be a valid date/time.")
        valid_until = _parse_datetime(payload.get("valid_until"))
        if payload.get("valid_until") and valid_until is None:
            raise ValueError("Valid until must be a valid date/time.")
        if valid_from and valid_until and valid_until < valid_from:
            raise ValueError("Valid until must be after valid from.")

        elements = payload.get("elements") or payload.get("price_components") or []
        if isinstance(elements, str):
            try:
                elements = json.loads(elements)
            except Exception as exc:
                raise ValueError("Price elements must be a list or JSON array.") from exc
        if not isinstance(elements, list):
            raise ValueError("Price elements must be a list or JSON array.")

        emsp_surcharges = self._parse_surcharges(payload.get("emsp_surcharges") or payload.get("surcharges") or [])

        alt_texts = self._normalize_alt_texts(payload)
        tax_included_raw = payload.get("tax_included")
        tax_included = str(tax_included_raw).lower() in {"1", "true", "yes", "on"} if tax_included_raw is not None else False

        prepared_payload: dict[str, Any] = {
            "id": tariff_id,
            "currency": currency,
            "elements": elements,
            "tax_included": tax_included,
            "emsp_surcharges": emsp_surcharges,
            "last_updated": _timestamp_str(),
        }
        if valid_from:
            prepared_payload["valid_from"] = _timestamp_str(valid_from)
        if valid_until:
            prepared_payload["valid_until"] = _timestamp_str(valid_until)
        if alt_texts:
            prepared_payload["tariff_alt_text"] = alt_texts
        if payload.get("energy_mix"):
            prepared_payload["energy_mix"] = payload.get("energy_mix")

        own_connection = False
        if conn is None:
            conn = _connect_db(self.mysql_cfg)
            own_connection = True
        try:
            self._ensure_tables(conn)
            if not _skip_seed:
                self._seed_from_fallback(conn)
            with conn.cursor() as cur:
                cur.execute(
                    """
                    INSERT INTO op_ocpi_tariffs (tariff_id, name, currency, elements, emsp_surcharges, tax_included, valid_from, valid_until, payload)
                    VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s)
                    ON DUPLICATE KEY UPDATE
                        name=VALUES(name),
                        currency=VALUES(currency),
                        elements=VALUES(elements),
                        emsp_surcharges=VALUES(emsp_surcharges),
                        tax_included=VALUES(tax_included),
                        valid_from=VALUES(valid_from),
                        valid_until=VALUES(valid_until),
                        payload=VALUES(payload)
                    """,
                    (
                        tariff_id,
                        prepared_payload.get("tariff_alt_text", [{}])[0].get("text") if alt_texts else payload.get("name"),
                        currency,
                        json.dumps(elements, ensure_ascii=False),
                        json.dumps(emsp_surcharges, ensure_ascii=False) if emsp_surcharges else None,
                        1 if tax_included else 0,
                        valid_from.replace(tzinfo=None) if valid_from else None,
                        valid_until.replace(tzinfo=None) if valid_until else None,
                        json.dumps(prepared_payload, ensure_ascii=False),
                    ),
                )
            conn.commit()
        finally:
            if own_connection:
                try:
                    conn.close()
                except Exception:
                    pass

        result = self.get_tariff(tariff_id) or prepared_payload
        return result

    def delete_tariff(self, tariff_id: str) -> bool:
        normalized = _normalize_tariff_id(tariff_id)
        if not normalized:
            raise ValueError("Tariff ID is required.")

        try:
            conn = _connect_db(self.mysql_cfg)
        except Exception as exc:
            raise ValueError("Tariff storage is unavailable.") from exc

        deleted = False
        try:
            self._ensure_tables(conn)
            with conn.cursor() as cur:
                cur.execute("DELETE FROM op_ocpi_tariffs WHERE tariff_id=%s", (normalized,))
                deleted = cur.rowcount > 0
            conn.commit()
            return deleted
        finally:
            try:
                conn.close()
            except Exception:
                pass

    # ------------------------------------------------------------------ #
    # Assignments
    # ------------------------------------------------------------------ #
    def list_assignments(self, tariff_id: Optional[str] = None) -> list[dict[str, Any]]:
        try:
            conn = _connect_db(self.mysql_cfg)
        except Exception:
            return []

        try:
            self._ensure_tables(conn)
            query = """
                SELECT id, tariff_id, location_id, evse_uid, backend_id
                FROM op_ocpi_tariff_assignments
            """
            params: list[Any] = []
            if tariff_id:
                query += " WHERE tariff_id=%s"
                params.append(_normalize_tariff_id(tariff_id))
            query += " ORDER BY location_id, evse_uid"
            with conn.cursor() as cur:
                cur.execute(query, params)
                return cur.fetchall() or []
        finally:
            try:
                conn.close()
            except Exception:
                pass

    def assign_tariff(
        self,
        tariff_id: str,
        *,
        location_id: str,
        evse_uid: Optional[str] = None,
        backend_id: Optional[int] = None,
    ) -> None:
        normalized_tariff = _normalize_tariff_id(tariff_id)
        if not normalized_tariff:
            raise ValueError("Tariff ID is required.")
        normalized_location = _normalize_tariff_id(location_id)
        if not normalized_location:
            raise ValueError("Location ID is required.")

        try:
            conn = _connect_db(self.mysql_cfg)
        except Exception as exc:
            raise ValueError("Tariff storage is unavailable.") from exc

        try:
            self._ensure_tables(conn)
            with conn.cursor() as cur:
                cur.execute("SELECT 1 FROM op_ocpi_tariffs WHERE tariff_id=%s", (normalized_tariff,))
                if not cur.fetchone():
                    raise ValueError("Tariff does not exist.")
                cur.execute(
                    """
                    INSERT INTO op_ocpi_tariff_assignments (tariff_id, location_id, evse_uid, backend_id)
                    VALUES (%s, %s, %s, %s)
                    ON DUPLICATE KEY UPDATE tariff_id=VALUES(tariff_id)
                    """,
                    (normalized_tariff, normalized_location, evse_uid or None, backend_id),
                )
            conn.commit()
        finally:
            try:
                conn.close()
            except Exception:
                pass

    def remove_assignment(
        self,
        tariff_id: str,
        *,
        location_id: str,
        evse_uid: Optional[str] = None,
        backend_id: Optional[int] = None,
    ) -> None:
        normalized_tariff = _normalize_tariff_id(tariff_id)
        normalized_location = _normalize_tariff_id(location_id)
        if not normalized_tariff or not normalized_location:
            raise ValueError("Tariff assignment target is required.")

        try:
            conn = _connect_db(self.mysql_cfg)
        except Exception as exc:
            raise ValueError("Tariff storage is unavailable.") from exc

        try:
            self._ensure_tables(conn)
            with conn.cursor() as cur:
                cur.execute(
                    """
                    DELETE FROM op_ocpi_tariff_assignments
                    WHERE tariff_id=%s AND location_id=%s AND IFNULL(evse_uid, '') = IFNULL(%s, '')
                        AND IFNULL(backend_id, -1) = IFNULL(%s, -1)
                    """,
                    (normalized_tariff, normalized_location, evse_uid or None, backend_id),
                )
            conn.commit()
        finally:
            try:
                conn.close()
            except Exception:
                pass

    # ------------------------------------------------------------------ #
    # Helpers
    # ------------------------------------------------------------------ #
    def _parse_surcharges(self, raw: Any) -> list[dict[str, Any]]:
        if isinstance(raw, str):
            try:
                raw = json.loads(raw)
            except Exception as exc:
                raise ValueError("Surcharges must be a list or JSON array.") from exc
        if raw is None or raw == "":
            return []
        if not isinstance(raw, Sequence) or isinstance(raw, (bytes, bytearray, str)):
            return []
        surcharges: list[dict[str, Any]] = []
        for entry in raw:
            if not isinstance(entry, Mapping):
                continue
            backend_raw = entry.get("backend_id") or entry.get("emsp_id")
            backend_id = None
            if backend_raw not in (None, "", "none"):
                try:
                    backend_id = int(backend_raw)
                except (TypeError, ValueError):
                    raise ValueError("Backend IDs must be integers.")

            percent_raw = entry.get("percent")
            fixed_raw = entry.get("fixed") or entry.get("flat")
            if backend_id is None and percent_raw is None and fixed_raw is None:
                continue

            surcharge_entry: dict[str, Any] = {}
            if backend_id is not None:
                surcharge_entry["backend_id"] = backend_id
            if percent_raw not in (None, ""):
                try:
                    surcharge_entry["percent"] = float(percent_raw)
                except (TypeError, ValueError) as exc:
                    raise ValueError("Percent markup must be numeric.") from exc
            if fixed_raw not in (None, ""):
                try:
                    surcharge_entry["fixed"] = float(fixed_raw)
                except (TypeError, ValueError) as exc:
                    raise ValueError("Fixed markup must be numeric.") from exc
            if not surcharge_entry:
                raise ValueError("At least one surcharge value is required.")
            surcharges.append(surcharge_entry)
        return surcharges

    def _normalize_alt_texts(self, payload: Mapping[str, Any]) -> list[dict[str, Any]]:
        alt_texts: list[dict[str, Any]] = []
        existing = payload.get("tariff_alt_text")
        if isinstance(existing, list):
            alt_texts.extend([entry for entry in existing if isinstance(entry, Mapping) and entry.get("text")])
        name_en = payload.get("name_en") or payload.get("title_en")
        name_de = payload.get("name_de") or payload.get("title_de")
        generic_name = payload.get("name")
        if name_en:
            alt_texts.append({"language": "en", "text": str(name_en)})
        if name_de:
            alt_texts.append({"language": "de", "text": str(name_de)})
        if generic_name:
            if not any(entry.get("language") == "en" for entry in alt_texts):
                alt_texts.append({"language": "en", "text": str(generic_name)})
            if not any(entry.get("language") == "de" for entry in alt_texts):
                alt_texts.append({"language": "de", "text": str(generic_name)})
        return alt_texts
