Merge pull request #14147 from PVBLIC-F/dev

perf Update pinecone.py
This commit is contained in:
Tim Jaeryang Baek 2025-05-22 12:19:14 +04:00 committed by GitHub
commit 0eda03bd3c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,13 +1,12 @@
from typing import Optional, List, Dict, Any, Union
import logging
import time # for measuring elapsed time
from pinecone import ServerlessSpec
from pinecone import Pinecone, ServerlessSpec
import asyncio # for async upserts
import functools # for partial binding in async tasks
import concurrent.futures # for parallel batch upserts
from pinecone.grpc import PineconeGRPC # use gRPC client for faster upserts
from open_webui.retrieval.vector.main import (
VectorDBBase,
@ -47,10 +46,8 @@ class PineconeClient(VectorDBBase):
self.metric = PINECONE_METRIC
self.cloud = PINECONE_CLOUD
# Initialize Pinecone gRPC client for improved performance
self.client = PineconeGRPC(
api_key=self.api_key, environment=self.environment, cloud=self.cloud
)
# Initialize Pinecone client for improved performance
self.client = Pinecone(api_key=self.api_key)
# Persistent executor for batch operations
self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=5)
@ -147,8 +144,8 @@ class PineconeClient(VectorDBBase):
metadatas = []
for match in matches:
metadata = match.get("metadata", {})
ids.append(match["id"])
metadata = getattr(match, "metadata", {}) or {}
ids.append(match.id if hasattr(match, "id") else match["id"])
documents.append(metadata.get("text", ""))
metadatas.append(metadata)
@ -174,7 +171,8 @@ class PineconeClient(VectorDBBase):
filter={"collection_name": collection_name_with_prefix},
include_metadata=False,
)
return len(response.matches) > 0
matches = getattr(response, "matches", []) or []
return len(matches) > 0
except Exception as e:
log.exception(
f"Error checking collection '{collection_name_with_prefix}': {e}"
@ -321,32 +319,6 @@ class PineconeClient(VectorDBBase):
f"Successfully async upserted {len(points)} vectors in batches into '{collection_name_with_prefix}'"
)
def streaming_upsert(self, collection_name: str, items: List[VectorItem]) -> None:
"""Perform a streaming upsert over gRPC for performance testing."""
if not items:
log.warning("No items to upsert via streaming")
return
collection_name_with_prefix = self._get_collection_name_with_prefix(
collection_name
)
points = self._create_points(items, collection_name_with_prefix)
# Open a streaming upsert channel
stream = self.index.streaming_upsert()
try:
for point in points:
# send each point over the stream
stream.send(point)
# close the stream to finalize
stream.close()
log.info(
f"Successfully streamed upsert of {len(points)} vectors into '{collection_name_with_prefix}'"
)
except Exception as e:
log.error(f"Error during streaming upsert: {e}")
raise
def search(
self, collection_name: str, vectors: List[List[Union[float, int]]], limit: int
) -> Optional[SearchResult]:
@ -374,7 +346,8 @@ class PineconeClient(VectorDBBase):
filter={"collection_name": collection_name_with_prefix},
)
if not query_response.matches:
matches = getattr(query_response, "matches", []) or []
if not matches:
# Return empty result if no matches
return SearchResult(
ids=[[]],
@ -384,13 +357,13 @@ class PineconeClient(VectorDBBase):
)
# Convert to GetResult format
get_result = self._result_to_get_result(query_response.matches)
get_result = self._result_to_get_result(matches)
# Calculate normalized distances based on metric
distances = [
[
self._normalize_distance(match.score)
for match in query_response.matches
self._normalize_distance(getattr(match, "score", 0.0))
for match in matches
]
]
@ -432,7 +405,8 @@ class PineconeClient(VectorDBBase):
include_metadata=True,
)
return self._result_to_get_result(query_response.matches)
matches = getattr(query_response, "matches", []) or []
return self._result_to_get_result(matches)
except Exception as e:
log.error(f"Error querying collection '{collection_name}': {e}")
@ -456,7 +430,8 @@ class PineconeClient(VectorDBBase):
filter={"collection_name": collection_name_with_prefix},
)
return self._result_to_get_result(query_response.matches)
matches = getattr(query_response, "matches", []) or []
return self._result_to_get_result(matches)
except Exception as e:
log.error(f"Error getting collection '{collection_name}': {e}")
@ -516,12 +491,12 @@ class PineconeClient(VectorDBBase):
raise
def close(self):
"""Shut down the gRPC channel and thread pool."""
"""Shut down resources."""
try:
self.client.close()
log.info("Pinecone gRPC channel closed.")
# The new Pinecone client doesn't need explicit closing
pass
except Exception as e:
log.warning(f"Failed to close Pinecone gRPC channel: {e}")
log.warning(f"Failed to clean up Pinecone resources: {e}")
self._executor.shutdown(wait=True)
def __enter__(self):