from typing import Optional, List, Dict, Any, Union
import logging
from pinecone import Pinecone, ServerlessSpec

from open_webui.retrieval.vector.main import (
    VectorDBBase,
    VectorItem,
    SearchResult,
    GetResult,
)
from open_webui.config import (
    PINECONE_API_KEY,
    PINECONE_ENVIRONMENT,
    PINECONE_INDEX_NAME,
    PINECONE_DIMENSION,
    PINECONE_METRIC,
    PINECONE_CLOUD,
)
from open_webui.env import SRC_LOG_LEVELS

NO_LIMIT = 10000  # Reasonable limit to avoid overwhelming the system
BATCH_SIZE = 100  # Recommended batch size for Pinecone operations

log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])


class PineconeClient(VectorDBBase):
    def __init__(self):
        self.collection_prefix = "open-webui"

        # Validate required configuration
        self._validate_config()

        # Store configuration values
        self.api_key = PINECONE_API_KEY
        self.environment = PINECONE_ENVIRONMENT
        self.index_name = PINECONE_INDEX_NAME
        self.dimension = PINECONE_DIMENSION
        self.metric = PINECONE_METRIC
        self.cloud = PINECONE_CLOUD

        # Initialize Pinecone client
        self.client = Pinecone(api_key=self.api_key)

        # Create index if it doesn't exist
        self._initialize_index()

    def _validate_config(self) -> None:
        """Validate that all required configuration variables are set."""
        missing_vars = []
        if not PINECONE_API_KEY:
            missing_vars.append("PINECONE_API_KEY")
        if not PINECONE_ENVIRONMENT:
            missing_vars.append("PINECONE_ENVIRONMENT")
        if not PINECONE_INDEX_NAME:
            missing_vars.append("PINECONE_INDEX_NAME")
        if not PINECONE_DIMENSION:
            missing_vars.append("PINECONE_DIMENSION")
        if not PINECONE_CLOUD:
            missing_vars.append("PINECONE_CLOUD")

        if missing_vars:
            raise ValueError(
                f"Required configuration missing: {', '.join(missing_vars)}"
            )

    def _initialize_index(self) -> None:
        """Initialize the Pinecone index."""
        try:
            # Check if index exists
            if self.index_name not in self.client.list_indexes().names():
                log.info(f"Creating Pinecone index '{self.index_name}'...")
                self.client.create_index(
                    name=self.index_name,
                    dimension=self.dimension,
                    metric=self.metric,
                    spec=ServerlessSpec(cloud=self.cloud, region=self.environment),
                )
                log.info(f"Successfully created Pinecone index '{self.index_name}'")
            else:
                log.info(f"Using existing Pinecone index '{self.index_name}'")

            # Connect to the index
            self.index = self.client.Index(self.index_name)

        except Exception as e:
            log.error(f"Failed to initialize Pinecone index: {e}")
            raise RuntimeError(f"Failed to initialize Pinecone index: {e}")

    def _create_points(
        self, items: List[VectorItem], collection_name_with_prefix: str
    ) -> List[Dict[str, Any]]:
        """Convert VectorItem objects to Pinecone point format."""
        points = []
        for item in items:
            # Start with any existing metadata or an empty dict
            metadata = item.get("metadata", {}).copy() if item.get("metadata") else {}

            # Add text to metadata if available
            if "text" in item:
                metadata["text"] = item["text"]

            # Always add collection_name to metadata for filtering
            metadata["collection_name"] = collection_name_with_prefix

            point = {
                "id": item["id"],
                "values": item["vector"],
                "metadata": metadata,
            }
            points.append(point)
        return points

    def _get_collection_name_with_prefix(self, collection_name: str) -> str:
        """Get the collection name with prefix."""
        return f"{self.collection_prefix}_{collection_name}"

    def _normalize_distance(self, score: float) -> float:
        """Normalize distance score based on the metric used."""
        if self.metric.lower() == "cosine":
            # Cosine similarity ranges from -1 to 1, normalize to 0 to 1
            return (score + 1.0) / 2.0
        elif self.metric.lower() in ["euclidean", "dotproduct"]:
            # These are already suitable for ranking (smaller is better for Euclidean)
            return score
        else:
            # For other metrics, use as is
            return score

    def _result_to_get_result(self, matches: list) -> GetResult:
        """Convert Pinecone matches to GetResult format."""
        ids = []
        documents = []
        metadatas = []

        for match in matches:
            metadata = match.get("metadata", {})
            ids.append(match["id"])
            documents.append(metadata.get("text", ""))
            metadatas.append(metadata)

        return GetResult(
            **{
                "ids": [ids],
                "documents": [documents],
                "metadatas": [metadatas],
            }
        )

    def has_collection(self, collection_name: str) -> bool:
        """Check if a collection exists by searching for at least one item."""
        collection_name_with_prefix = self._get_collection_name_with_prefix(
            collection_name
        )

        try:
            # Search for at least 1 item with this collection name in metadata
            response = self.index.query(
                vector=[0.0] * self.dimension,  # dummy vector
                top_k=1,
                filter={"collection_name": collection_name_with_prefix},
                include_metadata=False,
            )
            return len(response.matches) > 0
        except Exception as e:
            log.exception(
                f"Error checking collection '{collection_name_with_prefix}': {e}"
            )
            return False

    def delete_collection(self, collection_name: str) -> None:
        """Delete a collection by removing all vectors with the collection name in metadata."""
        collection_name_with_prefix = self._get_collection_name_with_prefix(
            collection_name
        )
        try:
            self.index.delete(filter={"collection_name": collection_name_with_prefix})
            log.info(
                f"Collection '{collection_name_with_prefix}' deleted (all vectors removed)."
            )
        except Exception as e:
            log.warning(
                f"Failed to delete collection '{collection_name_with_prefix}': {e}"
            )
            raise

    def insert(self, collection_name: str, items: List[VectorItem]) -> None:
        """Insert vectors into a collection."""
        if not items:
            log.warning("No items to insert")
            return

        collection_name_with_prefix = self._get_collection_name_with_prefix(
            collection_name
        )
        points = self._create_points(items, collection_name_with_prefix)

        # Insert in batches for better performance and reliability
        for i in range(0, len(points), BATCH_SIZE):
            batch = points[i : i + BATCH_SIZE]
            try:
                self.index.upsert(vectors=batch)
                log.debug(
                    f"Inserted batch of {len(batch)} vectors into '{collection_name_with_prefix}'"
                )
            except Exception as e:
                log.error(
                    f"Error inserting batch into '{collection_name_with_prefix}': {e}"
                )
                raise

        log.info(
            f"Successfully inserted {len(items)} vectors into '{collection_name_with_prefix}'"
        )

    def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
        """Upsert (insert or update) vectors into a collection."""
        if not items:
            log.warning("No items to upsert")
            return

        collection_name_with_prefix = self._get_collection_name_with_prefix(
            collection_name
        )
        points = self._create_points(items, collection_name_with_prefix)

        # Upsert in batches
        for i in range(0, len(points), BATCH_SIZE):
            batch = points[i : i + BATCH_SIZE]
            try:
                self.index.upsert(vectors=batch)
                log.debug(
                    f"Upserted batch of {len(batch)} vectors into '{collection_name_with_prefix}'"
                )
            except Exception as e:
                log.error(
                    f"Error upserting batch into '{collection_name_with_prefix}': {e}"
                )
                raise

        log.info(
            f"Successfully upserted {len(items)} vectors into '{collection_name_with_prefix}'"
        )

    def search(
        self, collection_name: str, vectors: List[List[Union[float, int]]], limit: int
    ) -> Optional[SearchResult]:
        """Search for similar vectors in a collection."""
        if not vectors or not vectors[0]:
            log.warning("No vectors provided for search")
            return None

        collection_name_with_prefix = self._get_collection_name_with_prefix(
            collection_name
        )

        if limit is None or limit <= 0:
            limit = NO_LIMIT

        try:
            # Search using the first vector (assuming this is the intended behavior)
            query_vector = vectors[0]

            # Perform the search
            query_response = self.index.query(
                vector=query_vector,
                top_k=limit,
                include_metadata=True,
                filter={"collection_name": collection_name_with_prefix},
            )

            if not query_response.matches:
                # Return empty result if no matches
                return SearchResult(
                    ids=[[]],
                    documents=[[]],
                    metadatas=[[]],
                    distances=[[]],
                )

            # Convert to GetResult format
            get_result = self._result_to_get_result(query_response.matches)

            # Calculate normalized distances based on metric
            distances = [
                [
                    self._normalize_distance(match.score)
                    for match in query_response.matches
                ]
            ]

            return SearchResult(
                ids=get_result.ids,
                documents=get_result.documents,
                metadatas=get_result.metadatas,
                distances=distances,
            )
        except Exception as e:
            log.error(f"Error searching in '{collection_name_with_prefix}': {e}")
            return None

    def query(
        self, collection_name: str, filter: Dict, limit: Optional[int] = None
    ) -> Optional[GetResult]:
        """Query vectors by metadata filter."""
        collection_name_with_prefix = self._get_collection_name_with_prefix(
            collection_name
        )

        if limit is None or limit <= 0:
            limit = NO_LIMIT

        try:
            # Create a zero vector for the dimension as Pinecone requires a vector
            zero_vector = [0.0] * self.dimension

            # Combine user filter with collection_name
            pinecone_filter = {"collection_name": collection_name_with_prefix}
            if filter:
                pinecone_filter.update(filter)

            # Perform metadata-only query
            query_response = self.index.query(
                vector=zero_vector,
                filter=pinecone_filter,
                top_k=limit,
                include_metadata=True,
            )

            return self._result_to_get_result(query_response.matches)

        except Exception as e:
            log.error(f"Error querying collection '{collection_name}': {e}")
            return None

    def get(self, collection_name: str) -> Optional[GetResult]:
        """Get all vectors in a collection."""
        collection_name_with_prefix = self._get_collection_name_with_prefix(
            collection_name
        )

        try:
            # Use a zero vector for fetching all entries
            zero_vector = [0.0] * self.dimension

            # Add filter to only get vectors for this collection
            query_response = self.index.query(
                vector=zero_vector,
                top_k=NO_LIMIT,
                include_metadata=True,
                filter={"collection_name": collection_name_with_prefix},
            )

            return self._result_to_get_result(query_response.matches)

        except Exception as e:
            log.error(f"Error getting collection '{collection_name}': {e}")
            return None

    def delete(
        self,
        collection_name: str,
        ids: Optional[List[str]] = None,
        filter: Optional[Dict] = None,
    ) -> None:
        """Delete vectors by IDs or filter."""
        collection_name_with_prefix = self._get_collection_name_with_prefix(
            collection_name
        )

        try:
            if ids:
                # Delete by IDs (in batches for large deletions)
                for i in range(0, len(ids), BATCH_SIZE):
                    batch_ids = ids[i : i + BATCH_SIZE]
                    # Note: When deleting by ID, we can't filter by collection_name
                    # This is a limitation of Pinecone - be careful with ID uniqueness
                    self.index.delete(ids=batch_ids)
                    log.debug(
                        f"Deleted batch of {len(batch_ids)} vectors by ID from '{collection_name_with_prefix}'"
                    )
                log.info(
                    f"Successfully deleted {len(ids)} vectors by ID from '{collection_name_with_prefix}'"
                )

            elif filter:
                # Combine user filter with collection_name
                pinecone_filter = {"collection_name": collection_name_with_prefix}
                if filter:
                    pinecone_filter.update(filter)
                # Delete by metadata filter
                self.index.delete(filter=pinecone_filter)
                log.info(
                    f"Successfully deleted vectors by filter from '{collection_name_with_prefix}'"
                )

            else:
                log.warning("No ids or filter provided for delete operation")

        except Exception as e:
            log.error(f"Error deleting from collection '{collection_name}': {e}")
            raise

    def reset(self) -> None:
        """Reset the database by deleting all collections."""
        try:
            self.index.delete(delete_all=True)
            log.info("All vectors successfully deleted from the index.")
        except Exception as e:
            log.error(f"Failed to reset Pinecone index: {e}")
            raise