#!/usr/bin/env python3
"""Collect charging sessions from monthly OCPP server messages.

Usage::
    python services/getSessionsOfMonthByServer.py [YYMM]

If *YYMM* is omitted, the script processes messages for the current month.

Reads StartTransaction/StopTransaction messages from op_server_messages_YYMM and
stores aggregated sessions in op_server_charging_sessions.
"""

import json
import sys
from datetime import datetime, timezone
import logging

import pymysql

CONFIG_FILE = "config.json"

logger = logging.getLogger(__name__)


# ---------------------------------------------------------------------------
# helpers
# ---------------------------------------------------------------------------

def load_config(path: str = CONFIG_FILE) -> dict:
    with open(path, "r", encoding="utf-8") as f:
        return json.load(f)


def init_db(cfg: dict):
    return pymysql.connect(
        host=cfg["host"],
        user=cfg["user"],
        password=cfg["password"],
        db=cfg["db"],
        charset=cfg.get("charset", "utf8"),
        autocommit=True,
    )


def parse_time(ts: str | None):
    if not ts:
        return None
    try:
        # replace trailing Z with +00:00 so fromisoformat can parse it
        dt = datetime.fromisoformat(ts.replace("Z", "+00:00"))
        return dt.astimezone(timezone.utc)
    except ValueError:
        return None


def create_session_table(cur) -> None:
    cur.execute(
        """
        CREATE TABLE IF NOT EXISTS op_server_charging_sessions (
            id INT AUTO_INCREMENT PRIMARY KEY,
            chargepoint_id VARCHAR(50) NOT NULL,
            connector_id INT NOT NULL,
            transaction_id INT NOT NULL,
            id_tag VARCHAR(100) NOT NULL,
            session_start TIMESTAMP NULL,
            session_end TIMESTAMP NULL,
            meter_start INT,
            meter_stop INT,
            energyChargedWh INT,
            reason VARCHAR(50),
            UNIQUE KEY uniq_cp_txn (chargepoint_id, transaction_id)
        )
        """
    )

    # ensure column exists for installations created with older versions
    cur.execute(
        "SHOW COLUMNS FROM op_server_charging_sessions LIKE 'energyChargedWh'"
    )
    if not cur.fetchone():
        cur.execute(
            "ALTER TABLE op_server_charging_sessions ADD COLUMN energyChargedWh INT AFTER meter_stop"
        )


def ensure_message_table(cur, table: str) -> bool:
    """Ensure message table has the expected columns."""

    cur.execute("SHOW TABLES LIKE %s", (table,))
    if not cur.fetchone():
        logger.warning("Message table %s does not exist", table)
        return False

    cur.execute(f"SHOW COLUMNS FROM `{table}`")
    columns = {row[0] for row in cur.fetchall()}

    if "chargepoint_id" not in columns:
        logger.debug("Adding missing column chargepoint_id to %s", table)
        cur.execute(
            f"ALTER TABLE `{table}` ADD COLUMN chargepoint_id VARCHAR(255) NULL"
        )
        columns.add("chargepoint_id")

    if "source_url" not in columns:
        logger.debug("Adding missing column source_url to %s", table)
        cur.execute(
            f"ALTER TABLE `{table}` ADD COLUMN source_url VARCHAR(255) NULL"
        )
        columns.add("source_url")

    if "message" not in columns:
        logger.debug("Adding missing column message to %s", table)
        cur.execute(f"ALTER TABLE `{table}` ADD COLUMN message JSON")

    if "chargepoint_id" in columns and "source_url" in columns:
        logger.debug("Syncing chargepoint_id/source_url columns in %s", table)
        cur.execute(
            f"""
            UPDATE `{table}`
            SET source_url = chargepoint_id
            WHERE (source_url IS NULL OR source_url = '')
              AND chargepoint_id IS NOT NULL
        """
        )
        cur.execute(
            f"""
            UPDATE `{table}`
            SET chargepoint_id = source_url
            WHERE (chargepoint_id IS NULL OR chargepoint_id = '')
              AND source_url IS NOT NULL
        """
        )

    return True


def collect_sessions(cur, month: str):
    table = f"op_server_messages_{month}"
    if not ensure_message_table(cur, table):
        return []

    logger.debug("Querying messages for %s", month)
    cur.execute(
        f'''
        SELECT source_url, chargepoint_id, message
        FROM `{table}`
        WHERE message LIKE '%StartTransaction%'
           OR message LIKE '%StopTransaction%'
           OR message LIKE '%"transactionId"%'
        ORDER BY id
        '''
    )
    rows = cur.fetchall()
    logger.debug("Fetched %d messages", len(rows))

    start_requests: dict[str, dict] = {}
    start_transactions: dict[int, dict] = {}
    sessions = []

    for source_url, chargepoint_col, message in rows:
        try:
            data = json.loads(message)
        except json.JSONDecodeError:
            continue
        if not isinstance(data, list) or len(data) < 3:
            continue

        msg_type = data[0]
        unique_id = data[1]
        cp_source = source_url or chargepoint_col
        if not cp_source:
            logger.debug("Skipping message without chargepoint reference (uid %s)", unique_id)
            continue
        chargepoint_id = cp_source.rsplit("/", 1)[-1]

        # StartTransaction CALL from charge point -> central system
        if msg_type == 2 and data[2] == "StartTransaction":
            payload = data[3]
            logger.debug(
                "StartTransaction from %s (uid %s)", chargepoint_id, unique_id
            )
            start_requests[unique_id] = {
                "chargepoint_id": chargepoint_id,
                "connector_id": payload.get("connectorId"),
                "id_tag": payload.get("idTag"),
                "meter_start": payload.get("meterStart"),
                "start": parse_time(payload.get("timestamp")),
            }

        # StartTransaction CALLRESULT from central system -> charge point
        elif msg_type == 3 and unique_id in start_requests:
            payload = data[2]
            tx_id = payload.get("transactionId")
            info = start_requests.pop(unique_id)
            if tx_id is not None:
                info["transaction_id"] = tx_id
                start_transactions[tx_id] = info
                logger.debug("StartTransaction confirmed for tx %s", tx_id)

        # StopTransaction CALL from charge point -> central system
        elif msg_type == 2 and data[2] == "StopTransaction":
            payload = data[3]
            tx_id = payload.get("transactionId")
            info = start_transactions.get(tx_id)
            if info:
                logger.debug("StopTransaction for tx %s", tx_id)
                meter_stop = payload.get("meterStop")
                energy = (
                    meter_stop - info["meter_start"]
                    if meter_stop is not None and info["meter_start"] is not None
                    else None
                )
                sessions.append(
                    (
                        info["chargepoint_id"],
                        info["connector_id"],
                        tx_id,
                        info["id_tag"],
                        info["start"],
                        parse_time(payload.get("timestamp")),
                        info["meter_start"],
                        meter_stop,
                        energy,
                        payload.get("reason"),
                    )
                )
            else:
                logger.debug("StopTransaction for unknown tx %s", tx_id)

    logger.debug("Collected %d sessions", len(sessions))
    return sessions


def insert_sessions(cur, sessions) -> int:
    if not sessions:
        return 0
    sql = (
        "INSERT IGNORE INTO op_server_charging_sessions "
        "(chargepoint_id, connector_id, transaction_id, id_tag, session_start, session_end, meter_start, meter_stop, energyChargedWh, reason) "
        "VALUES (%s,%s,%s,%s,%s,%s,%s,%s,%s,%s)"
    )
    logger.debug("Inserting %d sessions into database", len(sessions))
    cur.executemany(sql, sessions)
    logger.debug("Inserted %d new rows", cur.rowcount)
    return cur.rowcount


# ---------------------------------------------------------------------------
# main
# ---------------------------------------------------------------------------

def main():
    if len(sys.argv) > 2:
        print("Usage: getSessionsOfMonthByServer.py [YYMM]")
        sys.exit(1)

    if len(sys.argv) == 2:
        month = sys.argv[1]
    else:
        month = datetime.now(timezone.utc).strftime("%y%m")

    logging.basicConfig(level=logging.DEBUG, format="%(levelname)s:%(message)s")
    logger.info("Collecting sessions for %s", month)

    config = load_config()
    db_cfg = config.get("mysql", {})
    logger.debug("Connecting to database %s", db_cfg.get("db"))
    conn = init_db(db_cfg)
    try:
        with conn.cursor() as cur:
            create_session_table(cur)
            logger.debug("Session table checked/created")
            sessions = collect_sessions(cur, month)
            inserted = insert_sessions(cur, sessions)
        logger.info("Inserted %d sessions for %s", inserted, month)
    finally:
        conn.close()


if __name__ == "__main__":
    main()
