from pymilvus import MilvusClient as Client
from pymilvus import FieldSchema, DataType
import json
import logging
from typing import Optional
from open_webui.retrieval.vector.main import (
    VectorDBBase,
    VectorItem,
    SearchResult,
    GetResult,
)
from open_webui.config import (
    MILVUS_URI,
    MILVUS_DB,
    MILVUS_TOKEN,
    MILVUS_INDEX_TYPE,
    MILVUS_METRIC_TYPE,
    MILVUS_HNSW_M,
    MILVUS_HNSW_EFCONSTRUCTION,
    MILVUS_IVF_FLAT_NLIST,
)
from open_webui.env import SRC_LOG_LEVELS

log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])


class MilvusClient(VectorDBBase):
    def __init__(self):
        self.collection_prefix = "open_webui"
        if MILVUS_TOKEN is None:
            self.client = Client(uri=MILVUS_URI, db_name=MILVUS_DB)
        else:
            self.client = Client(uri=MILVUS_URI, db_name=MILVUS_DB, token=MILVUS_TOKEN)

    def _result_to_get_result(self, result) -> GetResult:
        ids = []
        documents = []
        metadatas = []
        for match in result:
            _ids = []
            _documents = []
            _metadatas = []
            for item in match:
                _ids.append(item.get("id"))
                _documents.append(item.get("data", {}).get("text"))
                _metadatas.append(item.get("metadata"))
            ids.append(_ids)
            documents.append(_documents)
            metadatas.append(_metadatas)
        return GetResult(
            **{
                "ids": ids,
                "documents": documents,
                "metadatas": metadatas,
            }
        )

    def _result_to_search_result(self, result) -> SearchResult:
        ids = []
        distances = []
        documents = []
        metadatas = []
        for match in result:
            _ids = []
            _distances = []
            _documents = []
            _metadatas = []
            for item in match:
                _ids.append(item.get("id"))
                # normalize milvus score from [-1, 1] to [0, 1] range
                # https://milvus.io/docs/de/metric.md
                _dist = (item.get("distance") + 1.0) / 2.0
                _distances.append(_dist)
                _documents.append(item.get("entity", {}).get("data", {}).get("text"))
                _metadatas.append(item.get("entity", {}).get("metadata"))
            ids.append(_ids)
            distances.append(_distances)
            documents.append(_documents)
            metadatas.append(_metadatas)
        return SearchResult(
            **{
                "ids": ids,
                "distances": distances,
                "documents": documents,
                "metadatas": metadatas,
            }
        )

    def _create_collection(self, collection_name: str, dimension: int):
        schema = self.client.create_schema(
            auto_id=False,
            enable_dynamic_field=True,
        )
        schema.add_field(
            field_name="id",
            datatype=DataType.VARCHAR,
            is_primary=True,
            max_length=65535,
        )
        schema.add_field(
            field_name="vector",
            datatype=DataType.FLOAT_VECTOR,
            dim=dimension,
            description="vector",
        )
        schema.add_field(field_name="data", datatype=DataType.JSON, description="data")
        schema.add_field(
            field_name="metadata", datatype=DataType.JSON, description="metadata"
        )

        index_params = self.client.prepare_index_params()

        # Use configurations from config.py
        index_type = MILVUS_INDEX_TYPE.upper()
        metric_type = MILVUS_METRIC_TYPE.upper()

        log.info(f"Using Milvus index type: {index_type}, metric type: {metric_type}")

        index_creation_params = {}
        if index_type == "HNSW":
            index_creation_params = {
                "M": MILVUS_HNSW_M,
                "efConstruction": MILVUS_HNSW_EFCONSTRUCTION,
            }
            log.info(f"HNSW params: {index_creation_params}")
        elif index_type == "IVF_FLAT":
            index_creation_params = {"nlist": MILVUS_IVF_FLAT_NLIST}
            log.info(f"IVF_FLAT params: {index_creation_params}")
        elif index_type in ["FLAT", "AUTOINDEX"]:
            log.info(f"Using {index_type} index with no specific build-time params.")
        else:
            log.warning(
                f"Unsupported MILVUS_INDEX_TYPE: '{index_type}'. "
                f"Supported types: HNSW, IVF_FLAT, FLAT, AUTOINDEX. "
                f"Milvus will use its default for the collection if this type is not directly supported for index creation."
            )
            # For unsupported types, pass the type directly to Milvus; it might handle it or use a default.
            # If Milvus errors out, the user needs to correct the MILVUS_INDEX_TYPE env var.

        index_params.add_index(
            field_name="vector",
            index_type=index_type,
            metric_type=metric_type,
            params=index_creation_params,
        )

        self.client.create_collection(
            collection_name=f"{self.collection_prefix}_{collection_name}",
            schema=schema,
            index_params=index_params,
        )
        log.info(
            f"Successfully created collection '{self.collection_prefix}_{collection_name}' with index type '{index_type}' and metric '{metric_type}'."
        )

    def has_collection(self, collection_name: str) -> bool:
        # Check if the collection exists based on the collection name.
        collection_name = collection_name.replace("-", "_")
        return self.client.has_collection(
            collection_name=f"{self.collection_prefix}_{collection_name}"
        )

    def delete_collection(self, collection_name: str):
        # Delete the collection based on the collection name.
        collection_name = collection_name.replace("-", "_")
        return self.client.drop_collection(
            collection_name=f"{self.collection_prefix}_{collection_name}"
        )

    def search(
        self, collection_name: str, vectors: list[list[float | int]], limit: int
    ) -> Optional[SearchResult]:
        # Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
        collection_name = collection_name.replace("-", "_")
        # For some index types like IVF_FLAT, search params like nprobe can be set.
        # Example: search_params = {"nprobe": 10} if using IVF_FLAT
        # For simplicity, not adding configurable search_params here, but could be extended.
        result = self.client.search(
            collection_name=f"{self.collection_prefix}_{collection_name}",
            data=vectors,
            limit=limit,
            output_fields=["data", "metadata"],
            # search_params=search_params # Potentially add later if needed
        )
        return self._result_to_search_result(result)

    def query(self, collection_name: str, filter: dict, limit: Optional[int] = None):
        # Construct the filter string for querying
        collection_name = collection_name.replace("-", "_")
        if not self.has_collection(collection_name):
            log.warning(
                f"Query attempted on non-existent collection: {self.collection_prefix}_{collection_name}"
            )
            return None
        filter_string = " && ".join(
            [
                f'metadata["{key}"] == {json.dumps(value)}'
                for key, value in filter.items()
            ]
        )
        max_limit = 16383  # The maximum number of records per request
        all_results = []
        if limit is None:
            # Milvus default limit for query if not specified is 16384, but docs mention iteration.
            # Let's set a practical high number if "all" is intended, or handle true pagination.
            # For now, if limit is None, we'll fetch in batches up to a very large number.
            # This part could be refined based on expected use cases for "get all".
            # For this function signature, None implies "as many as possible" up to Milvus limits.
            limit = (
                16384 * 10
            )  # A large number to signify fetching many, will be capped by actual data or max_limit per call.
            log.info(
                f"Limit not specified for query, fetching up to {limit} results in batches."
            )

        # Initialize offset and remaining to handle pagination
        offset = 0
        remaining = limit

        try:
            log.info(
                f"Querying collection {self.collection_prefix}_{collection_name} with filter: '{filter_string}', limit: {limit}"
            )
            # Loop until there are no more items to fetch or the desired limit is reached
            while remaining > 0:
                current_fetch = min(
                    max_limit, remaining if isinstance(remaining, int) else max_limit
                )
                log.debug(
                    f"Querying with offset: {offset}, current_fetch: {current_fetch}"
                )

                results = self.client.query(
                    collection_name=f"{self.collection_prefix}_{collection_name}",
                    filter=filter_string,
                    output_fields=[
                        "id",
                        "data",
                        "metadata",
                    ],  # Explicitly list needed fields. Vector not usually needed in query.
                    limit=current_fetch,
                    offset=offset,
                )

                if not results:
                    log.debug("No more results from query.")
                    break

                all_results.extend(results)
                results_count = len(results)
                log.debug(f"Fetched {results_count} results in this batch.")

                if isinstance(remaining, int):
                    remaining -= results_count

                offset += results_count

                # Break the loop if the results returned are less than the requested fetch count (means end of data)
                if results_count < current_fetch:
                    log.debug(
                        "Fetched less than requested, assuming end of results for this query."
                    )
                    break

            log.info(f"Total results from query: {len(all_results)}")
            return self._result_to_get_result([all_results])
        except Exception as e:
            log.exception(
                f"Error querying collection {self.collection_prefix}_{collection_name} with filter '{filter_string}' and limit {limit}: {e}"
            )
            return None

    def get(self, collection_name: str) -> Optional[GetResult]:
        # Get all the items in the collection. This can be very resource-intensive for large collections.
        collection_name = collection_name.replace("-", "_")
        log.warning(
            f"Fetching ALL items from collection '{self.collection_prefix}_{collection_name}'. This might be slow for large collections."
        )
        # Using query with a trivial filter to get all items.
        # This will use the paginated query logic.
        return self.query(collection_name=collection_name, filter={}, limit=None)

    def insert(self, collection_name: str, items: list[VectorItem]):
        # Insert the items into the collection, if the collection does not exist, it will be created.
        collection_name = collection_name.replace("-", "_")
        if not self.client.has_collection(
            collection_name=f"{self.collection_prefix}_{collection_name}"
        ):
            log.info(
                f"Collection {self.collection_prefix}_{collection_name} does not exist. Creating now."
            )
            if not items:
                log.error(
                    f"Cannot create collection {self.collection_prefix}_{collection_name} without items to determine dimension."
                )
                raise ValueError(
                    "Cannot create Milvus collection without items to determine vector dimension."
                )
            self._create_collection(
                collection_name=collection_name, dimension=len(items[0]["vector"])
            )

        log.info(
            f"Inserting {len(items)} items into collection {self.collection_prefix}_{collection_name}."
        )
        return self.client.insert(
            collection_name=f"{self.collection_prefix}_{collection_name}",
            data=[
                {
                    "id": item["id"],
                    "vector": item["vector"],
                    "data": {"text": item["text"]},
                    "metadata": item["metadata"],
                }
                for item in items
            ],
        )

    def upsert(self, collection_name: str, items: list[VectorItem]):
        # Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created.
        collection_name = collection_name.replace("-", "_")
        if not self.client.has_collection(
            collection_name=f"{self.collection_prefix}_{collection_name}"
        ):
            log.info(
                f"Collection {self.collection_prefix}_{collection_name} does not exist for upsert. Creating now."
            )
            if not items:
                log.error(
                    f"Cannot create collection {self.collection_prefix}_{collection_name} for upsert without items to determine dimension."
                )
                raise ValueError(
                    "Cannot create Milvus collection for upsert without items to determine vector dimension."
                )
            self._create_collection(
                collection_name=collection_name, dimension=len(items[0]["vector"])
            )

        log.info(
            f"Upserting {len(items)} items into collection {self.collection_prefix}_{collection_name}."
        )
        return self.client.upsert(
            collection_name=f"{self.collection_prefix}_{collection_name}",
            data=[
                {
                    "id": item["id"],
                    "vector": item["vector"],
                    "data": {"text": item["text"]},
                    "metadata": item["metadata"],
                }
                for item in items
            ],
        )

    def delete(
        self,
        collection_name: str,
        ids: Optional[list[str]] = None,
        filter: Optional[dict] = None,
    ):
        # Delete the items from the collection based on the ids or filter.
        collection_name = collection_name.replace("-", "_")
        if not self.has_collection(collection_name):
            log.warning(
                f"Delete attempted on non-existent collection: {self.collection_prefix}_{collection_name}"
            )
            return None

        if ids:
            log.info(
                f"Deleting items by IDs from {self.collection_prefix}_{collection_name}. IDs: {ids}"
            )
            return self.client.delete(
                collection_name=f"{self.collection_prefix}_{collection_name}",
                ids=ids,
            )
        elif filter:
            filter_string = " && ".join(
                [
                    f'metadata["{key}"] == {json.dumps(value)}'
                    for key, value in filter.items()
                ]
            )
            log.info(
                f"Deleting items by filter from {self.collection_prefix}_{collection_name}. Filter: {filter_string}"
            )
            return self.client.delete(
                collection_name=f"{self.collection_prefix}_{collection_name}",
                filter=filter_string,
            )
        else:
            log.warning(
                f"Delete operation on {self.collection_prefix}_{collection_name} called without IDs or filter. No action taken."
            )
            return None

    def reset(self):
        # Resets the database. This will delete all collections and item entries that match the prefix.
        log.warning(
            f"Resetting Milvus: Deleting all collections with prefix '{self.collection_prefix}'."
        )
        collection_names = self.client.list_collections()
        deleted_collections = []
        for collection_name_full in collection_names:
            if collection_name_full.startswith(self.collection_prefix):
                try:
                    self.client.drop_collection(collection_name=collection_name_full)
                    deleted_collections.append(collection_name_full)
                    log.info(f"Deleted collection: {collection_name_full}")
                except Exception as e:
                    log.error(f"Error deleting collection {collection_name_full}: {e}")
        log.info(f"Milvus reset complete. Deleted collections: {deleted_collections}")