import asyncio
import socket
import ssl
from pathlib import Path

import websockets

import test_websocket_handshake_tolerance as compat

broker = compat.broker


def _get_free_port() -> int:
    sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    try:
        sock.bind(("127.0.0.1", 0))
        return sock.getsockname()[1]
    finally:
        sock.close()


def test_multi_proxy_listeners_accept_connections(monkeypatch):
    handled_paths: list[str] = []

    async def _dummy_handler(websocket, path):
        handled_paths.append(path)
        await websocket.close()

    monkeypatch.setattr(broker, "handle_client", _dummy_handler)

    async def _run():
        secure_port = _get_free_port()
        plain_port = _get_free_port()
        certfile = str(Path("certs/walle_QA_signed.pem").resolve())
        keyfile = str(Path("certs/walle_private.key").resolve())

        secure_listener = broker.ProxyListenerSettings(
            name="secure-test",
            ip="127.0.0.1",
            port=secure_port,
            use_ssl=True,
            certfile=certfile,
            keyfile=keyfile,
            force_tls12=False,
            cert_identifiers=tuple(
                broker._collect_server_certificate_identifiers(certfile)
            ),
        )
        plain_listener = broker.ProxyListenerSettings(
            name="plain-test",
            ip="127.0.0.1",
            port=plain_port,
            use_ssl=False,
            certfile=None,
            keyfile=None,
            force_tls12=False,
            cert_identifiers=tuple(),
        )

        original_paths = broker.STATION_PATHS.copy()
        original_aliases = broker.STATION_PATH_ALIASES.copy()
        try:
            broker.STATION_PATHS.clear()
            broker.STATION_PATH_ALIASES.clear()
            broker.STATION_PATHS["/ocpp-secure/test"] = {"source_url": "/ocpp-secure/test"}
            broker.STATION_PATHS["/ocpp-plain/test"] = {"source_url": "/ocpp-plain/test"}

            servers = []
            configured = ["ocpp2.0.1", "ocpp1.6"]
            for listener in (secure_listener, plain_listener):
                server = await broker._start_proxy_listener(listener, configured)
                servers.append(server)

            try:
                tls_context = ssl.create_default_context()
                tls_context.check_hostname = False
                tls_context.verify_mode = ssl.CERT_NONE

                async with websockets.connect(
                    f"wss://127.0.0.1:{secure_port}/ocpp-secure/test",
                    ssl=tls_context,
                    subprotocols=["ocpp1.6"],
                ):
                    pass

                async with websockets.connect(
                    f"ws://127.0.0.1:{plain_port}/ocpp-plain/test",
                    subprotocols=["ocpp1.6"],
                ):
                    pass
            finally:
                for server in servers:
                    server.close()
                await asyncio.gather(*(server.wait_closed() for server in servers))
        finally:
            broker.STATION_PATHS.clear()
            broker.STATION_PATHS.update(original_paths)
            broker.STATION_PATH_ALIASES.clear()
            broker.STATION_PATH_ALIASES.update(original_aliases)

    asyncio.run(_run())

    assert handled_paths == ["/ocpp-secure/test", "/ocpp-plain/test"]
