import json
import os
import uuid
from collections import defaultdict, deque
from datetime import datetime, timezone
from typing import Any, Deque, Dict, Iterable, List, Optional, Tuple
from urllib.parse import urlencode

from flask import Flask, Response, jsonify, request, g
import pymysql
import logging
import requests
import time

from services.token_service import TokenService

CONFIG_FILE = "config.json"
try:
    with open(CONFIG_FILE, "r", encoding="utf-8") as f:
        _config = json.load(f)
except FileNotFoundError:
    _config = {}

log_cfg = _config.get("log_levels", {})
level_name = log_cfg.get("rest_api", "INFO").upper()
LOG_LEVEL = getattr(logging, level_name, logging.INFO)
logging.basicConfig(
    level=LOG_LEVEL,
    format="%(asctime)s %(levelname)s:%(name)s:%(message)s",
)

audit_logger = logging.getLogger("audit")
if not audit_logger.handlers:
    audit_handler = logging.FileHandler("debug/audit.log")
    audit_handler.setLevel(logging.INFO)
    audit_formatter = logging.Formatter(
        "%(asctime)s %(levelname)s:%(name)s:%(message)s"
    )
    audit_handler.setFormatter(audit_formatter)
    audit_logger.addHandler(audit_handler)
audit_logger.setLevel(logging.INFO)

_mysql_cfg = {
    "host": "localhost",
    "user": "root",
    "password": "",
    "db": "op",
    "charset": "utf8mb4",
}
_mysql_cfg.update(_config.get("mysql", {}))

OCPP_ENDPOINT = _config.get("ocpp_endpoint", "")

SESSION_STATE_TABLE_SQL = """
CREATE TABLE IF NOT EXISTS op_broker_sessions (
    id BIGINT AUTO_INCREMENT PRIMARY KEY,
    session_uid VARCHAR(100) NOT NULL,
    station_id VARCHAR(255) NOT NULL,
    connector_id VARCHAR(100),
    evse_id VARCHAR(100),
    transaction_id VARCHAR(100),
    id_tag VARCHAR(255),
    status VARCHAR(30) NOT NULL,
    start_timestamp DATETIME NULL,
    end_timestamp DATETIME NULL,
    meter_start_wh DOUBLE NULL,
    meter_stop_wh DOUBLE NULL,
    meter_last_wh DOUBLE NULL,
    last_status VARCHAR(50),
    last_status_timestamp DATETIME NULL,
    updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
    UNIQUE KEY uniq_station_session (station_id, session_uid),
    KEY idx_status (status),
    KEY idx_station (station_id),
    KEY idx_updated_at (updated_at)
) CHARACTER SET utf8mb4
"""

EVSE_STATUS_TABLE_SQL = """
CREATE TABLE IF NOT EXISTS op_broker_evse_status (
    id BIGINT AUTO_INCREMENT PRIMARY KEY,
    station_id VARCHAR(255) NOT NULL,
    connector_id VARCHAR(100) NOT NULL,
    evse_id VARCHAR(100) NOT NULL,
    status VARCHAR(50),
    error_code VARCHAR(100),
    status_timestamp DATETIME NULL,
    updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
    UNIQUE KEY uniq_evse_status (station_id, connector_id, evse_id),
    KEY idx_updated_at (updated_at)
) CHARACTER SET utf8mb4
"""

_PARTNER_TOKEN_ENV = "PIPELET_PARTNER_TOKENS"
_PARTNER_TOKEN_FILE_ENV = "PIPELET_PARTNER_TOKENS_FILE"
_RATE_LIMIT_ENV = "PIPELET_RATE_LIMITS"
_RATE_LIMIT_DEFAULT_PER_MIN = int(os.environ.get("PIPELET_RATE_LIMIT_PER_MIN", "120"))
_RATE_LIMIT_WINDOW = int(os.environ.get("PIPELET_RATE_LIMIT_WINDOW", "60"))
_IP_ALLOWLIST_ENV = "PIPELET_IP_ALLOWLIST"


def _load_partner_token_map() -> Dict[str, Dict[str, str]]:
    def _normalize_entry(entry: Dict[str, Any]) -> Optional[Tuple[str, Dict[str, str]]]:
        token = (entry.get("token") or entry.get("access_token") or "").strip()
        partner_id = (entry.get("partner_id") or entry.get("partnerId") or "").strip()
        role = (entry.get("role") or "").upper()
        if not token or role not in {"CPO", "EMSP"}:
            return None
        return token, {"partner_id": partner_id or "unknown", "role": role}

    secrets: Dict[str, Dict[str, str]] = {}
    raw = os.environ.get(_PARTNER_TOKEN_ENV)
    path = os.environ.get(_PARTNER_TOKEN_FILE_ENV)
    payload: Iterable[Any] = []
    if path and os.path.exists(path):
        try:
            with open(path, "r", encoding="utf-8") as f:
                payload = json.load(f) or []
        except Exception as exc:  # pragma: no cover - defensive
            audit_logger.error("Failed to load partner tokens from file: %s", exc)
            payload = []
    elif raw:
        try:
            payload = json.loads(raw)
        except Exception as exc:  # pragma: no cover - defensive
            audit_logger.error("Failed to parse partner tokens env: %s", exc)
            payload = []

    if isinstance(payload, dict):
        payload = [payload]

    for entry in payload:
        if isinstance(entry, dict):
            normalized = _normalize_entry(entry)
            if normalized:
                token, info = normalized
                secrets[token] = info
    return secrets


def _load_rate_limits() -> Dict[str, int]:
    raw = os.environ.get(_RATE_LIMIT_ENV, "")
    limits: Dict[str, int] = {}
    if not raw:
        return limits
    try:
        payload = json.loads(raw)
    except Exception:
        return limits
    if isinstance(payload, dict):
        for module, value in payload.items():
            try:
                limits[str(module)] = max(1, int(value))
            except Exception:
                continue
    return limits


def _load_ip_allowlist() -> List[str]:
    raw = os.environ.get(_IP_ALLOWLIST_ENV, "")
    if not raw:
        return []
    return [ip.strip() for ip in raw.split(",") if ip.strip()]


class RateLimiter:
    def __init__(self, window_seconds: int, default_limit: int, module_limits: Dict[str, int]):
        self.window_seconds = max(1, window_seconds)
        self.default_limit = max(1, default_limit)
        self.module_limits = module_limits
        self._requests: Dict[Tuple[str, str], Deque[float]] = defaultdict(deque)

    def _limit_for_module(self, module: str) -> int:
        return self.module_limits.get(module, self.default_limit)

    def allow(self, module: str, partner_id: str) -> bool:
        key = (module, partner_id)
        now = time.monotonic()
        bucket = self._requests[key]
        window_start = now - self.window_seconds
        while bucket and bucket[0] < window_start:
            bucket.popleft()
        limit = self._limit_for_module(module)
        if len(bucket) >= limit:
            return False
        bucket.append(now)
        return True


partner_tokens = _load_partner_token_map()
rate_limits = _load_rate_limits()
ip_allowlist = _load_ip_allowlist()
rate_limiter = RateLimiter(_RATE_LIMIT_WINDOW, _RATE_LIMIT_DEFAULT_PER_MIN, rate_limits)

ROLE_REQUIREMENTS: List[Tuple[str, Tuple[str, ...]]] = [
    ("/WallboxManagement", ("CPO",)),
    ("/Charge", ("CPO",)),
    ("/tokens", ("EMSP",)),
]


def _audit_event(event: str, **details: Any) -> None:
    payload = {"event": event, **details}
    try:
        audit_logger.info(json.dumps(payload, ensure_ascii=False))
    except Exception:
        audit_logger.info("%s | %s", event, details)


def _audit_config_change(action: str, **details: Any) -> None:
    actor = getattr(g, "partner_info", {}) or {}
    _audit_event(
        "config_change",
        action=action,
        actor=actor,
        **details,
    )


def _module_for_request(req_path: str) -> str:
    cleaned = req_path.strip("/")
    if not cleaned:
        return "root"
    return cleaned.split("/", 1)[0]


def _required_roles_for_path(req_path: str) -> Optional[Tuple[str, ...]]:
    for prefix, roles in ROLE_REQUIREMENTS:
        if req_path.startswith(prefix):
            return roles
    return None


def _extract_bearer_token() -> Optional[str]:
    auth_header = request.headers.get("Authorization", "")
    if not auth_header.startswith("Bearer "):
        return None
    return auth_header.split(" ", 1)[1].strip()
app = Flask(__name__)


@app.route("/robots.txt")
def robots_txt():
    return Response("User-agent: *\nDisallow: /\n", mimetype="text/plain")


@app.before_request
def enforce_partner_security():
    if request.path == "/robots.txt":
        return None

    remote_ip = request.remote_addr or "unknown"
    if ip_allowlist and remote_ip not in ip_allowlist:
        _audit_event(
            "auth_denied",
            reason="ip_not_allowed",
            ip=remote_ip,
            path=request.path,
        )
        return jsonify({"error": "IP not allowed"}), 403

    token = _extract_bearer_token()
    partner = partner_tokens.get(token or "")
    if not partner:
        _audit_event(
            "auth_denied",
            reason="invalid_token",
            ip=remote_ip,
            path=request.path,
        )
        return jsonify({"error": "unauthorized"}), 401

    required_roles = _required_roles_for_path(request.path)
    if required_roles and partner.get("role") not in required_roles:
        _audit_event(
            "auth_denied",
            reason="role_forbidden",
            role=partner.get("role"),
            path=request.path,
            partner=partner.get("partner_id"),
        )
        return jsonify({"error": "forbidden"}), 403

    module = _module_for_request(request.path)
    partner_id = partner.get("partner_id") or "unknown"
    if not rate_limiter.allow(module, partner_id):
        _audit_event(
            "throttled",
            module=module,
            partner=partner_id,
            ip=remote_ip,
        )
        return jsonify({"error": "too many requests"}), 429

    g.partner_info = partner
    g.request_module = module
    return None


def get_db_conn():
    return pymysql.connect(
        host=_mysql_cfg["host"],
        user=_mysql_cfg["user"],
        password=_mysql_cfg["password"],
        db=_mysql_cfg["db"],
        charset=_mysql_cfg.get("charset", "utf8mb4"),
        cursorclass=pymysql.cursors.DictCursor,
    )


token_service = TokenService(get_db_conn)


def ensure_session_state_table(conn) -> None:
    with conn.cursor() as cur:
        cur.execute(SESSION_STATE_TABLE_SQL)
    conn.commit()


def ensure_evse_status_table(conn) -> None:
    with conn.cursor() as cur:
        cur.execute(EVSE_STATUS_TABLE_SQL)
    conn.commit()


def ensure_emsp_token_table(conn) -> None:
    token_service.ensure_table(conn)


def _parse_iso8601(value: Optional[str]) -> Optional[datetime]:
    if not value:
        return None
    try:
        return datetime.fromisoformat(value.replace("Z", "+00:00"))
    except Exception:
        return None


def _serialize_dt(value: Optional[datetime]) -> Optional[str]:
    if value is None:
        return None
    if value.tzinfo:
        return value.astimezone(timezone.utc).isoformat()
    return value.isoformat()


def _bool_param(value: Optional[str]) -> bool:
    if value is None:
        return False
    lowered = value.lower()
    return lowered in {"1", "true", "yes", "y"}


def _build_pagination(limit: int, offset: int, total: int, count: int) -> Dict[str, Any]:
    base_args = request.args.to_dict(flat=True)
    base_url = request.base_url
    pagination: Dict[str, Any] = {
        "limit": limit,
        "offset": offset,
        "count": count,
        "total": total,
    }

    if offset + limit < total:
        next_args = dict(base_args)
        next_args["offset"] = offset + limit
        next_args["limit"] = limit
        pagination["next"] = f"{base_url}?{urlencode(next_args)}"

    if offset > 0:
        prev_args = dict(base_args)
        prev_args["offset"] = max(offset - limit, 0)
        prev_args["limit"] = limit
        pagination["prev"] = f"{base_url}?{urlencode(prev_args)}"

    return pagination


def _bool_value(value: Any) -> bool:
    if isinstance(value, bool):
        return value
    if value is None:
        return False
    if isinstance(value, (int, float)):
        return bool(value)
    if isinstance(value, str):
        return value.strip().lower() in {"1", "true", "yes", "y"}
    return False


def _extract_tokens_from_payload(payload: Any) -> List[Dict[str, Any]]:
    if isinstance(payload, list):
        tokens = payload
    elif isinstance(payload, dict) and "tokens" in payload:
        tokens = payload.get("tokens") or []
    elif isinstance(payload, dict):
        tokens = [payload]
    else:
        tokens = []

    normalized_entries: List[Dict[str, Any]] = []
    for entry in tokens:
        if isinstance(entry, Dict):
            normalized = TokenService.normalize_payload(entry)
            if normalized:
                normalized_entries.append(normalized)
    return normalized_entries


def _coerce_float(value: Any) -> Optional[float]:
    if value is None:
        return None
    try:
        return float(value)
    except (TypeError, ValueError):
        return None


def _normalize_session_payload(entry: Dict[str, Any]) -> Optional[Dict[str, Any]]:
    session_uid = (
        entry.get("sessionUid")
        or entry.get("session_id")
        or entry.get("transactionId")
    )
    station_id = entry.get("stationId")
    if not session_uid or not station_id:
        return None

    return {
        "session_uid": str(session_uid),
        "station_id": station_id,
        "connector_id": entry.get("connectorId"),
        "evse_id": entry.get("evseId"),
        "transaction_id": entry.get("transactionId"),
        "id_tag": entry.get("idTag"),
        "status": (entry.get("status") or "ACTIVE").upper(),
        "start_timestamp": _parse_iso8601(entry.get("startTimestamp")),
        "end_timestamp": _parse_iso8601(entry.get("endTimestamp")),
        "meter_start_wh": _coerce_float(entry.get("meterStartWh")),
        "meter_stop_wh": _coerce_float(entry.get("meterStopWh")),
        "meter_last_wh": _coerce_float(entry.get("meterLastWh")),
        "last_status": entry.get("lastStatus"),
        "last_status_timestamp": _parse_iso8601(entry.get("lastStatusTimestamp")),
    }


def _serialize_session_row(row: Dict[str, Any]) -> Dict[str, Any]:
    return {
        "sessionUid": row.get("session_uid"),
        "stationId": row.get("station_id"),
        "connectorId": row.get("connector_id"),
        "evseId": row.get("evse_id"),
        "transactionId": row.get("transaction_id"),
        "idTag": row.get("id_tag"),
        "status": row.get("status"),
        "startTimestamp": _serialize_dt(row.get("start_timestamp")),
        "endTimestamp": _serialize_dt(row.get("end_timestamp")),
        "meterStartWh": row.get("meter_start_wh"),
        "meterStopWh": row.get("meter_stop_wh"),
        "meterLastWh": row.get("meter_last_wh"),
        "lastStatus": row.get("last_status"),
        "lastStatusTimestamp": _serialize_dt(row.get("last_status_timestamp")),
        "updatedAt": _serialize_dt(row.get("updated_at")),
    }


def _serialize_status_row(row: Dict[str, Any]) -> Dict[str, Any]:
    return {
        "stationId": row.get("station_id"),
        "connectorId": row.get("connector_id"),
        "evseId": row.get("evse_id"),
        "status": row.get("status"),
        "errorCode": row.get("error_code"),
        "statusTimestamp": _serialize_dt(row.get("status_timestamp")),
        "updatedAt": _serialize_dt(row.get("updated_at")),
    }


def _upsert_session_entry(conn, entry: Dict[str, Any]) -> None:
    payload = _normalize_session_payload(entry)
    if not payload:
        return

    with conn.cursor() as cur:
        cur.execute(
            """
            INSERT INTO op_broker_sessions (
                session_uid, station_id, connector_id, evse_id, transaction_id,
                id_tag, status, start_timestamp, end_timestamp, meter_start_wh,
                meter_stop_wh, meter_last_wh, last_status, last_status_timestamp
            ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
            ON DUPLICATE KEY UPDATE
                connector_id = COALESCE(VALUES(connector_id), connector_id),
                evse_id = COALESCE(VALUES(evse_id), evse_id),
                transaction_id = COALESCE(VALUES(transaction_id), transaction_id),
                id_tag = COALESCE(VALUES(id_tag), id_tag),
                status = COALESCE(VALUES(status), status),
                start_timestamp = COALESCE(VALUES(start_timestamp), start_timestamp),
                end_timestamp = COALESCE(VALUES(end_timestamp), end_timestamp),
                meter_start_wh = COALESCE(VALUES(meter_start_wh), meter_start_wh),
                meter_stop_wh = COALESCE(VALUES(meter_stop_wh), meter_stop_wh),
                meter_last_wh = COALESCE(VALUES(meter_last_wh), meter_last_wh),
                last_status = COALESCE(VALUES(last_status), last_status),
                last_status_timestamp = COALESCE(VALUES(last_status_timestamp), last_status_timestamp),
                updated_at = CURRENT_TIMESTAMP
            """,
            (
                payload["session_uid"],
                payload["station_id"],
                payload.get("connector_id"),
                payload.get("evse_id"),
                payload.get("transaction_id"),
                payload.get("id_tag"),
                payload.get("status"),
                payload.get("start_timestamp"),
                payload.get("end_timestamp"),
                payload.get("meter_start_wh"),
                payload.get("meter_stop_wh"),
                payload.get("meter_last_wh"),
                payload.get("last_status"),
                payload.get("last_status_timestamp"),
            ),
        )


def _fetch_evse_status(conn, station_id: Optional[str]) -> List[Dict[str, Any]]:
    ensure_evse_status_table(conn)
    sql = "SELECT station_id, connector_id, evse_id, status, error_code, status_timestamp, updated_at FROM op_broker_evse_status"
    params: List[Any] = []
    if station_id:
        sql += " WHERE station_id=%s"
        params.append(station_id)
    sql += " ORDER BY updated_at DESC"
    with conn.cursor() as cur:
        cur.execute(sql, params)
        rows = cur.fetchall()
    return [_serialize_status_row(r) for r in rows]


def get_config_value(key: str) -> Optional[str]:
    try:
        conn = get_db_conn()
        with conn.cursor() as cur:
            cur.execute(
                "SELECT config_value FROM op_config WHERE config_key=%s", (key,)
            )
            row = cur.fetchone()
            return row["config_value"] if row else None
    except Exception:
        return None
    finally:
        try:
            conn.close()
        except Exception:
            pass


def api_port() -> int:
    val = get_config_value("rest_api_port")
    try:
        return int(val) if val else 9750
    except Exception:
        return 9750


def _owner_filter(owner: Optional[str]) -> str:
    if owner:
        return owner
    return None


@app.get("/api/connected")
def get_connected_devices():
    owner = request.args.get("owner")
    owner = _owner_filter(owner)
    try:
        conn = get_db_conn()
        with conn.cursor() as cur:
            table = datetime.utcnow().strftime("op_messages_%y%m")
            sql = f"SELECT topic, MAX(timestamp) as ts FROM `{table}`"
            params: List[Any] = []
            if owner:
                sql += " WHERE ocpp_endpoint=%s"
                params.append(owner)
            sql += " GROUP BY topic HAVING ts >= NOW() - INTERVAL 1 HOUR"
            cur.execute(sql, params)
            rows = [r["topic"] for r in cur.fetchall()]
        return jsonify(rows)
    except Exception as e:
        return jsonify({"error": str(e)}), 500
    finally:
        try:
            conn.close()
        except Exception:
            pass


@app.get("/api/disconnected")
def get_disconnected_devices():
    return jsonify({"message": "not implemented"}), 501


@app.get("/api/meter/<device>")
def get_meter_value(device: str):
    return jsonify({"message": "not implemented"}), 501


def _parse_message(data: str) -> Optional[Dict[str, Any]]:
    try:
        msg = json.loads(data)
        if isinstance(msg, list) and len(msg) >= 3:
            return {
                "type": msg[0],
                "uid": msg[1],
                "action": msg[2],
                "payload": msg[3] if len(msg) > 3 else {},
            }
        return None
    except Exception:
        return None


@app.route("/tokens", methods=["GET"])
def list_emsp_tokens():
    try:
        limit = int(request.args.get("limit", 100))
    except Exception:
        limit = 100
    try:
        offset = int(request.args.get("offset", 0))
    except Exception:
        offset = 0
    limit = max(1, min(limit, 500))
    offset = max(0, offset)
    search = request.args.get("token")
    status_filter = request.args.get("status")

    tokens, total = token_service.list_tokens(
        search=search,
        status=status_filter,
        offset=offset,
        limit=limit,
    )
    serialized = [TokenService.serialize_rest(t) for t in tokens]
    return jsonify(
        {
            "tokens": serialized,
            "cache": token_service.cache_stats(),
            "pagination": _build_pagination(limit, offset, total, len(serialized)),
        }
    )


@app.route("/tokens", methods=["PUT"])
def upsert_emsp_tokens():
    payload = request.get_json(silent=True) or {}
    normalized_entries = _extract_tokens_from_payload(payload)
    if not normalized_entries:
        return jsonify({"stored": 0}), 200

    stored = token_service.upsert_tokens(normalized_entries)
    _audit_config_change(
        "upsert_tokens",
        stored=stored,
        tokens=[t.get("uid") for t in normalized_entries],
    )
    return jsonify({"stored": stored, "cache": token_service.cache_stats()}), 202


@app.route("/tokens/sync", methods=["POST"])
def sync_tokens():
    payload = request.get_json(silent=True) or {}
    normalized_entries = _extract_tokens_from_payload(payload)
    clear_cache = _bool_param(str(payload.get("clearCache") or payload.get("invalidateCache")))
    result = token_service.sync_tokens(normalized_entries, clear_cache=clear_cache)
    _audit_config_change(
        "sync_tokens",
        stored=result.get("stored", 0),
        cache=result.get("cache"),
    )
    return jsonify(result), 202


@app.route("/tokens/cache", methods=["GET", "DELETE"])
def token_cache_status():
    if request.method == "DELETE":
        token_service.cache.invalidate()
        return jsonify({"cache": token_service.cache_stats()}), 200
    return jsonify({"cache": token_service.cache_stats()}), 200


@app.route("/sessions", methods=["GET"])
def get_sessions():
    station_id = request.args.get("stationId")
    status_filter = request.args.get("status")
    from_ts = _parse_iso8601(request.args.get("from"))
    to_ts = _parse_iso8601(request.args.get("to"))
    try:
        limit = int(request.args.get("limit", 100))
    except Exception:
        limit = 100
    try:
        offset = int(request.args.get("offset", 0))
    except Exception:
        offset = 0
    limit = max(1, min(limit, 500))
    offset = max(0, offset)
    include_status = _bool_param(request.args.get("includeStatus", "true"))

    try:
        conn = get_db_conn()
        ensure_session_state_table(conn)
        with conn.cursor() as cur:
            where_clauses = ["1=1"]
            params: List[Any] = []
            if status_filter:
                where_clauses.append("status=%s")
                params.append(status_filter.upper())
            if station_id:
                where_clauses.append("station_id=%s")
                params.append(station_id)
            if from_ts:
                where_clauses.append("updated_at >= %s")
                params.append(from_ts)
            if to_ts:
                where_clauses.append("updated_at <= %s")
                params.append(to_ts)

            where_sql = " AND ".join(where_clauses)
            count_sql = f"SELECT COUNT(*) AS total_count FROM op_broker_sessions WHERE {where_sql}"
            cur.execute(count_sql, params)
            total = cur.fetchone().get("total_count", 0)

            sql = (
                "SELECT session_uid, station_id, connector_id, evse_id, transaction_id, id_tag, "
                "status, start_timestamp, end_timestamp, meter_start_wh, meter_stop_wh, "
                "meter_last_wh, last_status, last_status_timestamp, updated_at "
                f"FROM op_broker_sessions WHERE {where_sql} "
                "ORDER BY updated_at DESC LIMIT %s OFFSET %s"
            )
            cur.execute(sql, params + [limit, offset])
            rows = cur.fetchall()

        sessions = [_serialize_session_row(r) for r in rows]
        response: Dict[str, Any] = {
            "sessions": sessions,
            "pagination": _build_pagination(limit, offset, total, len(sessions)),
        }
        if include_status:
            response["evseStatus"] = _fetch_evse_status(conn, station_id)
        return jsonify(response)
    except Exception as e:
        return jsonify({"error": str(e)}), 500
    finally:
        try:
            conn.close()
        except Exception:
            pass


@app.route("/sessions", methods=["POST", "PUT"])
def upsert_sessions():
    payload = request.get_json(silent=True) or {}
    if isinstance(payload, list):
        sessions = payload
    elif isinstance(payload, dict) and "sessions" in payload:
        sessions = payload.get("sessions") or []
    elif isinstance(payload, dict):
        sessions = [payload]
    else:
        sessions = []

    if not sessions:
        return jsonify({"stored": 0}), 200

    try:
        conn = get_db_conn()
        ensure_session_state_table(conn)
        for entry in sessions:
            if isinstance(entry, dict):
                _upsert_session_entry(conn, entry)
        conn.commit()
        return jsonify({"stored": len(sessions)}), 202
    except Exception as e:
        return jsonify({"error": str(e)}), 500
    finally:
        try:
            conn.close()
        except Exception:
            pass


@app.get("/api/getsession")
def get_session_of_month():
    device = request.args.get("device")
    month = request.args.get("month")
    year = request.args.get("year")
    owner = _owner_filter(request.args.get("owner"))

    if not (device and month and year):
        return jsonify({"error": "device, month and year required"}), 400

    table = f"op_messages_{str(year)[-2:]}{int(month):02d}"
    try:
        conn = get_db_conn()
        with conn.cursor() as cur:
            sql = f"SELECT direction, message, timestamp FROM `{table}` WHERE topic=%s"
            params: List[Any] = [device]
            if owner:
                sql += " AND ocpp_endpoint=%s"
                params.append(owner)
            sql += " ORDER BY id"
            cur.execute(sql, params)
            rows = cur.fetchall()
    except Exception as e:
        return jsonify({"error": str(e)}), 500
    finally:
        try:
            conn.close()
        except Exception:
            pass

    start_by_uid: Dict[str, Dict[str, Any]] = {}
    start_by_txid: Dict[int, Dict[str, Any]] = {}
    last_data: Optional[Any] = None
    sessions: List[Dict[str, Any]] = []

    for row in rows:
        parsed = _parse_message(row["message"])
        if not parsed:
            continue
        mtype = parsed["type"]
        uid = parsed["uid"]
        action = parsed.get("action")
        payload = parsed.get("payload", {})

        if mtype == 2 and action == "StartTransaction":
            start_by_uid[uid] = payload
        elif mtype == 3 and uid in start_by_uid:
            txid = parsed.get("payload", {}).get("transactionId")
            if txid is not None:
                info = start_by_uid.pop(uid)
                start_by_txid[txid] = info
        elif mtype == 2 and action == "DataTransfer":
            last_data = payload.get("data") or payload
        elif mtype == 2 and action == "StopTransaction":
            txid = payload.get("transactionId")
            start = start_by_txid.get(txid, {})
            session = {
                "id": str(uuid.uuid4()),
                "domain": "ocpp",
                "deviceId": device,
                "connector_id": start.get("connectorId"),
                "register": "Energy.Active.Import.Register",
                "value": payload.get("meterStop"),
                "unit": "Wh",
                "timestamp": payload.get("timestamp"),
                "session": {
                    "Transaction.Id": txid,
                    "SessionID": str(txid),
                    "ID.Tag": start.get("idTag"),
                    "Energy.Meter.Start": start.get("meterStart"),
                    "Energy.Meter.Stop": payload.get("meterStop"),
                    "Reason.Session": payload.get("reason"),
                    "unit": "Wh",
                    "Data.Transfer.Value": last_data,
                    "Transaction.Data": json.dumps({}, ensure_ascii=False),
                    "Session.Start": start.get("timestamp"),
                    "Session.End": payload.get("timestamp"),
                },
            }
            sessions.append(session)
            last_data = None
    return jsonify(sessions)


@app.get("/Charge/GetChargeRecordsByRfId")
def get_charge_records():
    rfid = request.args.get("rfId") or request.args.get("rfid")
    month = request.args.get("month")
    if not rfid:
        return jsonify({"error": "rfId required"}), 400
    if month and (len(month) != 4 or not month.isdigit()):
        return jsonify({"error": "month must be in YYMM format"}), 400
    try:
        conn = get_db_conn()
        with conn.cursor() as cur:
            sql = (
                "SELECT station_id, connector_id, transaction_id, id_tag, "
                "session_start, session_end, meter_start, meter_stop, "
                "energyChargedWh, reason FROM op_charging_sessions "
                "WHERE id_tag=%s"
            )
            params: List[Any] = [rfid]
            if month:
                sql += " AND DATE_FORMAT(session_start, '%y%m')=%s"
                params.append(month)
            cur.execute(sql, params)
            rows = cur.fetchall()
        return jsonify(rows)
    except Exception as e:
        return jsonify({"error": str(e)}), 500
    finally:
        try:
            conn.close()
        except Exception:
            pass


@app.post("/WallboxManagement/createRedirect")
def create_redirect():
    data = request.get_json(silent=True) or {}
    source_url = data.get("source_url")
    ws_url = data.get("ws_url")
    activity = data.get("activity")
    if not source_url or not ws_url or not activity:
        return (
            jsonify({"error": "source_url, ws_url and activity required"}),
            400,
        )
    try:
        conn = get_db_conn()
        with conn.cursor() as cur:
            cur.execute(
                "SELECT id FROM op_redirects WHERE source_url=%s",
                (source_url,),
            )
            if cur.fetchone():
                return (
                    jsonify(
                        {
                            "created": False,
                            "message": "source_url already exists",
                        }
                    ),
                    200,
                )
            cur.execute(
                "INSERT INTO op_redirects (source_url, ws_url, activity) VALUES (%s, %s, %s)",
                (source_url, ws_url, activity),
            )
            conn.commit()
        _audit_config_change(
            "create_redirect",
            source_url=source_url,
            activity=activity,
        )
        return jsonify({"created": True}), 201
    except Exception as e:
        return jsonify({"error": str(e)}), 500
    finally:
        try:
            conn.close()
        except Exception:
            pass


@app.get("/WallboxManagement/getChargePointConfiguration")
def get_chargepoint_configuration():
    cp_id = request.args.get("chargepoint_id")
    if not cp_id:
        return jsonify({"error": "chargepoint_id required"}), 400
    try:
        conn = get_db_conn()
        with conn.cursor() as cur:
            cur.execute(
                """
                SELECT configuration_json FROM op_server_cp_config
                WHERE chargepoint_id=%s
                ORDER BY id DESC LIMIT 1
                """,
                (cp_id,),
            )
            row = cur.fetchone()
            if not row:
                return jsonify({"error": "not found"}), 404
            cfg = row.get("configuration_json")
            if isinstance(cfg, str):
                try:
                    cfg = json.loads(cfg)
                except Exception:
                    pass
        return jsonify(cfg)
    except Exception as e:
        return jsonify({"error": str(e)}), 500
    finally:
        try:
            conn.close()
        except Exception:
            pass


@app.post("/WallboxManagement/setConfiguration")
def set_chargepoint_configuration():
    data = request.get_json(silent=True) or {}
    cp_id = data.get("chargepointID")
    key = data.get("setconfiguration")
    value = data.get("value")
    if not (cp_id and key and value is not None):
        return (
            jsonify({"error": "chargepointID, setconfiguration and value required"}),
            400,
        )
    if not OCPP_ENDPOINT:
        return jsonify({"error": "ocpp_endpoint not configured"}), 500
    url = f"{OCPP_ENDPOINT.rstrip('/')}/api/setConfiguration"
    try:
        resp = requests.post(
            url,
            json={"station_id": cp_id, "key": key, "value": value},
            timeout=10,
        )
        resp.raise_for_status()
        try:
            result = resp.json()
        except Exception:
            result = {"status": resp.status_code}
        return jsonify(result), resp.status_code
    except Exception as e:
        return jsonify({"error": str(e)}), 500


@app.post("/api/datatransfer")
def datatransfer():
    payload = request.get_json(silent=True) or {}
    station_id = (
        payload.get("stationId")
        or payload.get("station_id")
        or payload.get("topic")
    )
    vendor_id = payload.get("vendorId") or payload.get("vendor_id")
    message_id = payload.get("messageId") or payload.get("message_id")
    data_field = payload.get("data")
    connector_id_raw = payload.get("connectorId") or payload.get("connector_id")

    if not station_id:
        return jsonify({"error": "stationId required"}), 400
    if not vendor_id:
        return jsonify({"error": "vendorId required"}), 400
    if message_id is None:
        return jsonify({"error": "messageId required"}), 400
    if data_field is None:
        return jsonify({"error": "data required"}), 400

    connector_id: Optional[int] = None
    if connector_id_raw is not None:
        try:
            connector_id = int(connector_id_raw)
        except (TypeError, ValueError):
            return jsonify({"error": "connectorId must be an integer"}), 400

    if not OCPP_ENDPOINT:
        return jsonify({"error": "ocpp_endpoint not configured"}), 500

    dt_payload: Dict[str, Any] = {
        "vendorId": vendor_id,
        "messageId": message_id,
        "data": data_field,
    }
    if connector_id is not None:
        dt_payload["connectorId"] = connector_id

    url = f"{OCPP_ENDPOINT.rstrip('/')}/api/pnc/ocpp-command"
    body = {
        "station_id": station_id,
        "action": "custom_action",
        "ocpp_action": "DataTransfer",
        "payload": dt_payload,
    }

    try:
        resp = requests.post(url, json=body, timeout=15)
        resp.raise_for_status()
        try:
            result = resp.json()
        except Exception:
            result = {"status": resp.status_code}
        return jsonify(result), resp.status_code
    except Exception as e:
        return jsonify({"error": str(e)}), 500


@app.post("/WallboxManagement/AddWallbox")
def add_wallbox():
    return jsonify({"message": "not implemented"}), 501


@app.get("/WallboxManagement/GetAllWallboxes")
def get_all_wallboxes():
    return jsonify({"message": "not implemented"}), 501


@app.patch("/WallboxManagement/UpdateWallbox")
def update_wallbox():
    return jsonify({"message": "not implemented"}), 501


if __name__ == "__main__":
    app.run(host="0.0.0.0", port=api_port())
