import asyncio
import json
import sys
import types

import pytest

if "pymysql" not in sys.modules:
    fake_pymysql = types.ModuleType("pymysql")
    fake_cursors = types.ModuleType("pymysql.cursors")
    fake_cursors.DictCursor = object
    fake_pymysql.cursors = fake_cursors
    def _connect(*args, **kwargs):  # pragma: no cover - import shim
        raise RuntimeError("mysql unavailable in tests")

    fake_pymysql.connect = _connect
    sys.modules["pymysql"] = fake_pymysql
    sys.modules["pymysql.cursors"] = fake_cursors

if "aiohttp" not in sys.modules:
    fake_aiohttp = types.ModuleType("aiohttp")

    class _DummyClientSession:
        async def __aenter__(self):
            return self

        async def __aexit__(self, exc_type, exc, tb):
            return False

    class _DummyClientTimeout:
        def __init__(self, *args, **kwargs):
            pass

    class _DummyApplication:
        def __init__(self, *args, **kwargs):
            self.router = []

        def add_routes(self, routes):
            self.router.extend(routes)
            return None

    class _DummyRouteTableDef(list):
        def get(self, *args, **kwargs):
            def decorator(func):
                self.append(("get", args, kwargs, func))
                return func

            return decorator

        def post(self, *args, **kwargs):
            def decorator(func):
                self.append(("post", args, kwargs, func))
                return func

            return decorator

    fake_aiohttp.ClientSession = _DummyClientSession
    fake_aiohttp.ClientTimeout = _DummyClientTimeout
    fake_aiohttp.web = types.SimpleNamespace(
        json_response=lambda *args, **kwargs: {},
        Request=type("Request", (), {}),
        Response=type("Response", (), {}),
        Application=_DummyApplication,
        RouteTableDef=_DummyRouteTableDef,
    )
    sys.modules["aiohttp"] = fake_aiohttp
    sys.modules["aiohttp.web"] = fake_aiohttp.web

if "websockets" not in sys.modules:
    fake_websockets = types.ModuleType("websockets")
    fake_legacy = types.ModuleType("websockets.legacy")
    fake_server = types.ModuleType("websockets.legacy.server")
    fake_http = types.ModuleType("websockets.legacy.http")
    fake_exceptions = types.ModuleType("websockets.exceptions")

    class _DummyProtocol:
        pass

    fake_server.WebSocketServerProtocol = _DummyProtocol
    fake_server.serve = lambda *args, **kwargs: None
    fake_http.read_request = lambda *args, **kwargs: None
    fake_exceptions.InvalidMessage = Exception
    fake_exceptions.WebSocketException = Exception
    fake_exceptions.ConnectionClosed = Exception

    fake_legacy.server = fake_server
    fake_legacy.http = fake_http
    fake_websockets.legacy = fake_legacy
    fake_websockets.exceptions = fake_exceptions

    sys.modules["websockets"] = fake_websockets
    sys.modules["websockets.legacy"] = fake_legacy
    sys.modules["websockets.legacy.server"] = fake_server
    sys.modules["websockets.legacy.http"] = fake_http
    sys.modules["websockets.exceptions"] = fake_exceptions

import pipelet_ocpp_server as server


class _DummyCursor:
    def __init__(self, pnc_enabled: bool):
        self._phase = "columns"
        self._pnc_enabled = pnc_enabled

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc, tb):
        return False

    def execute(self, sql, *args, **kwargs):
        if "SHOW COLUMNS" in sql:
            self._phase = "columns"
        else:
            self._phase = "select"

    def fetchone(self):
        if self._phase == "columns":
            return {"Field": "pnc_enabled"}
        return {"pnc_enabled": 1 if self._pnc_enabled else 0}


class _DummyConnection:
    def __init__(self, pnc_enabled: bool):
        self._pnc_enabled = pnc_enabled

    def cursor(self, *args, **kwargs):
        return _DummyCursor(self._pnc_enabled)

    def close(self):
        return None

    def commit(self):
        return None


@pytest.fixture(autouse=True)
def reset_pnc_state(monkeypatch):
    server._PNC_FLAG_CACHE.clear()
    server._PNC_COLUMN_PRESENT = None
    server._PNC_ADAPTER = None
    server._PNC_ADAPTER_ERROR = None

    def _noop_conn(*args, **kwargs):
        raise AssertionError("unexpected database call")

    monkeypatch.setattr(server, "get_db_conn", _noop_conn)
    yield


def test_authorize_uses_pnc_adapter_when_flag_enabled(monkeypatch):
    calls: list[dict[str, object]] = []

    async def _fake_pnc_authorize(payload):
        calls.append(payload)
        return types.SimpleNamespace(authorized=True, status="Accepted")

    monkeypatch.setattr(
        server,
        "_load_pnc_adapter",
        lambda: types.SimpleNamespace(pnc_authorize=_fake_pnc_authorize),
    )
    monkeypatch.setattr(server, "get_db_conn", lambda: _DummyConnection(True))

    id_tag_info, decision = asyncio.run(
        server.resolve_authorize_id_tag_info(
            "CP-1",
            {"contractCertificate": {"Identification": {"evcoId": "DE*PNC*E123"}}},
        )
    )

    assert calls, "PnC adapter should be invoked when station flag is enabled"
    assert id_tag_info["status"] == "Accepted"
    assert decision is not None


def test_authorize_skips_pnc_when_flag_disabled(monkeypatch):
    calls: list[dict[str, object]] = []

    async def _fake_pnc_authorize(payload):
        calls.append(payload)
        return types.SimpleNamespace(authorized=True, status="Accepted")

    monkeypatch.setattr(
        server,
        "_load_pnc_adapter",
        lambda: types.SimpleNamespace(pnc_authorize=_fake_pnc_authorize),
    )
    monkeypatch.setattr(server, "get_db_conn", lambda: _DummyConnection(False))

    async def _fake_rfid(*args, **kwargs):
        return "Blocked"

    monkeypatch.setattr(server, "check_rfid_authorization", _fake_rfid)

    id_tag_info, decision = asyncio.run(
        server.resolve_authorize_id_tag_info(
            "CP-2",
            {"contractCertificate": {"Identification": {"evcoId": "DE*PNC*E999"}}},
        )
    )

    assert not calls, "PnC adapter must not be used when station flag is disabled"
    assert id_tag_info == {"status": "Blocked"}
    assert decision is None


@pytest.mark.parametrize(
    "decision_factory,expected_error",
    [
        (
            lambda: types.SimpleNamespace(
                authorized=False, status="Rejected", error="certificate_expired"
            ),
            "certificate_expired",
        ),
        (
            lambda: types.SimpleNamespace(
                authorized=False, status="Rejected", error="contract_blocked"
            ),
            "contract_blocked",
        ),
        (lambda: RuntimeError("hubject down"), "pnc_adapter_failure"),
    ],
)
def test_authorize_returns_rejected_status_on_errors(
    decision_factory, expected_error, monkeypatch
):
    def _adapter_factory():
        decision = decision_factory()

        async def _pnc_authorize(payload):
            if isinstance(decision, Exception):
                raise decision
            return decision

        return types.SimpleNamespace(pnc_authorize=_pnc_authorize)

    monkeypatch.setattr(server, "_load_pnc_adapter", _adapter_factory)
    monkeypatch.setattr(server, "get_db_conn", lambda: _DummyConnection(True))

    id_tag_info, _ = asyncio.run(
        server.resolve_authorize_id_tag_info("CP-3", {"contractCertificate": "DE*PNC*ERROR"})
    )

    assert id_tag_info["status"] == "Rejected"
    assert id_tag_info.get("errorCode") == expected_error


def test_log_ocpp_message_masks_sensitive_identifiers(monkeypatch):
    stored: list[tuple] = []

    class _RecordingCursor:
        def __init__(self, storage):
            self._storage = storage

        def __enter__(self):
            return self

        def __exit__(self, exc_type, exc, tb):
            return False

        def execute(self, sql, params=None):
            if sql.strip().lower().startswith("insert"):
                self._storage.append(params)

    class _RecordingConnection:
        def __init__(self, storage):
            self._storage = storage

        def cursor(self, *args, **kwargs):
            return _RecordingCursor(self._storage)

        def commit(self):
            return None

        def close(self):
            return None

    monkeypatch.setattr(server, "get_db_conn", lambda: _RecordingConnection(stored))

    payload = {
        "idTag": "T-001",
        "contractCertificate": {"Identification": {"evcoId": "DE*MASK*EMAID"}},
        "vin": "WVWZZZABCDEFG1234",
    }

    server.log_ocpp_message("CP-9", "client_to_server", payload, "Authorize")

    _, _, _, message_json, _, _ = stored[-1]
    stored_message = json.loads(message_json)

    assert stored_message["contractCertificate"]["Identification"]["evcoId"] != "DE*MASK*EMAID"
    assert stored_message["contractCertificate"]["Identification"]["evcoId"].startswith(
        "DE"
    )
    assert stored_message["vin"] == "WV***34"

