"""Repositorios Mongo y almacenamiento local usados por el flujo batch."""

from __future__ import annotations

from pathlib import Path
from typing import Any

from bson import ObjectId
from pymongo import MongoClient
from pymongo.errors import OperationFailure

from app.config import config


def _serialize_document(doc: dict[str, Any] | None) -> dict[str, Any] | None:
    if not doc:
        return None
    serialized = dict(doc)
    if "_id" in serialized:
        serialized["_id"] = str(serialized["_id"])
    return serialized


def _safe_create_index(collection, keys: list[tuple[str, int]], **kwargs: Any) -> None:
    """
    Crea índices de forma tolerante para no romper arranque si ya existen
    con opciones previas (ej. unique en entornos con migraciones parciales).
    """

    try:
        collection.create_index(keys, **kwargs)
    except OperationFailure as exc:
        if getattr(exc, "code", None) in {85, 86}:
            return
        raise


class MongoBatchRepository:
    """Persistencia principal del lote en MongoDB."""

    def __init__(self) -> None:
        client = MongoClient(config.MONGO_URI)
        db = client[config.DB_NAME]
        self.collection = db["processing_batches"]
        _safe_create_index(self.collection, [("created_at", -1)])
        _safe_create_index(self.collection, [("usuario", 1), ("updated_at", -1)])

    def create_batch(self, payload: dict) -> str:
        inserted = self.collection.insert_one(payload)
        return str(inserted.inserted_id)

    def update_batch(self, batch_id: str, payload: dict) -> None:
        if not ObjectId.is_valid(batch_id):
            return
        self.collection.update_one({"_id": ObjectId(batch_id)}, {"$set": payload})

    def get_batch(self, batch_id: str) -> dict | None:
        if not ObjectId.is_valid(batch_id):
            return None
        return _serialize_document(self.collection.find_one({"_id": ObjectId(batch_id)}))

    def list_user_batches(self, username: str, *, limit: int = 20) -> list[dict]:
        normalized_limit = max(1, min(int(limit or 20), 100))
        docs = list(
            self.collection.find({"usuario": username}).sort([("updated_at", -1)]).limit(normalized_limit)
        )
        return [_serialize_document(doc) for doc in docs if doc]


class MongoBatchFileRepository:
    """Persistencia por archivo del lote en MongoDB."""

    def __init__(self) -> None:
        client = MongoClient(config.MONGO_URI)
        db = client[config.DB_NAME]
        self.collection = db["processing_batch_files"]
        _safe_create_index(self.collection, [("batch_id", 1), ("status", 1)])
        _safe_create_index(self.collection, [("batch_id", 1), ("case_key", 1)])
        _safe_create_index(self.collection, [("batch_id", 1), ("created_at", 1)])
        _safe_create_index(self.collection, [("batch_id", 1), ("clinical_status", 1)])
        _safe_create_index(self.collection, [("analysis_document_id", 1)])

    def create_file(self, payload: dict) -> str:
        inserted = self.collection.insert_one(payload)
        return str(inserted.inserted_id)

    def update_file(self, file_id: str, payload: dict) -> None:
        if not ObjectId.is_valid(file_id):
            return
        self.collection.update_one({"_id": ObjectId(file_id)}, {"$set": payload})

    def get_file(self, file_id: str) -> dict | None:
        if not ObjectId.is_valid(file_id):
            return None
        return _serialize_document(self.collection.find_one({"_id": ObjectId(file_id)}))

    def list_files(self, batch_id: str, *, status: str | None = None) -> list[dict]:
        query: dict[str, Any] = {"batch_id": batch_id}
        if status:
            query["status"] = status
        docs = list(
            self.collection.find(query).sort([("created_at", 1), ("original_name", 1)])
        )
        return [_serialize_document(doc) for doc in docs if doc]


class MongoBatchCaseRepository:
    """Persistencia de casos consolidados a partir de la asociación automática."""

    def __init__(self) -> None:
        client = MongoClient(config.MONGO_URI)
        db = client[config.DB_NAME]
        self.collection = db["processing_batch_cases"]
        _safe_create_index(self.collection, [("batch_id", 1), ("case_key", 1)])
        _safe_create_index(self.collection, [("batch_id", 1), ("patient_name", 1)])
        _safe_create_index(self.collection, [("usuario", 1), ("case_key", 1)])

    def replace_cases(self, batch_id: str, cases: list[dict]) -> None:
        self.collection.delete_many({"batch_id": batch_id})
        if not cases:
            return
        self.collection.insert_many(cases)

    def list_cases(self, batch_id: str) -> list[dict]:
        docs = list(
            self.collection.find({"batch_id": batch_id}).sort(
                [("patient_name", 1), ("case_number", 1), ("case_key", 1)]
            )
        )
        return [_serialize_document(doc) for doc in docs if doc]

    def get_case(self, batch_id: str, case_key: str) -> dict | None:
        return _serialize_document(
            self.collection.find_one({"batch_id": batch_id, "case_key": case_key})
        )

    def get_user_case(self, username: str, case_key: str) -> dict | None:
        return _serialize_document(
            self.collection.find_one(
                {"usuario": username, "case_key": case_key},
                sort=[("updated_at", -1)],
            )
        )

    def update_case(self, batch_id: str, case_key: str, payload: dict) -> None:
        self.collection.update_one(
            {"batch_id": batch_id, "case_key": case_key},
            {"$set": payload},
        )


class LocalBatchArchiveStore:
    """Guarda el ZIP original localmente para que el worker lo procese después."""

    def __init__(self, base_dir: Path) -> None:
        self.base_dir = Path(base_dir) / "_batch_uploads"
        self.base_dir.mkdir(parents=True, exist_ok=True)

    def save_archive(self, batch_id: str, filename: str, data: bytes) -> str:
        batch_dir = self.base_dir / batch_id
        batch_dir.mkdir(parents=True, exist_ok=True)
        archive_path = batch_dir / filename
        archive_path.write_bytes(data)
        return str(archive_path)
