import json
import asyncio
import logging
import os
import time
from contextlib import asynccontextmanager
from datetime import datetime
from typing import Any, Iterable, Mapping, Sequence
from urllib.parse import urlparse, urlunparse

import aiomysql
import requests

from services.ocpi_utils import FailureNotifier, setup_logging

CONFIG_FILE = "config.json"
try:
    with open(CONFIG_FILE, "r", encoding="utf-8") as f:
        _config = json.load(f)
except FileNotFoundError:
    _config = {}

OCPI_CFG = _config.get("ocpi", {})
OCPI_CDRS_ENDPOINT = OCPI_CFG.get("cdrs_endpoint", "http://localhost/ocpi/cdrs")
OCPI_RETRY_LIMIT = int(OCPI_CFG.get("retry_limit", 3))
OCPI_RETRY_DELAY_SECONDS = float(OCPI_CFG.get("retry_delay_seconds", 2))


def _as_enabled(value, default=False):
    if value is None:
        return default
    if isinstance(value, str):
        return value.strip().lower() in {"1", "true", "yes", "on"}
    return bool(value)


OCPI_FORWARD_ENABLED_DEFAULT = _as_enabled(OCPI_CFG.get("forward_enabled"), False)
MYSQL_CONFIG = _config.get(
    "mysql",
    {
        "host": "127.0.0.1",
        "user": "root",
        "password": "",
        "db": "op",
        "charset": "utf8",
    },
)


@asynccontextmanager
async def db_connection():
    conn = await aiomysql.connect(**MYSQL_CONFIG)
    try:
        yield conn
    finally:
        conn.close()


async def _is_forwarding_enabled() -> bool:
    """Return True when OCPI CDR forwarding should be executed."""

    try:
        async with db_connection() as conn:
            async with conn.cursor(aiomysql.DictCursor) as cur:
                await cur.execute(
                    "SELECT config_value FROM op_config WHERE config_key=%s",
                    ("ocpi_backend_enabled",),
                )
                row = await cur.fetchone()
                if not row:
                    return OCPI_FORWARD_ENABLED_DEFAULT
                return _as_enabled(row.get("config_value"), OCPI_FORWARD_ENABLED_DEFAULT)
    except Exception:
        logger.exception("Failed to read ocpi_backend_enabled flag from database")
        return OCPI_FORWARD_ENABLED_DEFAULT


async def _ensure_backend_tables(cur):
    await cur.execute(
        """
        CREATE TABLE IF NOT EXISTS op_ocpi_backends (
            backend_id INT NOT NULL AUTO_INCREMENT PRIMARY KEY,
            name VARCHAR(255) NOT NULL,
            url VARCHAR(1024) NOT NULL,
            remote_versions_url VARCHAR(1024) DEFAULT NULL,
            peer_versions_url VARCHAR(1024) DEFAULT NULL,
            active_version VARCHAR(16) DEFAULT NULL,
            token TEXT,
            peer_token TEXT,
            credentials_token TEXT,
            modules VARCHAR(255) NOT NULL DEFAULT 'cdrs',
            enabled TINYINT(1) NOT NULL DEFAULT 1,
            last_credentials_status VARCHAR(255) NULL,
            last_credentials_at TIMESTAMP NULL DEFAULT NULL,
            created_at TIMESTAMP NULL DEFAULT CURRENT_TIMESTAMP,
            updated_at TIMESTAMP NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
            KEY idx_op_ocpi_backends_enabled (enabled)
        )
        """
    )

    await cur.execute("SHOW COLUMNS FROM op_ocpi_backends LIKE 'remote_versions_url'")
    if not await cur.fetchone():
        await cur.execute(
            "ALTER TABLE op_ocpi_backends ADD COLUMN remote_versions_url VARCHAR(1024) DEFAULT NULL AFTER url"
        )

    await cur.execute("SHOW COLUMNS FROM op_ocpi_backends LIKE 'peer_versions_url'")
    if not await cur.fetchone():
        await cur.execute(
            "ALTER TABLE op_ocpi_backends ADD COLUMN peer_versions_url VARCHAR(1024) DEFAULT NULL AFTER remote_versions_url"
        )

    await cur.execute("SHOW COLUMNS FROM op_ocpi_backends LIKE 'active_version'")
    if not await cur.fetchone():
        await cur.execute(
            "ALTER TABLE op_ocpi_backends ADD COLUMN active_version VARCHAR(16) DEFAULT NULL AFTER peer_versions_url"
        )

    await cur.execute("SHOW COLUMNS FROM op_ocpi_backends LIKE 'peer_token'")
    if not await cur.fetchone():
        await cur.execute(
            "ALTER TABLE op_ocpi_backends ADD COLUMN peer_token TEXT NULL AFTER token"
        )

    await cur.execute("SHOW COLUMNS FROM op_ocpi_backends LIKE 'credentials_token'")
    if not await cur.fetchone():
        await cur.execute(
            "ALTER TABLE op_ocpi_backends ADD COLUMN credentials_token TEXT NULL AFTER peer_token"
        )

    await cur.execute("SHOW COLUMNS FROM op_ocpi_backends LIKE 'last_credentials_status'")
    if not await cur.fetchone():
        await cur.execute(
            "ALTER TABLE op_ocpi_backends ADD COLUMN last_credentials_status VARCHAR(255) NULL AFTER enabled"
        )

    await cur.execute("SHOW COLUMNS FROM op_ocpi_backends LIKE 'last_credentials_at'")
    if not await cur.fetchone():
        await cur.execute(
            "ALTER TABLE op_ocpi_backends ADD COLUMN last_credentials_at TIMESTAMP NULL DEFAULT NULL AFTER last_credentials_status"
        )
    await cur.execute(
        """
        CREATE TABLE IF NOT EXISTS op_ocpi_wallbox_backends (
            station_id VARCHAR(50) NOT NULL,
            backend_id INT NOT NULL,
            enabled TINYINT(1) NOT NULL DEFAULT 1,
            priority INT NOT NULL DEFAULT 100,
            created_at TIMESTAMP NULL DEFAULT CURRENT_TIMESTAMP,
            updated_at TIMESTAMP NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
            PRIMARY KEY (station_id, backend_id),
            KEY idx_op_ocpi_wallbox_backends_backend_id (backend_id)
        )
        """
    )


async def _ensure_export_table(cur):
    await cur.execute(
        """
        CREATE TABLE IF NOT EXISTS op_ocpi_exports (
            id INT AUTO_INCREMENT PRIMARY KEY,
            station_id VARCHAR(255),
            backend_id INT,
            backend_name VARCHAR(255),
            transaction_id VARCHAR(255),
            payload JSON,
            success TINYINT(1),
            response_status INT,
            response_body TEXT,
            retry_count INT DEFAULT 0,
            should_retry TINYINT(1) DEFAULT 0,
            record_type VARCHAR(32) DEFAULT 'cdr',
            created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
        )
        """
    )
    await cur.execute("SHOW COLUMNS FROM op_ocpi_exports LIKE 'backend_id'")
    if not await cur.fetchone():
        await cur.execute(
            "ALTER TABLE op_ocpi_exports ADD COLUMN backend_id INT NULL AFTER station_id"
        )
    await cur.execute("SHOW COLUMNS FROM op_ocpi_exports LIKE 'backend_name'")
    if not await cur.fetchone():
        await cur.execute(
            "ALTER TABLE op_ocpi_exports ADD COLUMN backend_name VARCHAR(255) NULL AFTER backend_id"
        )
    await cur.execute("SHOW COLUMNS FROM op_ocpi_exports LIKE 'retry_count'")
    if not await cur.fetchone():
        await cur.execute(
            "ALTER TABLE op_ocpi_exports ADD COLUMN retry_count INT NULL DEFAULT 0 AFTER response_body"
        )
    await cur.execute("SHOW COLUMNS FROM op_ocpi_exports LIKE 'should_retry'")
    if not await cur.fetchone():
        await cur.execute(
            "ALTER TABLE op_ocpi_exports ADD COLUMN should_retry TINYINT(1) NULL DEFAULT 0 AFTER retry_count"
        )
    await cur.execute("SHOW COLUMNS FROM op_ocpi_exports LIKE 'record_type'")
    if not await cur.fetchone():
        await cur.execute(
            "ALTER TABLE op_ocpi_exports ADD COLUMN record_type VARCHAR(32) NULL DEFAULT 'cdr' AFTER should_retry"
        )


async def _ensure_sync_run_table(cur):
    await cur.execute(
        """
        CREATE TABLE IF NOT EXISTS op_ocpi_sync_runs (
            id INT AUTO_INCREMENT PRIMARY KEY,
            job_name VARCHAR(128) NOT NULL,
            module VARCHAR(64) NOT NULL,
            direction VARCHAR(32) NOT NULL,
            backend_id INT NULL,
            backend_name VARCHAR(255) NULL,
            record_type VARCHAR(32) NULL,
            duration_ms INT NULL,
            success TINYINT(1) NOT NULL DEFAULT 0,
            status_code INT NULL,
            detail TEXT NULL,
            created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
            KEY idx_module_direction (module, direction),
            KEY idx_created_at (created_at)
        )
        """
    )


def _truncate_detail(detail: Any) -> str:
    if detail is None:
        return ""
    text = str(detail)
    return text[:1000]


class ForwarderMetrics:
    def __init__(self):
        self.requests = 0
        self.successes = 0
        self.failures = 0
        self.retries = 0
        self.token_errors = 0
        self.latency_ms_sum = 0.0
        self.latency_ms_max = 0.0
        self.last_failure_timestamp = 0.0

    def record_attempt(
        self,
        *,
        duration_ms: float,
        success: bool,
        retry_count: int,
        status: int,
    ) -> None:
        self.requests += 1
        self.retries += max(retry_count, 0)
        self.latency_ms_sum += duration_ms
        self.latency_ms_max = max(self.latency_ms_max, duration_ms)
        if success:
            self.successes += 1
        else:
            self.failures += 1
            self.last_failure_timestamp = time.time()
        if status in (401, 403):
            self.token_errors += 1

    def as_prometheus(self) -> str:
        lines = [
            f"pipelet_ocpi_forward_requests_total {self.requests}",
            f"pipelet_ocpi_forward_success_total {self.successes}",
            f"pipelet_ocpi_forward_failure_total {self.failures}",
            f"pipelet_ocpi_forward_retry_total {self.retries}",
            f"pipelet_ocpi_forward_token_error_total {self.token_errors}",
            f"pipelet_ocpi_forward_latency_ms_sum {self.latency_ms_sum:.2f}",
            f"pipelet_ocpi_forward_latency_ms_max {self.latency_ms_max:.2f}",
            f"pipelet_ocpi_forward_last_failure_timestamp {self.last_failure_timestamp:.0f}",
        ]
        return "\n".join(lines) + "\n"


FORWARDER_METRICS_PORT = int(
    os.environ.get("OCPI_FORWARDER_METRICS_PORT") or OCPI_CFG.get("metrics_port", 0) or 0
)
setup_logging(_config, logger_name="ocpi_forwarder")
logger = logging.getLogger("ocpi_forwarder")
forwarder_metrics = ForwarderMetrics()
alert_notifier = FailureNotifier.from_config(OCPI_CFG.get("alerts"), logger=logger)
_metrics_server_started = False
_metrics_server: asyncio.AbstractServer | None = None


async def _handle_metrics_request(reader, writer):
    body = forwarder_metrics.as_prometheus().encode("utf-8")
    writer.write(
        b"HTTP/1.1 200 OK\r\n"
        b"Content-Type: text/plain; version=0.0.4\r\n"
        + f"Content-Length: {len(body)}\r\n\r\n".encode()
        + body
    )
    await writer.drain()
    writer.close()


async def _ensure_metrics_server():
    global _metrics_server_started, _metrics_server
    if _metrics_server_started or not FORWARDER_METRICS_PORT:
        return
    try:
        _metrics_server = await asyncio.start_server(
            _handle_metrics_request, host="0.0.0.0", port=FORWARDER_METRICS_PORT
        )
        _metrics_server_started = True
        logger.info(
            "Started forwarder metrics endpoint",
            extra={"event": "ocpi_forward_metrics", "port": FORWARDER_METRICS_PORT},
        )
    except Exception:
        logger.debug("Failed to start metrics endpoint", exc_info=True)


def _ensure_timestamp(value: Any) -> str | None:
    if not value:
        return None
    if isinstance(value, str):
        return value
    if isinstance(value, datetime):
        return (
            value.replace(microsecond=0).isoformat().replace("+00:00", "Z")
            if value.tzinfo
            else value.isoformat()
        )
    return None


def _unique_strings(items: Iterable[Any]) -> list[str]:
    seen: set[str] = set()
    ordered: list[str] = []
    for item in items:
        if item is None:
            continue
        value = str(item)
        if value not in seen and value.strip():
            ordered.append(value)
            seen.add(value)
    return ordered


def _extract_tariffs(start_info: Mapping[str, Any], stop_info: Mapping[str, Any]) -> list[str]:
    tariff_candidates: list[Any] = []
    for payload in (start_info, stop_info):
        if not isinstance(payload, Mapping):
            continue
        for key in ("tariffId", "tariff_id", "tariffIds", "tariffs"):
            value = payload.get(key)
            if isinstance(value, (list, tuple, set)):
                tariff_candidates.extend(value)
            elif value is not None:
                tariff_candidates.append(value)
    return _unique_strings(tariff_candidates)


def _extract_meter_values(stop_info: Mapping[str, Any]) -> list[dict[str, Any]]:
    meter_values: list[dict[str, Any]] = []
    entries: Sequence[Any] = []
    if isinstance(stop_info, Mapping):
        transaction_data = stop_info.get("transactionData")
        if isinstance(transaction_data, Sequence) and not isinstance(
            transaction_data, (str, bytes)
        ):
            entries = transaction_data

    for entry in entries:
        if not isinstance(entry, Mapping):
            continue
        timestamp = entry.get("timestamp")
        sampled_values = []
        for value in entry.get("sampledValue", []) or []:
            if not isinstance(value, Mapping):
                continue
            mapped = {
                "measurand": value.get("measurand"),
                "value": value.get("value"),
                "unit": value.get("unit"),
                "context": value.get("context"),
                "format": value.get("format"),
                "location": value.get("location"),
                "phase": value.get("phase"),
            }
            sampled_values.append({k: v for k, v in mapped.items() if v is not None})
        if sampled_values:
            meter_values.append(
                {"timestamp": timestamp, "sampled_values": sampled_values}
            )
    return meter_values


def _derive_auth_method(start_info: Mapping[str, Any], stop_info: Mapping[str, Any]) -> str | None:
    for payload in (start_info, stop_info):
        if isinstance(payload, Mapping):
            explicit_method = payload.get("authMethod") or payload.get(
                "authorizationType"
            )
            if isinstance(explicit_method, str):
                return explicit_method.upper()

    is_remote_start = False
    if isinstance(start_info, Mapping):
        is_remote_start = bool(
            start_info.get("isRemoteStart")
            or start_info.get("remoteStartIdTag")
            or start_info.get("remoteStartId")
        )

    if is_remote_start:
        return "COMMAND"

    return "AUTH_REQUEST"


def _build_cdr(station_id: str, start_info: Mapping[str, Any], stop_info: Mapping[str, Any]) -> dict[str, Any]:
    cdr_id = stop_info.get("transactionId") or start_info.get("transactionId")
    tariffs = _extract_tariffs(start_info, stop_info)
    meter_values = _extract_meter_values(stop_info)
    cdr: dict[str, Any] = {
        "id": str(cdr_id) if cdr_id is not None else None,
        "start_date_time": _ensure_timestamp(start_info.get("sessionStartTimestamp")),
        "end_date_time": _ensure_timestamp(stop_info.get("timestamp")),
        "meter_start": start_info.get("meterStartWh"),
        "meter_stop": stop_info.get("meterStop") or stop_info.get("meterStopWh"),
        "home_charging_compensation": True,
        "auth_id": start_info.get("idToken") or start_info.get("idTag"),
        "auth_method": _derive_auth_method(start_info, stop_info),
        "station_id": station_id,
    }

    if tariffs:
        cdr["tariffs"] = tariffs
    if meter_values:
        cdr["meter_values"] = meter_values

    return {key: value for key, value in cdr.items() if value is not None}


async def _get_backends_for_station(station_id: str) -> list[dict]:
    async with db_connection() as conn:
        async with conn.cursor(aiomysql.DictCursor) as cur:
            await _ensure_backend_tables(cur)
            await cur.execute(
                """
                SELECT
                    wb.backend_id,
                    wb.priority,
                    wb.enabled,
                    b.name,
                    b.url,
                    b.remote_versions_url,
                    b.peer_versions_url,
                    b.active_version,
                    b.token,
                    b.peer_token,
                    b.modules
                FROM op_ocpi_wallbox_backends AS wb
                JOIN op_ocpi_backends AS b ON wb.backend_id = b.backend_id
                WHERE wb.station_id=%s AND wb.enabled = 1 AND b.enabled = 1
                ORDER BY wb.priority, wb.backend_id
                """,
                (station_id,),
            )
            return await cur.fetchall()


def _build_cdr_endpoint(backend: Mapping[str, Any]) -> str:
    raw_base_url = str(
        backend.get("remote_versions_url")
        or backend.get("peer_versions_url")
        or backend.get("url")
        or OCPI_CDRS_ENDPOINT
    ).strip()
    if not raw_base_url:
        return OCPI_CDRS_ENDPOINT

    suffix = "cdrs"
    version = str(backend.get("active_version") or "").strip().strip("/")

    parsed = urlparse(raw_base_url)
    path = parsed.path.rstrip("/")
    if path.endswith("/versions"):
        path = path[: -len("/versions")]

    if path.endswith(f"/{suffix}"):
        return urlunparse(parsed._replace(path=path))

    segments = [segment for segment in path.split("/") if segment]

    if not version:
        if "ocpi" in segments:
            target_segments = segments + [suffix]
        else:
            target_segments = segments + ["ocpi", suffix]
    else:
        if "ocpi" in segments:
            ocpi_index = segments.index("ocpi")
            target_segments = segments[: ocpi_index + 1] + [version, suffix]
        else:
            target_segments = segments + ["ocpi", version, suffix]

    target_path = "/" + "/".join(target_segments)
    return urlunparse(parsed._replace(path=target_path))


def build_cdr_endpoint(backend: Mapping[str, Any]) -> str:
    """Public helper that mirrors the internal CDR endpoint construction."""

    return _build_cdr_endpoint(backend)


async def send_cdr(station_id: str, start_info: dict, stop_info: dict):
    """Forward CDR data to OCPI cdrs endpoint and log result."""

    await _ensure_metrics_server()

    if not await _is_forwarding_enabled():
        logger.info(
            "Skipping OCPI CDR forwarding because ocpi_backend_enabled is disabled for station_id=%s",
            station_id,
            extra={"event": "ocpi_forward_disabled", "station_id": station_id},
        )
        return False, 0

    cdr = _build_cdr(station_id, start_info, stop_info)

    backends = await _get_backends_for_station(station_id)
    if not backends:
        logger.info(
            "Skipping OCPI CDR forwarding because no backend is configured for station_id=%s",
            station_id,
            extra={"event": "ocpi_forward_missing_backend", "station_id": station_id},
        )
        return False, 0

    loop = asyncio.get_event_loop()
    overall_success = False
    last_status = 0

    for backend in backends:
        backend_name = backend.get("name") or f"Backend {backend.get('backend_id')}"
        modules = (backend.get("modules") or "cdrs").lower().replace(";", ",")
        module_set = {m.strip() for m in modules.split(",") if m.strip()}
        if module_set and "cdrs" not in module_set:
            logger.debug(
                "Skipping backend %s for station_id=%s because CDR module is disabled",
                backend_name,
                station_id,
            )
            continue

        headers = {}
        token = backend.get("peer_token") or backend.get("token")
        if token:
            headers["Authorization"] = f"Bearer {token}"

        target_url = _build_cdr_endpoint(backend)
        logger.info(
            "Sending OCPI CDR to %s for station_id=%s backend=%s payload=%s",
            target_url,
            station_id,
            backend_name,
            json.dumps(cdr, ensure_ascii=False),
            extra={
                "event": "ocpi_forward_start",
                "station_id": station_id,
                "backend_id": backend.get("backend_id"),
                "backend": backend_name,
                "target_url": target_url,
            },
        )
        response_status = 0
        response_body = ""
        success = False
        retry_count = 0
        should_retry = False

        for attempt in range(1, max(OCPI_RETRY_LIMIT, 1) + 1):
            retry_count = attempt - 1
            attempt_start = time.perf_counter()
            try:
                logger.info(
                    "Sending OCPI CDR attempt %s/%s to %s",
                    attempt,
                    OCPI_RETRY_LIMIT,
                    backend_name,
                    extra={
                        "station_id": station_id,
                        "backend_id": backend.get("backend_id"),
                        "transaction_id": cdr.get("id"),
                        "attempt": attempt,
                    },
                )
                response = await loop.run_in_executor(
                    None,
                    lambda: requests.post(
                        target_url,
                        json=cdr,
                        timeout=10,
                        headers=headers or None,
                    ),
                )
                response_status = response.status_code
                response_body = response.text[:1000]
                success = response.ok
                logger.info(
                    "OCPI CDR response for station_id=%s backend=%s: status=%s",
                    station_id,
                    backend_name,
                    response_status,
                    extra={
                        "station_id": station_id,
                        "backend_id": backend.get("backend_id"),
                        "status": response_status,
                        "transaction_id": cdr.get("id"),
                    },
                )
            except Exception as e:
                response_body = str(e)
                logger.exception(
                    "OCPI CDR request failed for station_id=%s backend=%s",
                    station_id,
                    backend_name,
                    extra={
                        "station_id": station_id,
                        "backend_id": backend.get("backend_id"),
                        "transaction_id": cdr.get("id"),
                        "attempt": attempt,
                    },
                )

            duration_ms = (time.perf_counter() - attempt_start) * 1000
            forwarder_metrics.record_attempt(
                duration_ms=duration_ms,
                success=success,
                retry_count=retry_count,
                status=response_status,
            )

            if success:
                break

            should_retry = attempt < OCPI_RETRY_LIMIT
            if should_retry:
                await asyncio.sleep(OCPI_RETRY_DELAY_SECONDS)

        try:
            async with db_connection() as conn:
                async with conn.cursor() as cur:
                    await _ensure_export_table(cur)
                    await _ensure_sync_run_table(cur)
                    await cur.execute(
                        """
                        INSERT INTO op_ocpi_exports
                        (station_id, backend_id, backend_name, transaction_id, payload, success, response_status, response_body, retry_count, should_retry, record_type)
                        VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
                        """,
                        (
                            station_id,
                            backend.get("backend_id"),
                            backend_name,
                            cdr.get("id"),
                            json.dumps(cdr),
                            1 if success else 0,
                            response_status,
                            response_body,
                            retry_count,
                            1 if should_retry else 0,
                            "cdr",
                        ),
                    )
                    await cur.execute(
                        """
                        INSERT INTO op_ocpi_sync_runs (
                            job_name, module, direction, backend_id, backend_name, record_type,
                            duration_ms, success, status_code, detail
                        ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
                        """,
                        (
                            "cdr_forwarder",
                            "cdrs",
                            "push",
                            backend.get("backend_id"),
                            backend_name,
                            "cdr",
                            int(duration_ms),
                            1 if success else 0,
                            response_status,
                            _truncate_detail(response_body),
                        ),
                    )
                    await conn.commit()
        except Exception:
            logger.debug(
                "Failed to log OCPI export for station_id=%s", station_id, exc_info=True
            )

        if success:
            alert_notifier.record_success()
        else:
            alert_notifier.record_failure(
                "OCPI forwarder failed",
                {
                    "station_id": station_id,
                    "backend": backend_name,
                    "status": response_status,
                    "response": response_body,
                    "retry_count": retry_count,
                },
            )

        overall_success = overall_success or success
        last_status = response_status

    return overall_success, last_status
