import asyncio
import sys
import types
from contextlib import asynccontextmanager

import pytest


if "aiomysql" not in sys.modules:
    fake_aiomysql = types.ModuleType("aiomysql")
    fake_aiomysql.DictCursor = object
    fake_aiomysql.Pool = object

    async def _dummy_connect(*args, **kwargs):  # pragma: no cover - safety fallback
        raise AssertionError("aiomysql.connect should be patched in tests")

    async def _dummy_create_pool(*args, **kwargs):  # pragma: no cover - safety fallback
        raise AssertionError("aiomysql.create_pool should be patched in tests")

    fake_aiomysql.connect = _dummy_connect
    fake_aiomysql.create_pool = _dummy_create_pool
    sys.modules["aiomysql"] = fake_aiomysql


if "aiohttp" not in sys.modules:
    fake_aiohttp = types.ModuleType("aiohttp")

    class _DummyClientSession:  # pragma: no cover - simple stub
        async def __aenter__(self):
            return self

        async def __aexit__(self, exc_type, exc, tb):
            return False

    fake_aiohttp.ClientSession = _DummyClientSession
    sys.modules["aiohttp"] = fake_aiohttp


if "paho.mqtt.client" not in sys.modules:
    fake_paho = types.ModuleType("paho")
    fake_mqtt = types.ModuleType("paho.mqtt")
    fake_mqtt_client = types.ModuleType("paho.mqtt.client")

    class _DummyCallbackAPIVersion:  # pragma: no cover - simple stub
        VERSION2 = object()

    class _DummyMQTTClient:  # pragma: no cover - simple stub
        def __init__(self, *args, **kwargs):
            self._callbacks = kwargs

        def username_pw_set(self, *args, **kwargs):
            return None

        def connect(self, *args, **kwargs):
            return None

        def loop_start(self):
            return None

        def loop_stop(self):
            return None

    fake_mqtt_client.CallbackAPIVersion = _DummyCallbackAPIVersion
    fake_mqtt_client.Client = _DummyMQTTClient
    fake_mqtt.client = fake_mqtt_client
    fake_paho.mqtt = fake_mqtt
    sys.modules["paho"] = fake_paho
    sys.modules["paho.mqtt"] = fake_mqtt
    sys.modules["paho.mqtt.client"] = fake_mqtt_client


if "requests" not in sys.modules:
    fake_requests = types.ModuleType("requests")

    class _DummyResponse:  # pragma: no cover - simple stub
        status_code = 200
        text = ""
        ok = True

    def _dummy_post(*args, **kwargs):  # pragma: no cover - simple stub
        return _DummyResponse()

    fake_requests.post = _dummy_post
    fake_requests.Session = object
    sys.modules["requests"] = fake_requests


import pipelet_ocpp_broker as broker


class DummyCursor:
    def __init__(self, columns, rows):
        self._columns = columns
        self._rows = rows
        self._phase = "columns"

    async def __aenter__(self):
        return self

    async def __aexit__(self, exc_type, exc, tb):
        return False

    async def execute(self, sql, *args, **kwargs):
        if sql.strip().upper().startswith("SHOW COLUMNS"):
            self._phase = "columns"
        else:
            self._phase = "select"

    async def fetchall(self):
        if self._phase == "columns":
            return self._columns
        return self._rows

    async def fetchone(self):
        return None


class DummyConnection:
    def __init__(self, columns, rows):
        self._columns = columns
        self._rows = rows

    async def __aenter__(self):
        return self

    async def __aexit__(self, exc_type, exc, tb):
        return False

    def cursor(self, *args, **kwargs):
        return DummyCursor(self._columns, self._rows)


class DummyConnectWrapper:
    def __init__(self, connection):
        self._connection = connection

    def __await__(self):
        async def _inner():
            return self._connection

        return _inner().__await__()

    async def __aenter__(self):
        return self._connection

    async def __aexit__(self, exc_type, exc, tb):
        return False


class DummyPool:
    def __init__(self, connection):
        self._connection = connection

    async def acquire(self):
        return self._connection

    def release(self, conn):
        return None

    def close(self):
        return None

    async def wait_closed(self):
        return None


def test_load_wallboxes_from_db_maps_ping_enabled_flag(monkeypatch):
    columns = [
        {"Field": "source_url"},
        {"Field": "ws_url"},
        {"Field": "activity"},
        {"Field": "auth_key"},
        {"Field": "measure_ping"},
        {"Field": "strict_availability"},
        {"Field": "charging_analytics"},
        {"Field": "mqtt_enabled"},
        {"Field": "backend_basic_user"},
        {"Field": "backend_basic_password"},
        {"Field": "ping_enabled"},
    ]
    rows = [
        {
            "source_url": "/flag-off",
            "ws_url": "ws://example.invalid/flag-off",
            "activity": "active",
            "auth_key": None,
            "measure_ping": 0,
            "strict_availability": 0,
            "charging_analytics": 0,
            "mqtt_enabled": 0,
            "backend_basic_user": None,
            "backend_basic_password": None,
            "ping_interval": None,
            "ping_enabled": 0,
            "ocpp_subprotocol": None,
        },
        {
            "source_url": "/flag-on",
            "ws_url": "ws://example.invalid/flag-on",
            "activity": "active",
            "auth_key": None,
            "measure_ping": 0,
            "strict_availability": 0,
            "charging_analytics": 0,
            "mqtt_enabled": 0,
            "backend_basic_user": None,
            "backend_basic_password": None,
            "ping_interval": None,
            "ping_enabled": 1,
            "ocpp_subprotocol": None,
        },
    ]

    def fake_connect(**kwargs):
        return DummyConnectWrapper(DummyConnection(columns, rows))

    async def fake_create_pool(**kwargs):
        return DummyPool(DummyConnection(columns, rows))

    fake_aiomysql = types.SimpleNamespace(
        connect=fake_connect, create_pool=fake_create_pool, DictCursor=object
    )
    monkeypatch.setattr(broker, "aiomysql", fake_aiomysql)

    original_station_paths = broker.STATION_PATHS.copy()

    async def _run_load():
        await broker.load_wallboxes_from_db()
        return broker.STATION_PATHS.copy()

    try:
        loaded_paths = asyncio.run(_run_load())
    finally:
        broker.STATION_PATHS = original_station_paths

    entry_off = loaded_paths.get("/flag-off")
    entry_on = loaded_paths.get("/flag-on")

    assert entry_off is not None
    assert entry_off["ping_enabled_flag"] is False
    assert entry_off["ping_interval"] == 0
    assert entry_off["ping_interval_source"] == "flag"

    assert entry_on is not None
    assert entry_on["ping_enabled_flag"] is True
    assert entry_on["ping_interval"] == broker.DEFAULT_PING_INTERVAL
    assert entry_on["ping_interval_source"] == "flag"


def test_resolve_ping_settings_disable_via_query_param():
    params = {"ping_enabled": ["0"]}
    entry = {"ping_interval": broker.DEFAULT_PING_INTERVAL, "ping_enabled_flag": True}

    enabled, interval = broker._resolve_ping_settings(params, entry)

    assert enabled is False
    assert interval == 0


def test_resolve_ping_settings_defaults_to_standard_interval():
    params = {}
    entry = {"ping_interval": None, "ping_enabled_flag": True}

    enabled, interval = broker._resolve_ping_settings(params, entry)

    assert enabled is True
    assert interval == broker.DEFAULT_PING_INTERVAL


def test_resolve_ping_settings_respects_numeric_ping_override():
    params = {"ping": ["30"]}
    entry = {"ping_interval": broker.DEFAULT_PING_INTERVAL, "ping_enabled_flag": False}

    enabled, interval = broker._resolve_ping_settings(params, entry)

    assert enabled is True
    assert interval == 30


def test_handle_client_prefers_listener_negotiated_subprotocol(monkeypatch):
    station_path = "/ocpp/test-station"
    station_id = "test-station"
    ws_url = "ws://backend.example/ocpp/test-station"
    entry = {
        "source_url": station_path,
        "ws_url": ws_url,
        "activity": "active",
        "configured_subprotocol": "ocpp1.6",
    }

    async def _noop(*args, **kwargs):
        return None

    async def _noop_keepalive(*args, **kwargs):
        return None

    monkeypatch.setattr(broker, "log_message", _noop)
    monkeypatch.setattr(broker, "update_last_connected", _noop)
    monkeypatch.setattr(broker, "notify_station_connected", _noop)
    monkeypatch.setattr(broker, "log_cp_auth_event", _noop)
    monkeypatch.setattr(broker, "keep_alive", _noop_keepalive)
    monkeypatch.setattr(broker, "forward_messages", _noop_keepalive)
    monkeypatch.setattr(broker, "schedule_disconnect_notification", lambda *args, **kwargs: None)

    def _create_task(coro, name=None):
        return asyncio.create_task(coro)

    monkeypatch.setattr(broker, "create_background_task", _create_task)

    connect_kwargs_used: dict[str, dict] = {}

    @asynccontextmanager
    async def _dummy_backend_connection(url, connect_kwargs, auth_headers, station_id_arg):
        connect_kwargs_used["value"] = connect_kwargs

        class _DummyTransport:
            def get_extra_info(self, name):
                return ("127.0.0.1", 9000)

        class _DummyRemote:
            def __init__(self):
                self.transport = _DummyTransport()
                self.closed = False

            async def close(self):
                self.closed = True

        remote = _DummyRemote()
        yield remote
        remote.closed = True

    monkeypatch.setattr(broker, "_backend_websocket_connection", _dummy_backend_connection)

    class _DummyRequest:
        def __init__(self, path):
            self.path = path
            self.headers = {"Sec-WebSocket-Protocol": "ocpp2.0.1"}

    class _DummyTransportInbound:
        def __init__(self, peer):
            self._peer = peer

        def get_extra_info(self, name):
            if name == "peername":
                return self._peer
            if name == "sockname":
                return ("0.0.0.0", 0)
            return None

    class DummyWebSocket:
        def __init__(self, path):
            self.request = _DummyRequest(path)
            self.subprotocol = "ocpp2.0.1"
            self.client_offered_subprotocols = ["ocpp2.0.1"]
            self.server_offered_subprotocols = ["ocpp1.6", "ocpp2.0.1"]
            self.subprotocol_negotiation_failed = False
            self.subprotocol_rejection_reason = ""
            self.transport = _DummyTransportInbound(("127.0.0.1", 12345))
            self.closed = False

        async def close(self, code=None, reason=None):
            self.closed = True
            self.close_code = code
            self.close_reason = reason

    state_snapshots = {
        "STATION_PATHS": broker.STATION_PATHS.copy(),
        "STATION_PATH_ALIASES": broker.STATION_PATH_ALIASES.copy(),
        "ACTIVE_CLIENTS": broker.ACTIVE_CLIENTS.copy(),
        "RECONNECTION_COUNTER": broker.RECONNECTION_COUNTER.copy(),
        "ACTIVE_CONNECTION_IDS": broker.ACTIVE_CONNECTION_IDS.copy(),
        "CONNECTION_SUBPROTOCOLS": broker.CONNECTION_SUBPROTOCOLS.copy(),
        "CONNECTION_START_TIMES": broker.CONNECTION_START_TIMES.copy(),
        "LAST_CONNECT": broker.LAST_CONNECT.copy(),
        "LAST_DISCONNECT": broker.LAST_DISCONNECT.copy(),
        "ACTIVE_BACKEND_CONNECTIONS": broker.ACTIVE_BACKEND_CONNECTIONS.copy(),
        "LAST_SEEN": broker.LAST_SEEN.copy(),
    }

    websocket = DummyWebSocket(station_path)

    broker.STATION_PATHS.clear()
    broker.STATION_PATH_ALIASES.clear()
    broker.STATION_PATHS[station_path] = entry

    try:
        asyncio.run(broker.handle_client(websocket))
        updated_entry = broker.STATION_PATHS[station_path].copy()
        connect_kwargs_snapshot = dict(connect_kwargs_used.get("value", {}))
    finally:
        for name, snapshot in state_snapshots.items():
            container = getattr(broker, name)
            container.clear()
            container.update(snapshot)

    assert websocket.closed is False
    assert updated_entry["configured_subprotocol"] == "ocpp2.0.1"
    assert updated_entry["subprotocol"] == "ocpp2.0.1"
    assert connect_kwargs_snapshot.get("subprotocols") == ["ocpp2.0.1"]
