"""Simple webhook receiver that logs all incoming HTTP requests.

Run with:
    python hubject_webhook_listener.py

The server listens on port 9500 by default and prints the method, path, headers,
query parameters, and body for every request. If ``hubject_pnc.callback_server_port``
is set in ``config.json`` that value will be used instead of the default.
"""

from __future__ import annotations

from http.server import BaseHTTPRequestHandler, HTTPServer
from urllib.parse import parse_qs, urlparse
import json
import logging
import sys
import threading
from datetime import datetime
from pathlib import Path
from typing import Any, Dict

import pymysql
import requests


HOST = "0.0.0.0"
CONFIG_FILE = "config.json"
TOKEN_REFRESH_INTERVAL = 24 * 60 * 60  # 24 hours

try:
    _config = json.loads(Path(CONFIG_FILE).read_text(encoding="utf-8"))
except FileNotFoundError:
    _config = {}

_mysql_cfg: Dict[str, Any] = {
    "host": "localhost",
    "user": "root",
    "password": "",
    "db": "op",
    "charset": "utf8mb4",
}
_mysql_cfg.update(_config.get("mysql", {}))

HUBJECT_PNC_CFG: Dict[str, Any] = _config.get("hubject_pnc", {})
PORT = int(HUBJECT_PNC_CFG.get("callback_server_port", 9500))

HUBJECT_ACCESS_TOKEN_TABLE_SQL = """
CREATE TABLE IF NOT EXISTS hubject_access_token (
    access_token TEXT NOT NULL,
    expires_in INT,
    token_type VARCHAR(50),
    scope TEXT,
    ts TIMESTAMP DEFAULT CURRENT_TIMESTAMP
) CHARACTER SET utf8mb4
"""

_table_ready = False


def _db_connection():
    return pymysql.connect(
        **_mysql_cfg,
        cursorclass=pymysql.cursors.DictCursor,
        autocommit=True,
    )


def ensure_token_table(conn: pymysql.connections.Connection) -> None:
    global _table_ready
    if _table_ready:
        return
    with conn.cursor() as cur:
        cur.execute(HUBJECT_ACCESS_TOKEN_TABLE_SQL)
    _table_ready = True


def fetch_and_store_access_token() -> None:
    client_id = HUBJECT_PNC_CFG.get("client_id")
    client_secret = HUBJECT_PNC_CFG.get("client_secret")
    token_url = HUBJECT_PNC_CFG.get("TOKEN_URL")
    audience = HUBJECT_PNC_CFG.get("AUDIENCE")

    if not all([client_id, client_secret, token_url, audience]):
        logging.error("Hubject PnC configuration incomplete; cannot request token")
        return

    url = f"{token_url.rstrip('/')}/oauth/token"
    payload = {
        "client_id": client_id,
        "client_secret": client_secret,
        "audience": audience,
        "grant_type": "client_credentials",
    }

    try:
        response = requests.post(url, json=payload, timeout=30)
        response.raise_for_status()
        data = response.json()
    except Exception as exc:  # noqa: BLE001 - broad to log any failure
        logging.error("Failed to obtain Hubject access token: %s", exc)
        return

    logging.info("Obtained Hubject access token expiring in %s seconds", data.get("expires_in"))

    try:
        conn = _db_connection()
    except pymysql.MySQLError as exc:
        logging.error("Failed to connect to MySQL: %s", exc)
        return

    try:
        ensure_token_table(conn)
        with conn.cursor() as cur:
            cur.execute(
                """
                INSERT INTO hubject_access_token (access_token, expires_in, token_type, scope, ts)
                VALUES (%s, %s, %s, %s, %s)
                """,
                (
                    data.get("access_token"),
                    data.get("expires_in"),
                    data.get("token_type"),
                    data.get("scope"),
                    datetime.utcnow(),
                ),
            )
    except pymysql.MySQLError as exc:
        logging.error("Failed to store Hubject access token: %s", exc)
    finally:
        conn.close()


def _refresh_loop(stop_event: threading.Event) -> None:
    while not stop_event.wait(TOKEN_REFRESH_INTERVAL):
        fetch_and_store_access_token()


class LoggingRequestHandler(BaseHTTPRequestHandler):
    """Logs incoming HTTP requests to stdout."""

    def _log_request(self):
        parsed_url = urlparse(self.path)
        query_params = parse_qs(parsed_url.query)
        content_length = int(self.headers.get("Content-Length", 0))
        body = self.rfile.read(content_length) if content_length else b""

        message = {
            "client": self.client_address[0],
            "method": self.command,
            "path": parsed_url.path,
            "query": query_params,
            "headers": {k: v for k, v in self.headers.items()},
            "body_raw": body.decode("utf-8", errors="replace"),
        }

        logging.info("Received request: %s", json.dumps(message, indent=2))
        return body

    def _send_response(self, body: bytes):
        self.send_response(200)
        self.send_header("Content-Type", "text/plain")
        self.send_header("Content-Length", str(len(body)))
        self.end_headers()
        self.wfile.write(body)

    def do_GET(self):
        self._log_request()
        self._send_response(b"Webhook listener is running.\n")

    def do_POST(self):
        body = self._log_request()
        self._send_response(body or b"OK\n")

    def do_PUT(self):
        body = self._log_request()
        self._send_response(body or b"OK\n")

    def do_DELETE(self):
        self._log_request()
        self._send_response(b"OK\n")

    def log_message(self, format, *args):
        # Route default HTTP server logs through the logging module
        logging.info("%s - %s", self.address_string(), format % args)


def run_server():
    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s [%(levelname)s] %(message)s",
        handlers=[logging.StreamHandler(sys.stdout)],
    )

    fetch_and_store_access_token()

    stop_event = threading.Event()
    refresher = threading.Thread(
        target=_refresh_loop, args=(stop_event,), name="hubject-token-refresh", daemon=True
    )
    refresher.start()

    server = HTTPServer((HOST, PORT), LoggingRequestHandler)
    logging.info("Starting webhook listener on %s:%s", HOST, PORT)

    try:
        server.serve_forever()
    except KeyboardInterrupt:
        logging.info("Shutting down server...")
    finally:
        stop_event.set()
        server.server_close()


if __name__ == "__main__":
    run_server()
