import json
import logging
import os
import smtplib
import time
from dataclasses import dataclass, field
from datetime import datetime
from email.message import EmailMessage
from typing import Any, Iterable, Mapping, Optional, Sequence
from zoneinfo import ZoneInfo

import requests
from flask import Request

CONFIG_PATH_ENV = "PIPELET_CONFIG"
DEFAULT_CONFIG_FILE = "config.json"


@dataclass(frozen=True)
class BackendProfile:
    backend_id: str
    url: Optional[str] = None
    token: Optional[str] = None
    country_code: Optional[str] = None
    party_id: Optional[str] = None
    modules: set[str] = field(default_factory=set)
    hosts: set[str] = field(default_factory=set)
    base_url: Optional[str] = None
    default: bool = False

    @staticmethod
    def _parse_modules(raw: Any) -> set[str]:
        if isinstance(raw, str):
            raw = raw.replace(";", ",")
            return {item.strip().lower() for item in raw.split(",") if item.strip()}
        if isinstance(raw, Iterable):
            return {str(item).strip().lower() for item in raw if str(item).strip()}
        return set()

    @classmethod
    def from_mapping(cls, cfg: Mapping[str, Any]) -> "BackendProfile":
        return cls(
            backend_id=str(cfg.get("id") or cfg.get("name") or "default"),
            url=cfg.get("url"),
            token=cfg.get("token"),
            country_code=cfg.get("country_code") or cfg.get("country"),
            party_id=cfg.get("party_id") or cfg.get("party"),
            modules=cls._parse_modules(cfg.get("modules") or cfg.get("module_flags")),
            hosts={str(host).lower() for host in cfg.get("hosts", []) if str(host).strip()},
            base_url=cfg.get("base_url"),
            default=bool(cfg.get("default", False)),
        )

    def supports_module(self, module: Optional[str]) -> bool:
        if not module:
            return True
        if not self.modules:
            return True
        return module.strip().lower() in self.modules


class BackendRegistry:
    def __init__(self, profiles: Sequence[BackendProfile], fallback_token: Optional[str] = None):
        if not profiles:
            raise ValueError("At least one backend profile is required")

        normalized_profiles: list[BackendProfile] = []
        for profile in profiles:
            if profile.token or not fallback_token:
                normalized_profiles.append(profile)
                continue
            normalized_profiles.append(
                BackendProfile(
                    backend_id=profile.backend_id,
                    url=profile.url,
                    token=fallback_token,
                    country_code=profile.country_code,
                    party_id=profile.party_id,
                    modules=profile.modules,
                    hosts=profile.hosts,
                    base_url=profile.base_url,
                    default=profile.default,
                )
            )

        self.profiles = normalized_profiles
        self._by_id = {profile.backend_id: profile for profile in self.profiles}
        self._host_map: dict[str, BackendProfile] = {}
        for profile in self.profiles:
            for host in profile.hosts:
                self._host_map.setdefault(host.lower(), profile)

        default_candidates = [profile for profile in self.profiles if profile.default]
        self._default = default_candidates[0] if default_candidates else self.profiles[0]

    @classmethod
    def from_config(cls, ocpi_cfg: Mapping[str, Any], *, fallback_token: Optional[str] = None) -> "BackendRegistry":
        profiles: list[BackendProfile] = []
        raw_backends = ocpi_cfg.get("backends") if isinstance(ocpi_cfg, Mapping) else None
        if isinstance(raw_backends, Sequence):
            for entry in raw_backends:
                if isinstance(entry, Mapping):
                    profiles.append(BackendProfile.from_mapping(entry))
        if not profiles:
            profiles.append(
                BackendProfile(
                    backend_id="default",
                    url=ocpi_cfg.get("base_url"),
                    token=fallback_token or ocpi_cfg.get("token"),
                    country_code=ocpi_cfg.get("country_code"),
                    party_id=ocpi_cfg.get("party_id"),
                    modules=BackendProfile._parse_modules(ocpi_cfg.get("modules")),
                )
            )
        return cls(profiles, fallback_token=fallback_token)

    def select_backend(
        self,
        request: Request,
        explicit_backend_id: Optional[str] = None,
        *,
        module: Optional[str] = None,
    ) -> Optional[BackendProfile]:
        backend_id = explicit_backend_id or request.headers.get("X-OCPI-Backend")
        if not backend_id:
            backend_id = request.headers.get("X-Backend-ID")

        candidate = self._by_id.get(str(backend_id)) if backend_id else None
        if candidate is None:
            host = (request.host or "").split(":")[0].lower()
            candidate = self._host_map.get(host)
        if candidate is None:
            candidate = self._default

        if module and candidate and not candidate.supports_module(module):
            return None
        return candidate


def load_config() -> dict[str, Any]:
    config_path = os.environ.get(CONFIG_PATH_ENV, DEFAULT_CONFIG_FILE)
    try:
        with open(config_path, "r", encoding="utf-8") as cfg:
            return json.load(cfg)
    except FileNotFoundError:
        logging.getLogger(__name__).warning(
            "Config file %s not found, using defaults", config_path
        )
    except json.JSONDecodeError:
        logging.getLogger(__name__).warning(
            "Config file %s is not valid JSON, using defaults", config_path
        )
    return {}


class JsonFormatter(logging.Formatter):
    def format(self, record: logging.LogRecord) -> str:  # pragma: no cover - formatting only
        payload: dict[str, Any] = {
            "timestamp": datetime.utcnow().isoformat() + "Z",
            "level": record.levelname,
            "logger": record.name,
            "message": record.getMessage(),
        }

        standard_fields = {
            "name",
            "msg",
            "args",
            "levelname",
            "levelno",
            "pathname",
            "filename",
            "module",
            "exc_info",
            "exc_text",
            "stack_info",
            "lineno",
            "funcName",
            "created",
            "msecs",
            "relativeCreated",
            "thread",
            "threadName",
            "processName",
            "process",
        }
        for key, value in record.__dict__.items():
            if key not in standard_fields:
                payload[key] = value

        if record.exc_info:
            payload["exception"] = self.formatException(record.exc_info)
        return json.dumps(payload, ensure_ascii=False)


def setup_logging(config: Mapping[str, Any], *, logger_name: str) -> None:
    log_cfg = config.get("log_levels", {}) if isinstance(config, Mapping) else {}
    level_name = log_cfg.get(logger_name, "INFO").upper()
    level = getattr(logging, level_name, logging.INFO)

    use_json = True
    if isinstance(config, Mapping):
        logging_cfg = config.get("logging", {})
        if isinstance(logging_cfg, Mapping):
            use_json = bool(logging_cfg.get("json", True))

    formatter: logging.Formatter
    if use_json:
        formatter = JsonFormatter()
    else:
        formatter = logging.Formatter(
            "%(asctime)s %(levelname)s:%(name)s:%(message)s"
        )

    logger = logging.getLogger(logger_name)
    logger.handlers.clear()
    handler = logging.StreamHandler()
    handler.setLevel(level)
    handler.setFormatter(formatter)
    logger.addHandler(handler)
    logger.setLevel(level)
    logger.propagate = False


@dataclass
class FailureNotifier:
    webhook_url: Optional[str] = None
    slack_webhook_url: Optional[str] = None
    email_recipients: Sequence[str] = field(default_factory=list)
    smtp_host: Optional[str] = None
    smtp_port: int = 25
    smtp_username: Optional[str] = None
    smtp_password: Optional[str] = None
    threshold: int = 3
    cooldown_seconds: int = 300
    quiet_hours_start: Optional[str] = None
    quiet_hours_end: Optional[str] = None
    quiet_timezone: str = "UTC"
    logger: logging.Logger = field(default_factory=lambda: logging.getLogger(__name__))
    consecutive_failures: int = 0
    last_notified: float = 0.0
    last_sent: dict[str, float] = field(default_factory=dict)

    @classmethod
    def from_config(
        cls, config: Optional[Mapping[str, Any]], logger: logging.Logger
    ) -> "FailureNotifier":
        if not isinstance(config, Mapping):
            return cls(logger=logger, threshold=0)
        smtp_cfg = config.get("smtp", {}) if isinstance(config.get("smtp"), Mapping) else {}
        quiet_cfg = config.get("quiet_hours") if isinstance(config.get("quiet_hours"), Mapping) else {}
        return cls(
            webhook_url=config.get("webhook_url"),
            slack_webhook_url=config.get("slack_webhook_url")
            or config.get("slack_webhook"),
            email_recipients=tuple(
                str(addr).strip()
                for addr in config.get("email_recipients", [])
                if str(addr).strip()
            ),
            smtp_host=smtp_cfg.get("host"),
            smtp_port=int(smtp_cfg.get("port", 25)),
            smtp_username=smtp_cfg.get("username"),
            smtp_password=smtp_cfg.get("password"),
            threshold=int(config.get("threshold", 3)),
            cooldown_seconds=int(config.get("cooldown_seconds", 300)),
            quiet_hours_start=str(quiet_cfg.get("start") or "").strip() or None,
            quiet_hours_end=str(quiet_cfg.get("end") or "").strip() or None,
            quiet_timezone=str(quiet_cfg.get("timezone") or "UTC").strip() or "UTC",
            logger=logger,
        )

    @property
    def enabled(self) -> bool:
        return bool(self.webhook_url or self.slack_webhook_url or self.email_recipients)

    def record_success(self) -> None:
        self.consecutive_failures = 0

    def _in_quiet_hours(self) -> bool:
        if not self.quiet_hours_start or not self.quiet_hours_end:
            return False

        def _parse_minutes(value: str) -> Optional[int]:
            try:
                parts = value.split(":")
                hour = int(parts[0])
                minute = int(parts[1]) if len(parts) > 1 else 0
                if hour < 0 or hour > 23 or minute < 0 or minute > 59:
                    return None
                return hour * 60 + minute
            except (TypeError, ValueError):
                return None

        start_minutes = _parse_minutes(self.quiet_hours_start)
        end_minutes = _parse_minutes(self.quiet_hours_end)
        if start_minutes is None or end_minutes is None:
            return False

        try:
            tz = ZoneInfo(self.quiet_timezone)
            now = datetime.now(tz)
        except Exception:
            now = datetime.utcnow()
        current_minutes = now.hour * 60 + now.minute

        if start_minutes <= end_minutes:
            return start_minutes <= current_minutes <= end_minutes
        return current_minutes >= start_minutes or current_minutes <= end_minutes

    def _should_cooldown(self, key: str, *, cooldown_override: Optional[int] = None) -> bool:
        now = time.time()
        cooldown = cooldown_override if cooldown_override is not None else self.cooldown_seconds
        last = self.last_sent.get(key)
        if last and now - last < cooldown:
            return True
        self.last_sent[key] = now
        return False

    def notify_event(
        self,
        key: str,
        summary: str,
        details: Mapping[str, Any],
        *,
        cooldown_seconds: Optional[int] = None,
    ) -> None:
        if not self.enabled or self._in_quiet_hours():
            return
        if self._should_cooldown(key, cooldown_override=cooldown_seconds):
            return
        payload = {
            "summary": summary,
            "details": dict(details),
            "timestamp": datetime.utcnow().isoformat() + "Z",
        }
        self._notify_webhook(payload)
        self._notify_slack(payload)
        self._notify_email(payload)

    def record_failure(self, summary: str, details: Mapping[str, Any]) -> None:
        self.consecutive_failures += 1
        if not self.enabled:
            return
        now = time.time()
        if self.consecutive_failures < max(1, self.threshold):
            return
        if self._in_quiet_hours():
            return
        if now - self.last_notified < self.cooldown_seconds:
            return
        self.last_notified = now
        payload = {
            "summary": summary,
            "details": dict(details),
            "consecutive_failures": self.consecutive_failures,
            "timestamp": datetime.utcnow().isoformat() + "Z",
        }
        self.last_sent["consecutive_failures"] = now
        self._notify_webhook(payload)
        self._notify_slack(payload)
        self._notify_email(payload)

    def _notify_webhook(self, payload: Mapping[str, Any]) -> None:
        if not self.webhook_url:
            return
        try:
            requests.post(self.webhook_url, json=payload, timeout=5)
        except Exception:
            self.logger.debug("Webhook notification failed", exc_info=True)

    def _notify_slack(self, payload: Mapping[str, Any]) -> None:
        if not self.slack_webhook_url:
            return
        text = f"{payload.get('summary')} (x{payload.get('consecutive_failures')})"
        try:
            requests.post(
                self.slack_webhook_url,
                json={"text": text, "attachments": [{"text": json.dumps(payload)}]},
                timeout=5,
            )
        except Exception:
            self.logger.debug("Slack notification failed", exc_info=True)

    def _notify_email(self, payload: Mapping[str, Any]) -> None:
        if not self.email_recipients or not self.smtp_host:
            return
        try:
            message = EmailMessage()
            message["Subject"] = payload.get("summary", "OCPI alert")
            message["From"] = self.smtp_username or "ocpi-alerts@example.com"
            message["To"] = ",".join(self.email_recipients)
            message.set_content(json.dumps(payload, ensure_ascii=False, indent=2))
            with smtplib.SMTP(self.smtp_host, self.smtp_port, timeout=10) as smtp:
                if self.smtp_username and self.smtp_password:
                    smtp.login(self.smtp_username, self.smtp_password)
                smtp.send_message(message)
        except Exception:
            self.logger.debug("Email notification failed", exc_info=True)


def ocpi_timestamp() -> str:
    return datetime.utcnow().isoformat() + "Z"


def mask_headers(req: Request) -> dict[str, str]:
    headers = {k: v for k, v in req.headers.items()}
    if "Authorization" in headers:
        headers["Authorization"] = "***"
    return headers


def log_request(logger: logging.Logger, req: Request) -> None:
    logger.info(
        "Incoming request",
        extra={
            "event": "request",
            "method": req.method,
            "path": req.path,
            "headers": mask_headers(req),
            "body": req.get_json(silent=True),
        },
    )


def log_response(logger: logging.Logger, req: Request, status_code: int, duration: float) -> None:
    logger.info(
        "Request completed",
        extra={
            "event": "response",
            "method": req.method,
            "path": req.path,
            "status": status_code,
            "duration_ms": round(duration * 1000, 2),
        },
    )


def validate_json_payload(payload: Any, required_fields: Iterable[str]) -> tuple[bool, Optional[str]]:
    if not isinstance(payload, Mapping):
        return False, "JSON body is required"
    for field in required_fields:
        value = payload.get(field)
        if value is None or (isinstance(value, str) and not value.strip()):
            return False, field
    return True, None


def verify_token_header(req: Request, expected_token: Optional[str]) -> bool:
    auth_header = req.headers.get("Authorization")
    if not auth_header or not auth_header.startswith("Token "):
        return False
    provided = auth_header.split(" ", 1)[1].strip()
    if not expected_token:
        return False
    return provided == expected_token
