import logging
from typing import Optional, Tuple
from urllib.parse import urlparse

import grpc
from open_webui.config import (
    QDRANT_API_KEY,
    QDRANT_GRPC_PORT,
    QDRANT_ON_DISK,
    QDRANT_PREFER_GRPC,
    QDRANT_URI,
)
from open_webui.env import SRC_LOG_LEVELS
from open_webui.retrieval.vector.main import (
    GetResult,
    SearchResult,
    VectorDBBase,
    VectorItem,
)
from qdrant_client import QdrantClient as Qclient
from qdrant_client.http.exceptions import UnexpectedResponse
from qdrant_client.http.models import PointStruct
from qdrant_client.models import models

NO_LIMIT = 999999999

log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])


class QdrantClient(VectorDBBase):
    def __init__(self):
        self.collection_prefix = "open-webui"
        self.QDRANT_URI = QDRANT_URI
        self.QDRANT_API_KEY = QDRANT_API_KEY
        self.QDRANT_ON_DISK = QDRANT_ON_DISK
        self.PREFER_GRPC = QDRANT_PREFER_GRPC
        self.GRPC_PORT = QDRANT_GRPC_PORT

        if not self.QDRANT_URI:
            self.client = None
            return

        # Unified handling for either scheme
        parsed = urlparse(self.QDRANT_URI)
        host = parsed.hostname or self.QDRANT_URI
        http_port = parsed.port or 6333  # default REST port

        if self.PREFER_GRPC:
            self.client = Qclient(
                host=host,
                port=http_port,
                grpc_port=self.GRPC_PORT,
                prefer_grpc=self.PREFER_GRPC,
                api_key=self.QDRANT_API_KEY,
            )
        else:
            self.client = Qclient(url=self.QDRANT_URI, api_key=self.QDRANT_API_KEY)

        # Main collection types for multi-tenancy
        self.MEMORY_COLLECTION = f"{self.collection_prefix}_memories"
        self.KNOWLEDGE_COLLECTION = f"{self.collection_prefix}_knowledge"
        self.FILE_COLLECTION = f"{self.collection_prefix}_files"
        self.WEB_SEARCH_COLLECTION = f"{self.collection_prefix}_web-search"
        self.HASH_BASED_COLLECTION = f"{self.collection_prefix}_hash-based"

    def _result_to_get_result(self, points) -> GetResult:
        ids = []
        documents = []
        metadatas = []

        for point in points:
            payload = point.payload
            ids.append(point.id)
            documents.append(payload["text"])
            metadatas.append(payload["metadata"])

        return GetResult(
            **{
                "ids": [ids],
                "documents": [documents],
                "metadatas": [metadatas],
            }
        )

    def _get_collection_and_tenant_id(self, collection_name: str) -> Tuple[str, str]:
        """
        Maps the traditional collection name to multi-tenant collection and tenant ID.

        Returns:
            tuple: (collection_name, tenant_id)
        """
        # Check for user memory collections
        tenant_id = collection_name

        if collection_name.startswith("user-memory-"):
            return self.MEMORY_COLLECTION, tenant_id

        # Check for file collections
        elif collection_name.startswith("file-"):
            return self.FILE_COLLECTION, tenant_id

        # Check for web search collections
        elif collection_name.startswith("web-search-"):
            return self.WEB_SEARCH_COLLECTION, tenant_id

        # Handle hash-based collections (YouTube and web URLs)
        elif len(collection_name) == 63 and all(
            c in "0123456789abcdef" for c in collection_name
        ):
            return self.HASH_BASED_COLLECTION, tenant_id

        else:
            return self.KNOWLEDGE_COLLECTION, tenant_id

    def _extract_error_message(self, exception):
        """
        Extract error message from either HTTP or gRPC exceptions

        Returns:
            tuple: (status_code, error_message)
        """
        # Check if it's an HTTP exception
        if isinstance(exception, UnexpectedResponse):
            try:
                error_data = exception.structured()
                error_msg = error_data.get("status", {}).get("error", "")
                return exception.status_code, error_msg
            except Exception as inner_e:
                log.error(f"Failed to parse HTTP error: {inner_e}")
                return exception.status_code, str(exception)

        # Check if it's a gRPC exception
        elif isinstance(exception, grpc.RpcError):
            # Extract status code from gRPC error
            status_code = None
            if hasattr(exception, "code") and callable(exception.code):
                status_code = exception.code().value[0]

            # Extract error message
            error_msg = str(exception)
            if "details =" in error_msg:
                # Parse the details line which contains the actual error message
                try:
                    details_line = [
                        line.strip()
                        for line in error_msg.split("\n")
                        if "details =" in line
                    ][0]
                    error_msg = details_line.split("details =")[1].strip(' "')
                except (IndexError, AttributeError):
                    # Fall back to full message if parsing fails
                    pass

            return status_code, error_msg

        # For any other type of exception
        return None, str(exception)

    def _is_collection_not_found_error(self, exception):
        """
        Check if the exception is due to collection not found, supporting both HTTP and gRPC
        """
        status_code, error_msg = self._extract_error_message(exception)

        # HTTP error (404)
        if (
            status_code == 404
            and "Collection" in error_msg
            and "doesn't exist" in error_msg
        ):
            return True

        # gRPC error (NOT_FOUND status)
        if (
            isinstance(exception, grpc.RpcError)
            and exception.code() == grpc.StatusCode.NOT_FOUND
        ):
            return True

        return False

    def _is_dimension_mismatch_error(self, exception):
        """
        Check if the exception is due to dimension mismatch, supporting both HTTP and gRPC
        """
        status_code, error_msg = self._extract_error_message(exception)

        # Common patterns in both HTTP and gRPC
        return (
            "Vector dimension error" in error_msg
            or "dimensions mismatch" in error_msg
            or "invalid vector size" in error_msg
        )

    def _create_multi_tenant_collection_if_not_exists(
        self, mt_collection_name: str, dimension: int = 384
    ):
        """
        Creates a collection with multi-tenancy configuration if it doesn't exist.
        Default dimension is set to 384 which corresponds to 'sentence-transformers/all-MiniLM-L6-v2'.
        When creating collections dynamically (insert/upsert), the actual vector dimensions will be used.
        """
        try:
            # Try to create the collection directly - will fail if it already exists
            self.client.create_collection(
                collection_name=mt_collection_name,
                vectors_config=models.VectorParams(
                    size=dimension,
                    distance=models.Distance.COSINE,
                    on_disk=self.QDRANT_ON_DISK,
                ),
                hnsw_config=models.HnswConfigDiff(
                    payload_m=16,  # Enable per-tenant indexing
                    m=0,
                    on_disk=self.QDRANT_ON_DISK,
                ),
            )

            # Create tenant ID payload index
            self.client.create_payload_index(
                collection_name=mt_collection_name,
                field_name="tenant_id",
                field_schema=models.KeywordIndexParams(
                    type=models.KeywordIndexType.KEYWORD,
                    is_tenant=True,
                    on_disk=self.QDRANT_ON_DISK,
                ),
                wait=True,
            )

            log.info(
                f"Multi-tenant collection {mt_collection_name} created with dimension {dimension}!"
            )
        except (UnexpectedResponse, grpc.RpcError) as e:
            # Check for the specific error indicating collection already exists
            status_code, error_msg = self._extract_error_message(e)

            # HTTP status code 409 or gRPC ALREADY_EXISTS
            if (isinstance(e, UnexpectedResponse) and status_code == 409) or (
                isinstance(e, grpc.RpcError)
                and e.code() == grpc.StatusCode.ALREADY_EXISTS
            ):
                if "already exists" in error_msg:
                    log.debug(f"Collection {mt_collection_name} already exists")
                    return
            # If it's not an already exists error, re-raise
            raise e
        except Exception as e:
            raise e

    def _create_points(self, items: list[VectorItem], tenant_id: str):
        """
        Create point structs from vector items with tenant ID.
        """
        return [
            PointStruct(
                id=item["id"],
                vector=item["vector"],
                payload={
                    "text": item["text"],
                    "metadata": item["metadata"],
                    "tenant_id": tenant_id,
                },
            )
            for item in items
        ]

    def has_collection(self, collection_name: str) -> bool:
        """
        Check if a logical collection exists by checking for any points with the tenant ID.
        """
        if not self.client:
            return False

        # Map to multi-tenant collection and tenant ID
        mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)

        # Create tenant filter
        tenant_filter = models.FieldCondition(
            key="tenant_id", match=models.MatchValue(value=tenant_id)
        )

        try:
            # Try directly querying - most of the time collection should exist
            response = self.client.query_points(
                collection_name=mt_collection,
                query_filter=models.Filter(must=[tenant_filter]),
                limit=1,
            )

            # Collection exists with this tenant ID if there are points
            return len(response.points) > 0
        except (UnexpectedResponse, grpc.RpcError) as e:
            if self._is_collection_not_found_error(e):
                log.debug(f"Collection {mt_collection} doesn't exist")
                return False
            else:
                # For other API errors, log and return False
                _, error_msg = self._extract_error_message(e)
                log.warning(f"Unexpected Qdrant error: {error_msg}")
                return False
        except Exception as e:
            # For any other errors, log and return False
            log.debug(f"Error checking collection {mt_collection}: {e}")
            return False

    def delete(
        self,
        collection_name: str,
        ids: Optional[list[str]] = None,
        filter: Optional[dict] = None,
    ):
        """
        Delete vectors by ID or filter from a collection with tenant isolation.
        """
        if not self.client:
            return None

        # Map to multi-tenant collection and tenant ID
        mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)

        # Create tenant filter
        tenant_filter = models.FieldCondition(
            key="tenant_id", match=models.MatchValue(value=tenant_id)
        )

        must_conditions = [tenant_filter]
        should_conditions = []

        if ids:
            for id_value in ids:
                should_conditions.append(
                    models.FieldCondition(
                        key="metadata.id",
                        match=models.MatchValue(value=id_value),
                    ),
                )
        elif filter:
            for key, value in filter.items():
                must_conditions.append(
                    models.FieldCondition(
                        key=f"metadata.{key}",
                        match=models.MatchValue(value=value),
                    ),
                )

        try:
            # Try to delete directly - most of the time collection should exist
            update_result = self.client.delete(
                collection_name=mt_collection,
                points_selector=models.FilterSelector(
                    filter=models.Filter(must=must_conditions, should=should_conditions)
                ),
            )

            return update_result
        except (UnexpectedResponse, grpc.RpcError) as e:
            if self._is_collection_not_found_error(e):
                log.debug(
                    f"Collection {mt_collection} doesn't exist, nothing to delete"
                )
                return None
            else:
                # For other API errors, log and re-raise
                _, error_msg = self._extract_error_message(e)
                log.warning(f"Unexpected Qdrant error: {error_msg}")
                raise
        except Exception as e:
            # For non-Qdrant exceptions, re-raise
            raise

    def search(
        self, collection_name: str, vectors: list[list[float | int]], limit: int
    ) -> Optional[SearchResult]:
        """
        Search for the nearest neighbor items based on the vectors with tenant isolation.
        """
        if not self.client:
            return None

        # Map to multi-tenant collection and tenant ID
        mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)

        # Get the vector dimension from the query vector
        dimension = len(vectors[0]) if vectors and len(vectors) > 0 else None

        try:
            # Try the search operation directly - most of the time collection should exist

            # Create tenant filter
            tenant_filter = models.FieldCondition(
                key="tenant_id", match=models.MatchValue(value=tenant_id)
            )

            # Ensure vector dimensions match the collection
            collection_dim = self.client.get_collection(
                mt_collection
            ).config.params.vectors.size

            if collection_dim != dimension:
                if collection_dim < dimension:
                    vectors = [vector[:collection_dim] for vector in vectors]
                else:
                    vectors = [
                        vector + [0] * (collection_dim - dimension)
                        for vector in vectors
                    ]

            # Search with tenant filter
            prefetch_query = models.Prefetch(
                filter=models.Filter(must=[tenant_filter]),
                limit=NO_LIMIT,
            )
            query_response = self.client.query_points(
                collection_name=mt_collection,
                query=vectors[0],
                prefetch=prefetch_query,
                limit=limit,
            )

            get_result = self._result_to_get_result(query_response.points)
            return SearchResult(
                ids=get_result.ids,
                documents=get_result.documents,
                metadatas=get_result.metadatas,
                # qdrant distance is [-1, 1], normalize to [0, 1]
                distances=[
                    [(point.score + 1.0) / 2.0 for point in query_response.points]
                ],
            )
        except (UnexpectedResponse, grpc.RpcError) as e:
            if self._is_collection_not_found_error(e):
                log.debug(
                    f"Collection {mt_collection} doesn't exist, search returns None"
                )
                return None
            else:
                # For other API errors, log and re-raise
                _, error_msg = self._extract_error_message(e)
                log.warning(f"Unexpected Qdrant error during search: {error_msg}")
                raise
        except Exception as e:
            # For non-Qdrant exceptions, log and return None
            log.exception(f"Error searching collection '{collection_name}': {e}")
            return None

    def query(self, collection_name: str, filter: dict, limit: Optional[int] = None):
        """
        Query points with filters and tenant isolation.
        """
        if not self.client:
            return None

        # Map to multi-tenant collection and tenant ID
        mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)

        # Set default limit if not provided
        if limit is None:
            limit = NO_LIMIT

        # Create tenant filter
        tenant_filter = models.FieldCondition(
            key="tenant_id", match=models.MatchValue(value=tenant_id)
        )

        # Create metadata filters
        field_conditions = []
        for key, value in filter.items():
            field_conditions.append(
                models.FieldCondition(
                    key=f"metadata.{key}", match=models.MatchValue(value=value)
                )
            )

        # Combine tenant filter with metadata filters
        combined_filter = models.Filter(must=[tenant_filter, *field_conditions])

        try:
            # Try the query directly - most of the time collection should exist
            points = self.client.query_points(
                collection_name=mt_collection,
                query_filter=combined_filter,
                limit=limit,
            )

            return self._result_to_get_result(points.points)
        except (UnexpectedResponse, grpc.RpcError) as e:
            if self._is_collection_not_found_error(e):
                log.debug(
                    f"Collection {mt_collection} doesn't exist, query returns None"
                )
                return None
            else:
                # For other API errors, log and re-raise
                _, error_msg = self._extract_error_message(e)
                log.warning(f"Unexpected Qdrant error during query: {error_msg}")
                raise
        except Exception as e:
            # For non-Qdrant exceptions, log and re-raise
            log.exception(f"Error querying collection '{collection_name}': {e}")
            return None

    def get(self, collection_name: str) -> Optional[GetResult]:
        """
        Get all items in a collection with tenant isolation.
        """
        if not self.client:
            return None

        # Map to multi-tenant collection and tenant ID
        mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)

        # Create tenant filter
        tenant_filter = models.FieldCondition(
            key="tenant_id", match=models.MatchValue(value=tenant_id)
        )

        try:
            # Try to get points directly - most of the time collection should exist
            points = self.client.query_points(
                collection_name=mt_collection,
                query_filter=models.Filter(must=[tenant_filter]),
                limit=NO_LIMIT,
            )

            return self._result_to_get_result(points.points)
        except (UnexpectedResponse, grpc.RpcError) as e:
            if self._is_collection_not_found_error(e):
                log.debug(f"Collection {mt_collection} doesn't exist, get returns None")
                return None
            else:
                # For other API errors, log and re-raise
                _, error_msg = self._extract_error_message(e)
                log.warning(f"Unexpected Qdrant error during get: {error_msg}")
                raise
        except Exception as e:
            # For non-Qdrant exceptions, log and return None
            log.exception(f"Error getting collection '{collection_name}': {e}")
            return None

    def _handle_operation_with_error_retry(
        self, operation_name, mt_collection, points, dimension
    ):
        """
        Private helper to handle common error cases for insert and upsert operations.

        Args:
            operation_name: 'insert' or 'upsert'
            mt_collection: The multi-tenant collection name
            points: The vector points to insert/upsert
            dimension: The dimension of the vectors

        Returns:
            The operation result (for upsert) or None (for insert)
        """
        try:
            if operation_name == "insert":
                self.client.upload_points(mt_collection, points)
                return None
            else:  # upsert
                return self.client.upsert(mt_collection, points)
        except (UnexpectedResponse, grpc.RpcError) as e:
            # Handle collection not found
            if self._is_collection_not_found_error(e):
                log.info(
                    f"Collection {mt_collection} doesn't exist. Creating it with dimension {dimension}."
                )
                # Create collection with correct dimensions from our vectors
                self._create_multi_tenant_collection_if_not_exists(
                    mt_collection_name=mt_collection, dimension=dimension
                )
                # Try operation again - no need for dimension adjustment since we just created with correct dimensions
                if operation_name == "insert":
                    self.client.upload_points(mt_collection, points)
                    return None
                else:  # upsert
                    return self.client.upsert(mt_collection, points)

            # Handle dimension mismatch
            elif self._is_dimension_mismatch_error(e):
                # For dimension errors, the collection must exist, so get its configuration
                mt_collection_info = self.client.get_collection(mt_collection)
                existing_size = mt_collection_info.config.params.vectors.size

                log.info(
                    f"Dimension mismatch: Collection {mt_collection} expects {existing_size}, got {dimension}"
                )

                if existing_size < dimension:
                    # Truncate vectors to fit
                    log.info(
                        f"Truncating vectors from {dimension} to {existing_size} dimensions"
                    )
                    points = [
                        PointStruct(
                            id=point.id,
                            vector=point.vector[:existing_size],
                            payload=point.payload,
                        )
                        for point in points
                    ]
                elif existing_size > dimension:
                    # Pad vectors with zeros
                    log.info(
                        f"Padding vectors from {dimension} to {existing_size} dimensions with zeros"
                    )
                    points = [
                        PointStruct(
                            id=point.id,
                            vector=point.vector
                            + [0] * (existing_size - len(point.vector)),
                            payload=point.payload,
                        )
                        for point in points
                    ]
                # Try operation again with adjusted dimensions
                if operation_name == "insert":
                    self.client.upload_points(mt_collection, points)
                    return None
                else:  # upsert
                    return self.client.upsert(mt_collection, points)
            else:
                # Not a known error we can handle, log and re-raise
                _, error_msg = self._extract_error_message(e)
                log.warning(f"Unhandled Qdrant error: {error_msg}")
                raise
        except Exception as e:
            # For non-Qdrant exceptions, re-raise
            raise

    def insert(self, collection_name: str, items: list[VectorItem]):
        """
        Insert items with tenant ID.
        """
        if not self.client or not items:
            return None

        # Map to multi-tenant collection and tenant ID
        mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)

        # Get dimensions from the actual vectors
        dimension = len(items[0]["vector"]) if items else None

        # Create points with tenant ID
        points = self._create_points(items, tenant_id)

        # Handle the operation with error retry
        return self._handle_operation_with_error_retry(
            "insert", mt_collection, points, dimension
        )

    def upsert(self, collection_name: str, items: list[VectorItem]):
        """
        Upsert items with tenant ID.
        """
        if not self.client or not items:
            return None

        # Map to multi-tenant collection and tenant ID
        mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)

        # Get dimensions from the actual vectors
        dimension = len(items[0]["vector"]) if items else None

        # Create points with tenant ID
        points = self._create_points(items, tenant_id)

        # Handle the operation with error retry
        return self._handle_operation_with_error_retry(
            "upsert", mt_collection, points, dimension
        )

    def reset(self):
        """
        Reset the database by deleting all collections.
        """
        if not self.client:
            return None

        collection_names = self.client.get_collections().collections
        for collection_name in collection_names:
            if collection_name.name.startswith(self.collection_prefix):
                self.client.delete_collection(collection_name=collection_name.name)

    def delete_collection(self, collection_name: str):
        """
        Delete a collection.
        """
        if not self.client:
            return None

        # Map to multi-tenant collection and tenant ID
        mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)

        tenant_filter = models.FieldCondition(
            key="tenant_id", match=models.MatchValue(value=tenant_id)
        )

        field_conditions = [tenant_filter]

        update_result = self.client.delete(
            collection_name=mt_collection,
            points_selector=models.FilterSelector(
                filter=models.Filter(must=field_conditions)
            ),
        )

        if self.client.get_collection(mt_collection).points_count == 0:
            self.client.delete_collection(mt_collection)

        return update_result