import asyncio
import contextlib
import json
import logging
import signal
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, Optional, Set

import pymysql
import websockets
from websockets.client import WebSocketClientProtocol
import paho.mqtt.client as mqtt

CONFIG_FILE = "config.json"
REFRESH_INTERVAL_SECONDS = 30
MQTT_QUEUE_MAXSIZE = 1000

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s %(levelname)s [%(name)s] %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
)
logger = logging.getLogger(__name__)


def load_config_file() -> Dict[str, Any]:
    try:
        with open(CONFIG_FILE, "r", encoding="utf-8") as fh:
            return json.load(fh)
    except FileNotFoundError:
        logger.warning("Config file %s not found, using defaults", CONFIG_FILE)
        return {}


CONFIG: Dict[str, Any] = load_config_file()


def configure_logging_from_config(config: Dict[str, Any]) -> None:
    level_name = (
        config.get("log_levels", {}).get("virtual_charger")
        or config.get("virtual_charger_log_level")
        or "INFO"
    )
    try:
        level = getattr(logging, str(level_name).upper())
    except AttributeError:
        logger.warning(
            "Unknown log level '%s' in config, falling back to INFO", level_name
        )
        level = logging.INFO
    logging.getLogger().setLevel(level)
    logging.getLogger("websockets.client").setLevel(level)
    logging.getLogger("websockets.protocol").setLevel(level)
    logger.debug("Logging configured with level %s", level)


configure_logging_from_config(CONFIG)

mysql_cfg: Dict[str, Any] = CONFIG.get("mysql", {})
ws_cfg: Dict[str, Any] = CONFIG.get("websocket", {})
DEFAULT_PING_INTERVAL = int(ws_cfg.get("ping_interval", 60) or 60)
DEFAULT_PING_TIMEOUT = int(ws_cfg.get("ping_timeout", 120) or 120)


@dataclass(frozen=True)
class MQTTConfig:
    broker: str = ""
    port: int = 0
    user: Optional[str] = None
    password: Optional[str] = None
    topic_prefix: str = ""

    @property
    def is_configured(self) -> bool:
        return bool(self.broker and self.port)


@dataclass
class ChargepointRule:
    chargepoint_id: str
    backend_url: str
    protocol: Optional[str]
    enabled_messages: Set[str] = field(default_factory=set)
    all_messages: bool = False

    def cleaned_backend_url(self) -> str:
        return (self.backend_url or "").strip()

    def subprotocols(self) -> Optional[list[str]]:
        mapping = {
            "OCPP 1.6": ["ocpp1.6", "ocpp1.6j"],
            "OCPP 2.0.1": ["ocpp2.0.1"],
        }
        if not self.protocol:
            return None
        proto = mapping.get(self.protocol)
        if proto:
            return proto
        normalized = self.protocol.lower().replace(" ", "")
        return [normalized]

    def should_forward(self, message_type: Optional[str], frame_type: Optional[int]) -> bool:
        if not self.cleaned_backend_url():
            return False
        if self.all_messages:
            return True
        if frame_type != 2 or not message_type:
            return False
        return message_type in self.enabled_messages


@dataclass
class MQTTMessage:
    topic: str
    payload: str


def normalize_chargepoint_id(raw: Any) -> Optional[str]:
    if raw is None:
        return None
    if isinstance(raw, (bytes, bytearray)):
        raw = raw.decode("utf-8", errors="ignore")
    text = str(raw).strip()
    if not text:
        return None
    cleaned = text.replace("\r", "").replace("\n", "").replace("\t", "")
    return cleaned or None


def get_db_conn() -> Optional[pymysql.connections.Connection]:
    if not mysql_cfg.get("host"):
        logger.error("MySQL configuration missing in config.json")
        return None
    try:
        return pymysql.connect(
            host=mysql_cfg.get("host"),
            user=mysql_cfg.get("user"),
            password=mysql_cfg.get("password"),
            db=mysql_cfg.get("db"),
            charset=mysql_cfg.get("charset", "utf8mb4"),
            autocommit=True,
            cursorclass=pymysql.cursors.DictCursor,
        )
    except Exception:
        logger.exception("Failed to connect to MySQL")
        return None


def load_runtime_config() -> Dict[str, str]:
    conn = get_db_conn()
    if conn is None:
        return {}
    try:
        with conn.cursor() as cur:
            cur.execute("SELECT config_key, config_value FROM op_config")
            rows = cur.fetchall()
            return {row["config_key"]: row["config_value"] for row in rows}
    finally:
        conn.close()


def extract_mqtt_config(runtime_cfg: Dict[str, str]) -> MQTTConfig:
    try:
        port = int(runtime_cfg.get("mqtt_port", "") or 0)
    except ValueError:
        port = 0
    return MQTTConfig(
        broker=(runtime_cfg.get("mqtt_broker") or "").strip(),
        port=port,
        user=(runtime_cfg.get("mqtt_user") or "").strip() or None,
        password=(runtime_cfg.get("mqtt_password") or "").strip() or None,
        topic_prefix=(runtime_cfg.get("mqtt_topic_prefix") or "").strip(),
    )


def load_routing_rules() -> Dict[str, ChargepointRule]:
    conn = get_db_conn()
    if conn is None:
        return {}
    try:
        with conn.cursor() as cur:
            cur.execute(
                """
                SELECT chargepoint_id, message_type, enabled, ocpp_backend_url, ocpp_protocol
                FROM op_ocpp_routing_rules
                ORDER BY chargepoint_id
                """
            )
            rows = cur.fetchall()
    finally:
        conn.close()

    grouped: Dict[str, Dict[str, Any]] = {}
    for row in rows:
        cp_id = normalize_chargepoint_id(row.get("chargepoint_id"))
        if not cp_id:
            continue
        entry = grouped.setdefault(
            cp_id,
            {
                "backend_url": "",
                "protocol": None,
                "enabled_messages": set(),
                "all_messages": False,
            },
        )
        backend_url = (row.get("ocpp_backend_url") or "").strip()
        if backend_url and not entry["backend_url"]:
            entry["backend_url"] = backend_url
        protocol = (row.get("ocpp_protocol") or "").strip() or None
        if protocol and not entry["protocol"]:
            entry["protocol"] = protocol
        enabled = row.get("enabled")
        is_enabled = bool(int(enabled)) if isinstance(enabled, (int, str)) else bool(enabled)
        if not is_enabled:
            continue
        message_type = row.get("message_type")
        if message_type == "__all__":
            entry["all_messages"] = True
        elif message_type:
            entry["enabled_messages"].add(str(message_type))

    rules: Dict[str, ChargepointRule] = {}
    for cp_id, data in grouped.items():
        rule = ChargepointRule(
            chargepoint_id=cp_id,
            backend_url=data["backend_url"],
            protocol=data["protocol"],
            enabled_messages=set(data["enabled_messages"]),
            all_messages=bool(data["all_messages"]),
        )
        if rule.cleaned_backend_url() and (rule.all_messages or rule.enabled_messages):
            rules[cp_id] = rule
    return rules


def parse_ocpp_payload(payload: str) -> tuple[Optional[int], Optional[str]]:
    try:
        data = json.loads(payload)
    except json.JSONDecodeError:
        logger.debug("Failed to decode OCPP payload as JSON")
        return None, None
    if not isinstance(data, list) or len(data) < 3:
        return None, None
    frame_type = data[0]
    if not isinstance(frame_type, int):
        return None, None
    action: Optional[str] = None
    if frame_type == 2 and isinstance(data[2], str):
        action = data[2]
    return frame_type, action


def parse_mqtt_topic(topic: str, prefix: str) -> tuple[Optional[str], Optional[str]]:
    normalized = topic
    if prefix:
        fixed_prefix = prefix.rstrip("/") + "/"
        if not normalized.startswith(fixed_prefix):
            return None, None
        normalized = normalized[len(fixed_prefix) :]
    if not normalized.startswith("wallbox/"):
        return None, None
    suffix = normalized[len("wallbox/") :]
    if "/" not in suffix:
        return None, None
    station_part, direction = suffix.rsplit("/", 1)
    station_id = normalize_chargepoint_id(station_part)
    direction = direction.strip()
    return station_id, direction


class ChargepointTwin:
    def __init__(self, rule: ChargepointRule):
        self.rule = rule
        self.queue: asyncio.Queue[Optional[str]] = asyncio.Queue(maxsize=MQTT_QUEUE_MAXSIZE)
        self._task: Optional[asyncio.Task[None]] = None
        self._stop_event = asyncio.Event()
        self.ws: Optional[WebSocketClientProtocol] = None

    async def start(self) -> None:
        if self._task is None or self._task.done():
            self._stop_event.clear()
            self._task = asyncio.create_task(self._run(), name=f"twin-{self.rule.chargepoint_id}")
            logger.info("Twin for %s started", self.rule.chargepoint_id)

    async def stop(self) -> None:
        self._stop_event.set()
        try:
            self.queue.put_nowait(None)
        except asyncio.QueueFull:
            logger.warning(
                "Message queue full while stopping twin %s; dropping pending messages",
                self.rule.chargepoint_id,
            )
            while not self.queue.empty():
                try:
                    self.queue.get_nowait()
                except asyncio.QueueEmpty:
                    break
            try:
                self.queue.put_nowait(None)
            except asyncio.QueueFull:
                logger.error(
                    "Unable to signal sender shutdown for %s due to queue overflow",
                    self.rule.chargepoint_id,
                )
        if self.ws is not None:
            await self.ws.close()
        if self._task:
            with contextlib.suppress(asyncio.CancelledError):
                await self._task
        self._task = None
        logger.info("Twin for %s stopped", self.rule.chargepoint_id)

    async def update_rule(self, rule: ChargepointRule) -> None:
        reconnect_required = (
            rule.cleaned_backend_url() != self.rule.cleaned_backend_url()
            or rule.subprotocols() != self.rule.subprotocols()
        )
        self.rule = rule
        if reconnect_required and self.ws is not None:
            logger.info(
                "Backend configuration for %s changed, reconnecting", self.rule.chargepoint_id
            )
            await self.ws.close()

    async def enqueue(self, payload: str) -> None:
        try:
            await self.queue.put(payload)
        except asyncio.CancelledError:
            raise
        except Exception:
            logger.exception("Failed to enqueue message for %s", self.rule.chargepoint_id)

    async def _run(self) -> None:
        while not self._stop_event.is_set():
            backend_url = self.rule.cleaned_backend_url()
            if not backend_url:
                logger.warning(
                    "No backend URL configured for %s, retrying in 15 seconds",
                    self.rule.chargepoint_id,
                )
                try:
                    await asyncio.wait_for(self._stop_event.wait(), timeout=15)
                except asyncio.TimeoutError:
                    continue
                else:
                    break
            try:
                subprotocols = self.rule.subprotocols()
                logger.info(
                    "Connecting twin %s to %s", self.rule.chargepoint_id, backend_url
                )
                async with websockets.connect(
                    backend_url,
                    subprotocols=subprotocols,
                    ping_interval=DEFAULT_PING_INTERVAL,
                    ping_timeout=DEFAULT_PING_TIMEOUT,
                ) as ws:
                    self.ws = ws
                    await self._handle_connection(ws)
            except Exception:
                logger.exception(
                    "Connection error for %s, retrying in 10 seconds",
                    self.rule.chargepoint_id,
                )
                try:
                    await asyncio.wait_for(self._stop_event.wait(), timeout=10)
                except asyncio.TimeoutError:
                    continue
                else:
                    break
            finally:
                self.ws = None

    async def _handle_connection(self, ws: WebSocketClientProtocol) -> None:
        logger.info("Twin %s connected", self.rule.chargepoint_id)
        sender = asyncio.create_task(self._sender(ws))
        receiver = asyncio.create_task(self._receiver(ws))
        stop_waiter = asyncio.create_task(self._stop_event.wait())
        tasks = {sender, receiver, stop_waiter}
        done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
        if stop_waiter in done:
            logger.info("Stop requested for twin %s", self.rule.chargepoint_id)
        for task in pending:
            task.cancel()
        if stop_waiter not in done:
            stop_waiter.cancel()
        for task in tasks:
            with contextlib.suppress(asyncio.CancelledError):
                await task
        logger.info("Twin %s disconnected", self.rule.chargepoint_id)

    async def _sender(self, ws: WebSocketClientProtocol) -> None:
        while not self._stop_event.is_set():
            payload = await self.queue.get()
            if payload is None:
                return
            try:
                await ws.send(payload)
                logger.debug(
                    "Forwarded MQTT message to backend for %s: %s",
                    self.rule.chargepoint_id,
                    payload,
                )
            except Exception:
                logger.exception(
                    "Failed to send message for %s, re-queuing",
                    self.rule.chargepoint_id,
                )
                await asyncio.sleep(1)
                await self.queue.put(payload)
                raise

    async def _receiver(self, ws: WebSocketClientProtocol) -> None:
        try:
            async for message in ws:
                logger.debug(
                    "Received backend message for %s: %s",
                    self.rule.chargepoint_id,
                    message,
                )
        except asyncio.CancelledError:
            raise
        except Exception:
            logger.exception("Backend listener failed for %s", self.rule.chargepoint_id)


class RoutingManager:
    def __init__(self) -> None:
        self.twins: Dict[str, ChargepointTwin] = {}

    async def apply_rules(self, rules: Dict[str, ChargepointRule]) -> None:
        existing_ids = set(self.twins.keys())
        new_ids = set(rules.keys())

        for cp_id in new_ids:
            rule = rules[cp_id]
            if cp_id in self.twins:
                await self.twins[cp_id].update_rule(rule)
            else:
                twin = ChargepointTwin(rule)
                self.twins[cp_id] = twin
                await twin.start()

        for cp_id in existing_ids - new_ids:
            twin = self.twins.pop(cp_id, None)
            if twin:
                await twin.stop()

    async def forward(self, cp_id: str, frame_type: Optional[int], action: Optional[str], payload: str) -> bool:
        twin = self.twins.get(cp_id)
        if not twin:
            return False
        if not twin.rule.should_forward(action, frame_type):
            return False
        await twin.enqueue(payload)
        return True

    async def stop_all(self) -> None:
        for cp_id, twin in list(self.twins.items()):
            await twin.stop()
            self.twins.pop(cp_id, None)


class MQTTBridge:
    def __init__(self, queue: "asyncio.Queue[MQTTMessage]") -> None:
        self.queue = queue
        self.client: Optional[mqtt.Client] = None
        self.config = MQTTConfig()
        self.loop: Optional[asyncio.AbstractEventLoop] = None

    async def ensure_running(self, config: MQTTConfig) -> None:
        if config == self.config:
            return
        if self.client:
            await self._stop_client()
        self.config = config
        if config.is_configured:
            await self._start_client()
        else:
            logger.warning("MQTT configuration incomplete; waiting for valid settings")

    async def _start_client(self) -> None:
        assert self.config.is_configured
        self.loop = asyncio.get_running_loop()
        self.client = mqtt.Client(callback_api_version=mqtt.CallbackAPIVersion.VERSION2)
        if self.config.user:
            self.client.username_pw_set(self.config.user, self.config.password)
        self.client.on_connect = self._on_connect
        self.client.on_message = self._on_message
        try:
            self.client.connect(self.config.broker, self.config.port)
        except Exception:
            logger.exception(
                "Failed to connect to MQTT broker %s:%s",
                self.config.broker,
                self.config.port,
            )
            self.client = None
            return
        logger.info(
            "Connected MQTT client to %s:%s (topic prefix '%s')",
            self.config.broker,
            self.config.port,
            self.config.topic_prefix,
        )
        self.client.loop_start()

    async def _stop_client(self) -> None:
        if not self.client:
            return
        logger.info("Stopping MQTT client")
        try:
            self.client.loop_stop()
            self.client.disconnect()
        except Exception:
            logger.exception("Error while stopping MQTT client")
        finally:
            self.client = None

    def _topic_filter(self) -> str:
        prefix = self.config.topic_prefix.rstrip("/") if self.config.topic_prefix else ""
        base = "wallbox/+/client_to_server"
        if prefix:
            return f"{prefix}/{base}"
        return base

    def _on_connect(self, client: mqtt.Client, _userdata: Any, flags: Dict[str, Any], rc: int) -> None:
        if rc != 0:
            logger.error("MQTT connection failed with code %s", rc)
            return
        topic = self._topic_filter()
        logger.info("Subscribing to MQTT topic %s", topic)
        client.subscribe(topic)

    def _on_message(self, _client: mqtt.Client, _userdata: Any, msg: mqtt.MQTTMessage) -> None:
        if not self.loop:
            return
        try:
            payload = msg.payload.decode("utf-8")
        except UnicodeDecodeError:
            payload = msg.payload.decode("utf-8", errors="ignore")
        message = MQTTMessage(topic=msg.topic, payload=payload)
        try:
            self.loop.call_soon_threadsafe(self.queue.put_nowait, message)
        except asyncio.QueueFull:
            logger.error("MQTT queue is full; dropping message from topic %s", msg.topic)


async def mqtt_consumer(
    queue: "asyncio.Queue[MQTTMessage]",
    manager: RoutingManager,
    config_supplier: Callable[[], MQTTConfig],
) -> None:
    while True:
        message = await queue.get()
        cfg: MQTTConfig = config_supplier()
        prefix = cfg.topic_prefix if cfg else ""
        station_id, direction = parse_mqtt_topic(message.topic, prefix)
        if not station_id or direction != "client_to_server":
            logger.debug("Ignoring MQTT message on topic %s", message.topic)
            continue
        frame_type, action = parse_ocpp_payload(message.payload)
        if frame_type is None:
            logger.debug(
                "Skipping non-OCPP payload received for station %s on topic %s",
                station_id or "<unknown>",
                message.topic,
            )
            continue
        forwarded = await manager.forward(station_id, frame_type, action, message.payload)
        if forwarded:
            logger.debug(
                "Forwarded OCPP %s for %s to backend", action or frame_type, station_id
            )
        else:
            logger.debug(
                "No forwarding rule matched for %s (%s)", station_id, action or frame_type
            )


async def refresh_loop(manager: RoutingManager, mqtt_bridge: MQTTBridge) -> None:
    while True:
        try:
            runtime_cfg = load_runtime_config()
            if runtime_cfg:
                mqtt_cfg = extract_mqtt_config(runtime_cfg)
                await mqtt_bridge.ensure_running(mqtt_cfg)
            rules = load_routing_rules()
            await manager.apply_rules(rules)
        except Exception:
            logger.exception("Failed to refresh routing configuration")
        await asyncio.sleep(REFRESH_INTERVAL_SECONDS)


async def main() -> None:
    runtime_cfg = load_runtime_config()
    mqtt_cfg = extract_mqtt_config(runtime_cfg) if runtime_cfg else MQTTConfig()

    manager = RoutingManager()
    await manager.apply_rules(load_routing_rules())

    mqtt_queue: asyncio.Queue[MQTTMessage] = asyncio.Queue(maxsize=MQTT_QUEUE_MAXSIZE)
    mqtt_bridge = MQTTBridge(mqtt_queue)
    await mqtt_bridge.ensure_running(mqtt_cfg)

    consumer_task = asyncio.create_task(
        mqtt_consumer(mqtt_queue, manager, lambda: mqtt_bridge.config),
        name="mqtt-consumer",
    )
    refresher_task = asyncio.create_task(
        refresh_loop(manager, mqtt_bridge), name="config-refresher"
    )

    stop_event = asyncio.Event()

    def _signal_handler() -> None:
        logger.info("Termination requested")
        stop_event.set()

    loop = asyncio.get_running_loop()
    for sig in (signal.SIGINT, signal.SIGTERM):
        loop.add_signal_handler(sig, _signal_handler)

    await stop_event.wait()
    consumer_task.cancel()
    refresher_task.cancel()
    with contextlib.suppress(asyncio.CancelledError):
        await consumer_task
        await refresher_task
    await manager.stop_all()
    await mqtt_bridge.ensure_running(MQTTConfig())


if __name__ == "__main__":
    asyncio.run(main())
