mirror of
https://github.com/open-webui/open-webui
synced 2025-06-16 19:31:52 +00:00
Update pinecone.py
This commit is contained in:
parent
04b9065f08
commit
b38711a581
@ -3,47 +3,6 @@ import logging
|
|||||||
import asyncio
|
import asyncio
|
||||||
from pinecone import Pinecone, ServerlessSpec
|
from pinecone import Pinecone, ServerlessSpec
|
||||||
|
|
||||||
# Helper for building consistent metadata
|
|
||||||
def build_metadata(
|
|
||||||
*,
|
|
||||||
source: str,
|
|
||||||
type_: str,
|
|
||||||
user_id: str,
|
|
||||||
chat_id: Optional[str] = None,
|
|
||||||
filename: Optional[str] = None,
|
|
||||||
text: Optional[str] = None,
|
|
||||||
topic: Optional[str] = None,
|
|
||||||
model: Optional[str] = None,
|
|
||||||
vector_dim: Optional[int] = None,
|
|
||||||
extra: Optional[Dict[str, Any]] = None,
|
|
||||||
collection_name: Optional[str] = None,
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
metadata = {
|
|
||||||
"source": source,
|
|
||||||
"type": type_,
|
|
||||||
"user_id": user_id,
|
|
||||||
"timestamp": datetime.utcnow().isoformat() + "Z",
|
|
||||||
}
|
|
||||||
if chat_id:
|
|
||||||
metadata["chat_id"] = chat_id
|
|
||||||
if filename:
|
|
||||||
metadata["filename"] = filename
|
|
||||||
if text:
|
|
||||||
metadata["text"] = text
|
|
||||||
if topic:
|
|
||||||
metadata["topic"] = topic
|
|
||||||
if model:
|
|
||||||
metadata["model"] = model
|
|
||||||
if vector_dim:
|
|
||||||
metadata["vector_dim"] = vector_dim
|
|
||||||
if collection_name:
|
|
||||||
metadata["collection_name"] = collection_name
|
|
||||||
if extra:
|
|
||||||
metadata.update(extra)
|
|
||||||
return metadata
|
|
||||||
|
|
||||||
from open_webui.retrieval.vector.main import (
|
from open_webui.retrieval.vector.main import (
|
||||||
VectorDBBase,
|
VectorDBBase,
|
||||||
VectorItem,
|
VectorItem,
|
||||||
@ -61,7 +20,7 @@ from open_webui.config import (
|
|||||||
from open_webui.env import SRC_LOG_LEVELS
|
from open_webui.env import SRC_LOG_LEVELS
|
||||||
|
|
||||||
NO_LIMIT = 10000 # Reasonable limit to avoid overwhelming the system
|
NO_LIMIT = 10000 # Reasonable limit to avoid overwhelming the system
|
||||||
BATCH_SIZE = 100 # Recommended batch size for Pinecone operations
|
BATCH_SIZE = 200 # Recommended batch size for Pinecone operations
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||||
@ -69,8 +28,7 @@ log.setLevel(SRC_LOG_LEVELS["RAG"])
|
|||||||
|
|
||||||
class PineconeClient(VectorDBBase):
|
class PineconeClient(VectorDBBase):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
from open_webui.config import PINECONE_NAMESPACE
|
self.collection_prefix = "open-webui"
|
||||||
self.namespace = PINECONE_NAMESPACE
|
|
||||||
|
|
||||||
# Validate required configuration
|
# Validate required configuration
|
||||||
self._validate_config()
|
self._validate_config()
|
||||||
@ -137,32 +95,15 @@ class PineconeClient(VectorDBBase):
|
|||||||
"""Convert VectorItem objects to Pinecone point format."""
|
"""Convert VectorItem objects to Pinecone point format."""
|
||||||
points = []
|
points = []
|
||||||
for item in items:
|
for item in items:
|
||||||
user_id = item.get("metadata", {}).get("created_by", "unknown")
|
# Start with any existing metadata or an empty dict
|
||||||
chat_id = item.get("metadata", {}).get("chat_id")
|
metadata = item.get("metadata", {}).copy() if item.get("metadata") else {}
|
||||||
filename = item.get("metadata", {}).get("name")
|
|
||||||
text = item.get("text")
|
|
||||||
model = item.get("metadata", {}).get("model")
|
|
||||||
topic = item.get("metadata", {}).get("topic")
|
|
||||||
|
|
||||||
# Infer source from filename or fallback
|
# Add text to metadata if available
|
||||||
raw_source = item.get("metadata", {}).get("source", "")
|
if "text" in item:
|
||||||
inferred_source = "knowledge"
|
metadata["text"] = item["text"]
|
||||||
if raw_source == filename or (isinstance(raw_source, str) and raw_source.endswith((".pdf", ".txt", ".docx"))):
|
|
||||||
inferred_source = "chat" if item.get("metadata", {}).get("created_by") else "knowledge"
|
|
||||||
else:
|
|
||||||
inferred_source = raw_source or "knowledge"
|
|
||||||
|
|
||||||
metadata = build_metadata(
|
# Always add collection_name to metadata for filtering
|
||||||
source=inferred_source,
|
metadata["collection_name"] = collection_name_with_prefix
|
||||||
type_="upload",
|
|
||||||
user_id=user_id,
|
|
||||||
chat_id=chat_id,
|
|
||||||
filename=filename,
|
|
||||||
text=text,
|
|
||||||
model=model,
|
|
||||||
topic=topic,
|
|
||||||
collection_name=collection_name_with_prefix,
|
|
||||||
)
|
|
||||||
|
|
||||||
point = {
|
point = {
|
||||||
"id": item["id"],
|
"id": item["id"],
|
||||||
@ -172,9 +113,9 @@ class PineconeClient(VectorDBBase):
|
|||||||
points.append(point)
|
points.append(point)
|
||||||
return points
|
return points
|
||||||
|
|
||||||
def _get_namespace(self) -> str:
|
def _get_collection_name_with_prefix(self, collection_name: str) -> str:
|
||||||
"""Get the namespace from the environment variable."""
|
"""Get the collection name with prefix."""
|
||||||
return self.namespace
|
return f"{self.collection_prefix}_{collection_name}"
|
||||||
|
|
||||||
def _normalize_distance(self, score: float) -> float:
|
def _normalize_distance(self, score: float) -> float:
|
||||||
"""Normalize distance score based on the metric used."""
|
"""Normalize distance score based on the metric used."""
|
||||||
@ -210,7 +151,9 @@ class PineconeClient(VectorDBBase):
|
|||||||
|
|
||||||
def has_collection(self, collection_name: str) -> bool:
|
def has_collection(self, collection_name: str) -> bool:
|
||||||
"""Check if a collection exists by searching for at least one item."""
|
"""Check if a collection exists by searching for at least one item."""
|
||||||
collection_name_with_prefix = self._get_namespace()
|
collection_name_with_prefix = self._get_collection_name_with_prefix(
|
||||||
|
collection_name
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Search for at least 1 item with this collection name in metadata
|
# Search for at least 1 item with this collection name in metadata
|
||||||
@ -229,7 +172,9 @@ class PineconeClient(VectorDBBase):
|
|||||||
|
|
||||||
def delete_collection(self, collection_name: str) -> None:
|
def delete_collection(self, collection_name: str) -> None:
|
||||||
"""Delete a collection by removing all vectors with the collection name in metadata."""
|
"""Delete a collection by removing all vectors with the collection name in metadata."""
|
||||||
collection_name_with_prefix = self._get_namespace()
|
collection_name_with_prefix = self._get_collection_name_with_prefix(
|
||||||
|
collection_name
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
self.index.delete(filter={"collection_name": collection_name_with_prefix})
|
self.index.delete(filter={"collection_name": collection_name_with_prefix})
|
||||||
log.info(
|
log.info(
|
||||||
@ -248,7 +193,9 @@ class PineconeClient(VectorDBBase):
|
|||||||
log.warning("No items to insert")
|
log.warning("No items to insert")
|
||||||
return
|
return
|
||||||
|
|
||||||
collection_name_with_prefix = self._get_namespace()
|
collection_name_with_prefix = self._get_collection_name_with_prefix(
|
||||||
|
collection_name
|
||||||
|
)
|
||||||
points = self._create_points(items, collection_name_with_prefix)
|
points = self._create_points(items, collection_name_with_prefix)
|
||||||
|
|
||||||
# Insert in batches for better performance and reliability
|
# Insert in batches for better performance and reliability
|
||||||
@ -276,7 +223,9 @@ class PineconeClient(VectorDBBase):
|
|||||||
log.warning("No items to upsert")
|
log.warning("No items to upsert")
|
||||||
return
|
return
|
||||||
|
|
||||||
collection_name_with_prefix = self._get_namespace()
|
collection_name_with_prefix = self._get_collection_name_with_prefix(
|
||||||
|
collection_name
|
||||||
|
)
|
||||||
points = self._create_points(items, collection_name_with_prefix)
|
points = self._create_points(items, collection_name_with_prefix)
|
||||||
|
|
||||||
# Upsert in batches
|
# Upsert in batches
|
||||||
@ -305,7 +254,9 @@ class PineconeClient(VectorDBBase):
|
|||||||
log.warning("No vectors provided for search")
|
log.warning("No vectors provided for search")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
collection_name_with_prefix = self._get_namespace()
|
collection_name_with_prefix = self._get_collection_name_with_prefix(
|
||||||
|
collection_name
|
||||||
|
)
|
||||||
|
|
||||||
if limit is None or limit <= 0:
|
if limit is None or limit <= 0:
|
||||||
limit = NO_LIMIT
|
limit = NO_LIMIT
|
||||||
@ -356,7 +307,9 @@ class PineconeClient(VectorDBBase):
|
|||||||
self, collection_name: str, filter: Dict, limit: Optional[int] = None
|
self, collection_name: str, filter: Dict, limit: Optional[int] = None
|
||||||
) -> Optional[GetResult]:
|
) -> Optional[GetResult]:
|
||||||
"""Query vectors by metadata filter."""
|
"""Query vectors by metadata filter."""
|
||||||
collection_name_with_prefix = self._get_namespace()
|
collection_name_with_prefix = self._get_collection_name_with_prefix(
|
||||||
|
collection_name
|
||||||
|
)
|
||||||
|
|
||||||
if limit is None or limit <= 0:
|
if limit is None or limit <= 0:
|
||||||
limit = NO_LIMIT
|
limit = NO_LIMIT
|
||||||
@ -386,7 +339,9 @@ class PineconeClient(VectorDBBase):
|
|||||||
|
|
||||||
def get(self, collection_name: str) -> Optional[GetResult]:
|
def get(self, collection_name: str) -> Optional[GetResult]:
|
||||||
"""Get all vectors in a collection."""
|
"""Get all vectors in a collection."""
|
||||||
collection_name_with_prefix = self._get_namespace()
|
collection_name_with_prefix = self._get_collection_name_with_prefix(
|
||||||
|
collection_name
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Use a zero vector for fetching all entries
|
# Use a zero vector for fetching all entries
|
||||||
@ -414,7 +369,9 @@ class PineconeClient(VectorDBBase):
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""Delete vectors by IDs or filter."""
|
"""Delete vectors by IDs or filter."""
|
||||||
import time
|
import time
|
||||||
collection_name_with_prefix = self._get_namespace()
|
collection_name_with_prefix = self._get_collection_name_with_prefix(
|
||||||
|
collection_name
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if ids:
|
if ids:
|
||||||
|
Loading…
Reference in New Issue
Block a user