diff --git a/backend/open_webui/retrieval/vector/dbs/pinecone.py b/backend/open_webui/retrieval/vector/dbs/pinecone.py index c921089b6..9f8abf460 100644 --- a/backend/open_webui/retrieval/vector/dbs/pinecone.py +++ b/backend/open_webui/retrieval/vector/dbs/pinecone.py @@ -1,13 +1,12 @@ from typing import Optional, List, Dict, Any, Union import logging import time # for measuring elapsed time -from pinecone import ServerlessSpec +from pinecone import Pinecone, ServerlessSpec import asyncio # for async upserts import functools # for partial binding in async tasks import concurrent.futures # for parallel batch upserts -from pinecone.grpc import PineconeGRPC # use gRPC client for faster upserts from open_webui.retrieval.vector.main import ( VectorDBBase, @@ -47,10 +46,8 @@ class PineconeClient(VectorDBBase): self.metric = PINECONE_METRIC self.cloud = PINECONE_CLOUD - # Initialize Pinecone gRPC client for improved performance - self.client = PineconeGRPC( - api_key=self.api_key, environment=self.environment, cloud=self.cloud - ) + # Initialize Pinecone client for improved performance + self.client = Pinecone(api_key=self.api_key) # Persistent executor for batch operations self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=5) @@ -147,8 +144,8 @@ class PineconeClient(VectorDBBase): metadatas = [] for match in matches: - metadata = match.get("metadata", {}) - ids.append(match["id"]) + metadata = getattr(match, "metadata", {}) or {} + ids.append(match.id if hasattr(match, "id") else match["id"]) documents.append(metadata.get("text", "")) metadatas.append(metadata) @@ -174,7 +171,8 @@ class PineconeClient(VectorDBBase): filter={"collection_name": collection_name_with_prefix}, include_metadata=False, ) - return len(response.matches) > 0 + matches = getattr(response, "matches", []) or [] + return len(matches) > 0 except Exception as e: log.exception( f"Error checking collection '{collection_name_with_prefix}': {e}" @@ -321,32 +319,6 @@ class PineconeClient(VectorDBBase): f"Successfully async upserted {len(points)} vectors in batches into '{collection_name_with_prefix}'" ) - def streaming_upsert(self, collection_name: str, items: List[VectorItem]) -> None: - """Perform a streaming upsert over gRPC for performance testing.""" - if not items: - log.warning("No items to upsert via streaming") - return - - collection_name_with_prefix = self._get_collection_name_with_prefix( - collection_name - ) - points = self._create_points(items, collection_name_with_prefix) - - # Open a streaming upsert channel - stream = self.index.streaming_upsert() - try: - for point in points: - # send each point over the stream - stream.send(point) - # close the stream to finalize - stream.close() - log.info( - f"Successfully streamed upsert of {len(points)} vectors into '{collection_name_with_prefix}'" - ) - except Exception as e: - log.error(f"Error during streaming upsert: {e}") - raise - def search( self, collection_name: str, vectors: List[List[Union[float, int]]], limit: int ) -> Optional[SearchResult]: @@ -374,7 +346,8 @@ class PineconeClient(VectorDBBase): filter={"collection_name": collection_name_with_prefix}, ) - if not query_response.matches: + matches = getattr(query_response, "matches", []) or [] + if not matches: # Return empty result if no matches return SearchResult( ids=[[]], @@ -384,13 +357,13 @@ class PineconeClient(VectorDBBase): ) # Convert to GetResult format - get_result = self._result_to_get_result(query_response.matches) + get_result = self._result_to_get_result(matches) # Calculate normalized distances based on metric distances = [ [ - self._normalize_distance(match.score) - for match in query_response.matches + self._normalize_distance(getattr(match, "score", 0.0)) + for match in matches ] ] @@ -432,7 +405,8 @@ class PineconeClient(VectorDBBase): include_metadata=True, ) - return self._result_to_get_result(query_response.matches) + matches = getattr(query_response, "matches", []) or [] + return self._result_to_get_result(matches) except Exception as e: log.error(f"Error querying collection '{collection_name}': {e}") @@ -456,7 +430,8 @@ class PineconeClient(VectorDBBase): filter={"collection_name": collection_name_with_prefix}, ) - return self._result_to_get_result(query_response.matches) + matches = getattr(query_response, "matches", []) or [] + return self._result_to_get_result(matches) except Exception as e: log.error(f"Error getting collection '{collection_name}': {e}") @@ -516,12 +491,12 @@ class PineconeClient(VectorDBBase): raise def close(self): - """Shut down the gRPC channel and thread pool.""" + """Shut down resources.""" try: - self.client.close() - log.info("Pinecone gRPC channel closed.") + # The new Pinecone client doesn't need explicit closing + pass except Exception as e: - log.warning(f"Failed to close Pinecone gRPC channel: {e}") + log.warning(f"Failed to clean up Pinecone resources: {e}") self._executor.shutdown(wait=True) def __enter__(self):