import asyncio
import logging
import ssl
import sys
import types

import pytest

if "aiomysql" not in sys.modules:
    fake_aiomysql = types.ModuleType("aiomysql")
    fake_aiomysql.DictCursor = object
    fake_aiomysql.Pool = object

    class _DummyCursor:  # pragma: no cover - simple stub
        async def __aenter__(self):
            return self

        async def __aexit__(self, exc_type, exc, tb):
            return False

        async def execute(self, *args, **kwargs):
            return None

        async def fetchone(self):
            return None

        async def fetchall(self):
            return []

    class _DummyConnection:  # pragma: no cover - simple stub
        async def __aenter__(self):
            return self

        async def __aexit__(self, exc_type, exc, tb):
            return False

        def cursor(self, *args, **kwargs):
            return _DummyCursor()

        def close(self):
            return None

    class _DummyPool:  # pragma: no cover - simple stub
        def __init__(self):
            self._conn = _DummyConnection()

        async def acquire(self):
            return self._conn

        def release(self, conn):
            return None

        def close(self):
            return None

        async def wait_closed(self):
            return None

    async def _dummy_connect(*args, **kwargs):  # pragma: no cover - simple stub
        return _DummyConnection()

    async def _dummy_create_pool(*args, **kwargs):  # pragma: no cover - simple stub
        return _DummyPool()

    fake_aiomysql.connect = _dummy_connect
    fake_aiomysql.create_pool = _dummy_create_pool
    sys.modules["aiomysql"] = fake_aiomysql

if "aiohttp" not in sys.modules:
    fake_aiohttp = types.ModuleType("aiohttp")

    class _DummyClientSession:  # pragma: no cover - simple stub
        async def __aenter__(self):
            return self

        async def __aexit__(self, exc_type, exc, tb):
            return False

    fake_aiohttp.ClientSession = _DummyClientSession
    sys.modules["aiohttp"] = fake_aiohttp

if "paho.mqtt.client" not in sys.modules:
    fake_paho = types.ModuleType("paho")
    fake_mqtt = types.ModuleType("paho.mqtt")
    fake_mqtt_client = types.ModuleType("paho.mqtt.client")

    class _DummyCallbackAPIVersion:  # pragma: no cover - simple stub
        VERSION2 = object()

    class _DummyMQTTClient:  # pragma: no cover - simple stub
        def __init__(self, *args, **kwargs):
            self._callbacks = kwargs

        def username_pw_set(self, *args, **kwargs):
            return None

        def connect(self, *args, **kwargs):
            return None

        def loop_start(self):
            return None

        def loop_stop(self):
            return None

    fake_mqtt_client.CallbackAPIVersion = _DummyCallbackAPIVersion
    fake_mqtt_client.Client = _DummyMQTTClient
    fake_mqtt.client = fake_mqtt_client
    fake_paho.mqtt = fake_mqtt
    sys.modules["paho"] = fake_paho
    sys.modules["paho.mqtt"] = fake_mqtt
    sys.modules["paho.mqtt.client"] = fake_mqtt_client

if "requests" not in sys.modules:
    fake_requests = types.ModuleType("requests")

    class _DummyResponse:  # pragma: no cover - simple stub
        status_code = 200
        text = ""
        ok = True

    def _dummy_post(*args, **kwargs):  # pragma: no cover - simple stub
        return _DummyResponse()

    fake_requests.post = _dummy_post
    fake_requests.Session = object
    sys.modules["requests"] = fake_requests

import pipelet_ocpp_broker as broker


def test_read_request_optional_crlf_handles_carriage_return_only_headers():
    async def _run():
        reader = asyncio.StreamReader()
        reader.feed_data(
            b"GET /ocpp HTTP/1.1\r"
            b"Host: charger01.example\r"
            b"Upgrade: websocket\r"
            b"Connection: Upgrade\r"
            b"Sec-WebSocket-Key: abcdefghijklmnop\r"
            b"Sec-WebSocket-Version: 13\r\r"
        )
        reader.feed_eof()

        path, headers, metadata = await broker.read_request_optional_crlf(reader)

        assert path == "/ocpp"
        assert headers["Host"] == "charger01.example"
        assert headers["Upgrade"].lower() == "websocket"
        assert metadata["missing_line_endings"] is True

    asyncio.run(_run())


def test_build_backend_ssl_context_disables_verification_by_default():
    original_verify = broker.BACKEND_SSL_VERIFY
    original_cafile = broker.BACKEND_SSL_CA_FILE
    try:
        broker.BACKEND_SSL_VERIFY = False
        broker.BACKEND_SSL_CA_FILE = None
        context = broker._build_backend_ssl_context()
        assert isinstance(context, ssl.SSLContext)
        assert context.verify_mode == ssl.CERT_NONE
        assert context.check_hostname is False
    finally:
        broker.BACKEND_SSL_VERIFY = original_verify
        broker.BACKEND_SSL_CA_FILE = original_cafile


def test_build_backend_ssl_context_with_verification_enabled():
    original_verify = broker.BACKEND_SSL_VERIFY
    original_cafile = broker.BACKEND_SSL_CA_FILE
    try:
        broker.BACKEND_SSL_VERIFY = True
        broker.BACKEND_SSL_CA_FILE = None
        context = broker._build_backend_ssl_context()
        assert isinstance(context, ssl.SSLContext)
        assert context.verify_mode == ssl.CERT_REQUIRED
    finally:
        broker.BACKEND_SSL_VERIFY = original_verify
        broker.BACKEND_SSL_CA_FILE = original_cafile


def test_backend_websocket_connection_uses_unverified_context_for_wss():
    original_verify = broker.BACKEND_SSL_VERIFY
    original_cafile = broker.BACKEND_SSL_CA_FILE
    original_connect = broker.websockets.connect

    class _DummyConnection:
        def __init__(self, kwargs):
            self.kwargs = kwargs
            self.transport = types.SimpleNamespace(get_extra_info=lambda *args, **kwargs: None)

        async def __aenter__(self):
            return self

        async def __aexit__(self, exc_type, exc, tb):
            return False

    def _dummy_connect(*args, **kwargs):
        _dummy_connect.last_kwargs = kwargs
        return _DummyConnection(kwargs)

    _dummy_connect.last_kwargs = {}

    try:
        broker.BACKEND_SSL_VERIFY = False
        broker.BACKEND_SSL_CA_FILE = None
        broker.websockets.connect = _dummy_connect

        async def _run():
            async with broker._backend_websocket_connection(
                "wss://example.com/ocpp",
                {},
                None,
                "station-1",
            ) as remote_ws:
                assert remote_ws is not None

        asyncio.run(_run())

        ssl_context = _dummy_connect.last_kwargs.get("ssl")
        assert isinstance(ssl_context, ssl.SSLContext)
        assert ssl_context.verify_mode == ssl.CERT_NONE
        assert ssl_context.check_hostname is False
    finally:
        broker.BACKEND_SSL_VERIFY = original_verify
        broker.BACKEND_SSL_CA_FILE = original_cafile
        broker.websockets.connect = original_connect


def test_read_http_request_logs_masked_frame_before_handshake(monkeypatch, caplog):
    protocol = broker.LoggingWebSocketServerProtocol.__new__(
        broker.LoggingWebSocketServerProtocol
    )

    masked_payload = b"HELLO"
    mask_key = b"\x01\x02\x03\x04"
    masked_bytes = bytes(
        payload_byte ^ mask_key[i % 4] for i, payload_byte in enumerate(masked_payload)
    )
    frame = bytes([0x81, 0x80 | len(masked_payload)]) + mask_key + masked_bytes

    def _get_extra_info(name, default=None):
        if name == "peername":
            return ("127.0.0.1", 443)
        if name == "ssl_object":
            return None
        return default

    protocol.transport = types.SimpleNamespace(get_extra_info=_get_extra_info)
    protocol.reader = types.SimpleNamespace(_buffer=bytearray(frame))
    protocol.logger = logging.getLogger("pipelet_ocpp_broker.test")

    async def _raise_invalid_request(_reader):
        raise RuntimeError("boom")

    monkeypatch.setattr(broker, "read_request_optional_crlf", _raise_invalid_request)

    caplog.set_level(logging.DEBUG)

    async def _run():
        await protocol.read_http_request()

    with pytest.raises(broker.websockets.exceptions.InvalidMessage):
        asyncio.run(_run())

    messages = [record.getMessage() for record in caplog.records]
    assert any(
        "masked WebSocket payload before completing the HTTP upgrade" in message
        and "HELLO" in message
        for message in messages
    )
