enh: kb metadata search
This commit is contained in:
@@ -69,7 +69,7 @@ class ChromaClient(VectorDBBase):
|
||||
return self.client.delete_collection(name=collection_name)
|
||||
|
||||
def search(
|
||||
self, collection_name: str, vectors: list[list[float | int]], limit: int
|
||||
self, collection_name: str, vectors: list[list[float | int]], filter: Optional[dict] = None, limit: int = 10
|
||||
) -> Optional[SearchResult]:
|
||||
# Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
|
||||
try:
|
||||
@@ -78,6 +78,7 @@ class ChromaClient(VectorDBBase):
|
||||
result = collection.query(
|
||||
query_embeddings=vectors,
|
||||
n_results=limit,
|
||||
where=filter,
|
||||
)
|
||||
|
||||
# chromadb has cosine distance, 2 (worst) -> 0 (best). Re-odering to 0 -> 1
|
||||
|
||||
@@ -153,7 +153,7 @@ class ElasticsearchClient(VectorDBBase):
|
||||
|
||||
# Status: works
|
||||
def search(
|
||||
self, collection_name: str, vectors: list[list[float]], limit: int
|
||||
self, collection_name: str, vectors: list[list[float]], filter: Optional[dict] = None, limit: int = 10
|
||||
) -> Optional[SearchResult]:
|
||||
query = {
|
||||
"size": limit,
|
||||
|
||||
@@ -179,7 +179,7 @@ class MilvusClient(VectorDBBase):
|
||||
)
|
||||
|
||||
def search(
|
||||
self, collection_name: str, vectors: list[list[float | int]], limit: int
|
||||
self, collection_name: str, vectors: list[list[float | int]], filter: Optional[dict] = None, limit: int = 10
|
||||
) -> Optional[SearchResult]:
|
||||
# Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
|
||||
collection_name = collection_name.replace("-", "_")
|
||||
|
||||
@@ -157,7 +157,7 @@ class MilvusClient(VectorDBBase):
|
||||
collection.insert(entities)
|
||||
|
||||
def search(
|
||||
self, collection_name: str, vectors: List[List[float]], limit: int
|
||||
self, collection_name: str, vectors: List[List[float]], filter: Optional[Dict] = None, limit: int = 10
|
||||
) -> Optional[SearchResult]:
|
||||
if not vectors:
|
||||
return None
|
||||
|
||||
@@ -233,7 +233,8 @@ class OpenGaussClient(VectorDBBase):
|
||||
self,
|
||||
collection_name: str,
|
||||
vectors: List[List[float]],
|
||||
limit: Optional[int] = None,
|
||||
filter: Optional[Dict[str, Any]] = None,
|
||||
limit: int = 10,
|
||||
) -> Optional[SearchResult]:
|
||||
try:
|
||||
if not vectors:
|
||||
|
||||
@@ -113,7 +113,7 @@ class OpenSearchClient(VectorDBBase):
|
||||
self.client.indices.delete(index=self._get_index_name(collection_name))
|
||||
|
||||
def search(
|
||||
self, collection_name: str, vectors: list[list[float | int]], limit: int
|
||||
self, collection_name: str, vectors: list[list[float | int]], filter: Optional[dict] = None, limit: int = 10
|
||||
) -> Optional[SearchResult]:
|
||||
try:
|
||||
if not self.has_collection(collection_name):
|
||||
|
||||
@@ -521,7 +521,7 @@ class Oracle23aiClient(VectorDBBase):
|
||||
raise
|
||||
|
||||
def search(
|
||||
self, collection_name: str, vectors: List[List[Union[float, int]]], limit: int
|
||||
self, collection_name: str, vectors: List[List[Union[float, int]]], filter: Optional[dict] = None, limit: int = 10
|
||||
) -> Optional[SearchResult]:
|
||||
"""
|
||||
Search for similar vectors in the database.
|
||||
|
||||
@@ -427,7 +427,8 @@ class PgvectorClient(VectorDBBase):
|
||||
self,
|
||||
collection_name: str,
|
||||
vectors: List[List[float]],
|
||||
limit: Optional[int] = None,
|
||||
filter: Optional[Dict[str, Any]] = None,
|
||||
limit: int = 10,
|
||||
) -> Optional[SearchResult]:
|
||||
try:
|
||||
if not vectors:
|
||||
@@ -475,9 +476,40 @@ class PgvectorClient(VectorDBBase):
|
||||
)
|
||||
|
||||
# Build the lateral subquery for each query vector
|
||||
where_clauses = [DocumentChunk.collection_name == collection_name]
|
||||
|
||||
# Apply metadata filter if provided
|
||||
if filter:
|
||||
for key, value in filter.items():
|
||||
if isinstance(value, dict) and "$in" in value:
|
||||
# Handle $in operator: {"field": {"$in": [values]}}
|
||||
in_values = value["$in"]
|
||||
if PGVECTOR_PGCRYPTO:
|
||||
where_clauses.append(
|
||||
pgcrypto_decrypt(
|
||||
DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB
|
||||
)[key].astext.in_([str(v) for v in in_values])
|
||||
)
|
||||
else:
|
||||
where_clauses.append(
|
||||
DocumentChunk.vmetadata[key].astext.in_([str(v) for v in in_values])
|
||||
)
|
||||
else:
|
||||
# Handle simple equality: {"field": "value"}
|
||||
if PGVECTOR_PGCRYPTO:
|
||||
where_clauses.append(
|
||||
pgcrypto_decrypt(
|
||||
DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB
|
||||
)[key].astext == str(value)
|
||||
)
|
||||
else:
|
||||
where_clauses.append(
|
||||
DocumentChunk.vmetadata[key].astext == str(value)
|
||||
)
|
||||
|
||||
subq = (
|
||||
select(*result_fields)
|
||||
.where(DocumentChunk.collection_name == collection_name)
|
||||
.where(*where_clauses)
|
||||
.order_by(
|
||||
(DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector))
|
||||
)
|
||||
|
||||
@@ -391,7 +391,7 @@ class PineconeClient(VectorDBBase):
|
||||
)
|
||||
|
||||
def search(
|
||||
self, collection_name: str, vectors: List[List[Union[float, int]]], limit: int
|
||||
self, collection_name: str, vectors: List[List[Union[float, int]]], filter: Optional[dict] = None, limit: int = 10
|
||||
) -> Optional[SearchResult]:
|
||||
"""Search for similar vectors in a collection."""
|
||||
if not vectors or not vectors[0]:
|
||||
|
||||
@@ -145,7 +145,7 @@ class QdrantClient(VectorDBBase):
|
||||
)
|
||||
|
||||
def search(
|
||||
self, collection_name: str, vectors: list[list[float | int]], limit: int
|
||||
self, collection_name: str, vectors: list[list[float | int]], filter: Optional[dict] = None, limit: int = 10
|
||||
) -> Optional[SearchResult]:
|
||||
# Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
|
||||
if limit is None:
|
||||
|
||||
@@ -254,7 +254,7 @@ class QdrantClient(VectorDBBase):
|
||||
)
|
||||
|
||||
def search(
|
||||
self, collection_name: str, vectors: List[List[float | int]], limit: int
|
||||
self, collection_name: str, vectors: List[List[float | int]], filter: Optional[Dict] = None, limit: int = 10
|
||||
) -> Optional[SearchResult]:
|
||||
"""
|
||||
Search for the nearest neighbor items based on the vectors with tenant isolation.
|
||||
|
||||
@@ -295,7 +295,7 @@ class S3VectorClient(VectorDBBase):
|
||||
raise
|
||||
|
||||
def search(
|
||||
self, collection_name: str, vectors: List[List[Union[float, int]]], limit: int
|
||||
self, collection_name: str, vectors: List[List[Union[float, int]]], filter: Optional[dict] = None, limit: int = 10
|
||||
) -> Optional[SearchResult]:
|
||||
"""
|
||||
Search for similar vectors in a collection using multiple query vectors.
|
||||
|
||||
@@ -159,7 +159,7 @@ class WeaviateClient(VectorDBBase):
|
||||
)
|
||||
|
||||
def search(
|
||||
self, collection_name: str, vectors: List[List[Union[float, int]]], limit: int
|
||||
self, collection_name: str, vectors: List[List[Union[float, int]]], filter: Optional[dict] = None, limit: int = 10
|
||||
) -> Optional[SearchResult]:
|
||||
sane_collection_name = self._sanitize_collection_name(collection_name)
|
||||
if not self.client.collections.exists(sane_collection_name):
|
||||
|
||||
@@ -53,7 +53,11 @@ class VectorDBBase(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def search(
|
||||
self, collection_name: str, vectors: List[List[Union[float, int]]], limit: int
|
||||
self,
|
||||
collection_name: str,
|
||||
vectors: List[List[Union[float, int]]],
|
||||
filter: Optional[Dict] = None,
|
||||
limit: int = 10,
|
||||
) -> Optional[SearchResult]:
|
||||
"""Search for similar vectors in a collection."""
|
||||
pass
|
||||
|
||||
@@ -46,6 +46,54 @@ router = APIRouter()
|
||||
|
||||
PAGE_ITEM_COUNT = 30
|
||||
|
||||
############################
|
||||
# Knowledge Base Embedding
|
||||
############################
|
||||
|
||||
KNOWLEDGE_BASES_COLLECTION = "knowledge-bases"
|
||||
|
||||
|
||||
async def embed_knowledge_base_metadata(
|
||||
request: Request,
|
||||
knowledge_base_id: str,
|
||||
name: str,
|
||||
description: str,
|
||||
) -> bool:
|
||||
"""Generate and store embedding for knowledge base."""
|
||||
try:
|
||||
content = f"{name}\n\n{description}" if description else name
|
||||
embedding = await request.app.state.EMBEDDING_FUNCTION(content)
|
||||
VECTOR_DB_CLIENT.upsert(
|
||||
collection_name=KNOWLEDGE_BASES_COLLECTION,
|
||||
items=[
|
||||
{
|
||||
"id": knowledge_base_id,
|
||||
"text": content,
|
||||
"vector": embedding,
|
||||
"metadata": {
|
||||
"knowledge_base_id": knowledge_base_id,
|
||||
},
|
||||
}
|
||||
],
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
log.error(f"Failed to embed knowledge base {knowledge_base_id}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def remove_knowledge_base_metadata_embedding(knowledge_base_id: str) -> bool:
|
||||
"""Remove knowledge base embedding."""
|
||||
try:
|
||||
VECTOR_DB_CLIENT.delete(
|
||||
collection_name=KNOWLEDGE_BASES_COLLECTION,
|
||||
ids=[knowledge_base_id],
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
log.debug(f"Failed to remove embedding for {knowledge_base_id}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
class KnowledgeAccessResponse(KnowledgeUserResponse):
|
||||
write_access: Optional[bool] = False
|
||||
@@ -205,6 +253,13 @@ async def create_new_knowledge(
|
||||
knowledge = Knowledges.insert_new_knowledge(user.id, form_data, db=db)
|
||||
|
||||
if knowledge:
|
||||
# Embed knowledge base for semantic search
|
||||
await embed_knowledge_base_metadata(
|
||||
request,
|
||||
knowledge.id,
|
||||
knowledge.name,
|
||||
knowledge.description,
|
||||
)
|
||||
return knowledge
|
||||
else:
|
||||
raise HTTPException(
|
||||
@@ -281,6 +336,30 @@ async def reindex_knowledge_files(
|
||||
return True
|
||||
|
||||
|
||||
############################
|
||||
# ReindexKnowledgeBases
|
||||
############################
|
||||
|
||||
|
||||
@router.post("/metadata/reindex", response_model=dict)
|
||||
async def reindex_knowledge_base_metadata_embeddings(
|
||||
request: Request,
|
||||
user=Depends(get_admin_user),
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
"""Batch embed all existing knowledge bases. Admin only."""
|
||||
knowledge_bases = Knowledges.get_knowledge_bases(db=db)
|
||||
log.info(f"Reindexing embeddings for {len(knowledge_bases)} knowledge bases")
|
||||
|
||||
success_count = 0
|
||||
for kb in knowledge_bases:
|
||||
if await embed_knowledge_base_metadata(request, kb.id, kb.name, kb.description):
|
||||
success_count += 1
|
||||
|
||||
log.info(f"Embedding reindex complete: {success_count}/{len(knowledge_bases)}")
|
||||
return {"total": len(knowledge_bases), "success": success_count}
|
||||
|
||||
|
||||
############################
|
||||
# GetKnowledgeById
|
||||
############################
|
||||
@@ -369,6 +448,13 @@ async def update_knowledge_by_id(
|
||||
|
||||
knowledge = Knowledges.update_knowledge_by_id(id=id, form_data=form_data, db=db)
|
||||
if knowledge:
|
||||
# Re-embed knowledge base for semantic search
|
||||
await embed_knowledge_base_metadata(
|
||||
request,
|
||||
knowledge.id,
|
||||
knowledge.name,
|
||||
knowledge.description,
|
||||
)
|
||||
return KnowledgeFilesResponse(
|
||||
**knowledge.model_dump(),
|
||||
files=Knowledges.get_file_metadatas_by_id(knowledge.id, db=db),
|
||||
@@ -718,6 +804,10 @@ async def delete_knowledge_by_id(
|
||||
except Exception as e:
|
||||
log.debug(e)
|
||||
pass
|
||||
|
||||
# Remove knowledge base embedding
|
||||
remove_knowledge_base_metadata_embedding(id)
|
||||
|
||||
result = Knowledges.delete_knowledge_by_id(id=id, db=db)
|
||||
return result
|
||||
|
||||
|
||||
@@ -39,6 +39,8 @@ from open_webui.models.groups import Groups
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
MAX_KNOWLEDGE_BASE_SEARCH_ITEMS = 10_000
|
||||
|
||||
# =============================================================================
|
||||
# TIME UTILITIES
|
||||
# =============================================================================
|
||||
@@ -1413,7 +1415,7 @@ async def view_knowledge_file(
|
||||
return json.dumps({"error": str(e)})
|
||||
|
||||
|
||||
async def query_knowledge_bases(
|
||||
async def query_knowledge_files(
|
||||
query: str,
|
||||
knowledge_ids: Optional[list[str]] = None,
|
||||
count: int = 5,
|
||||
@@ -1422,7 +1424,7 @@ async def query_knowledge_bases(
|
||||
__model_knowledge__: list[dict] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Search internal knowledge bases using semantic/vector search. This should be your first
|
||||
Search knowledge base files using semantic/vector search. This should be your first
|
||||
choice for finding information before searching the web. Searches across collections (KBs),
|
||||
individual files, and notes that the user has access to.
|
||||
|
||||
@@ -1558,6 +1560,105 @@ async def query_knowledge_bases(
|
||||
chunks = chunks[:count]
|
||||
|
||||
return json.dumps(chunks, ensure_ascii=False)
|
||||
except Exception as e:
|
||||
log.exception(f"query_knowledge_files error: {e}")
|
||||
return json.dumps({"error": str(e)})
|
||||
|
||||
|
||||
async def query_knowledge_bases(
|
||||
query: str,
|
||||
count: int = 5,
|
||||
__request__: Request = None,
|
||||
__user__: dict = None,
|
||||
) -> str:
|
||||
"""
|
||||
Search knowledge bases by semantic similarity to query.
|
||||
Finds KBs whose name/description match the meaning of your query.
|
||||
Use this to discover relevant knowledge bases before querying their files.
|
||||
|
||||
:param query: Natural language query describing what you're looking for
|
||||
:param count: Maximum results (default: 5)
|
||||
:return: JSON with matching KBs (id, name, description, similarity)
|
||||
"""
|
||||
if __request__ is None:
|
||||
return json.dumps({"error": "Request context not available"})
|
||||
|
||||
if not __user__:
|
||||
return json.dumps({"error": "User context not available"})
|
||||
|
||||
try:
|
||||
import heapq
|
||||
from open_webui.models.knowledge import Knowledges
|
||||
from open_webui.routers.knowledge import KNOWLEDGE_BASES_COLLECTION
|
||||
from open_webui.retrieval.vector.factory import VECTOR_DB_CLIENT
|
||||
|
||||
user_id = __user__.get("id")
|
||||
user_group_ids = [group.id for group in Groups.get_groups_by_member_id(user_id)]
|
||||
query_embedding = await __request__.app.state.EMBEDDING_FUNCTION(query)
|
||||
|
||||
# Min-heap of (distance, knowledge_base_id) - only holds top `count` results
|
||||
top_results_heap = []
|
||||
seen_ids = set()
|
||||
page_offset = 0
|
||||
page_size = 100
|
||||
|
||||
while True:
|
||||
accessible_knowledge_bases = Knowledges.search_knowledge_bases(
|
||||
user_id,
|
||||
filter={"user_id": user_id, "group_ids": user_group_ids},
|
||||
skip=page_offset,
|
||||
limit=page_size,
|
||||
)
|
||||
|
||||
if not accessible_knowledge_bases.items:
|
||||
break
|
||||
|
||||
accessible_ids = [kb.id for kb in accessible_knowledge_bases.items]
|
||||
|
||||
search_results = VECTOR_DB_CLIENT.search(
|
||||
collection_name=KNOWLEDGE_BASES_COLLECTION,
|
||||
vectors=[query_embedding],
|
||||
filter={"knowledge_base_id": {"$in": accessible_ids}},
|
||||
limit=count,
|
||||
)
|
||||
|
||||
if search_results and search_results.ids and search_results.ids[0]:
|
||||
result_ids = search_results.ids[0]
|
||||
result_distances = search_results.distances[0] if search_results.distances else [0] * len(result_ids)
|
||||
|
||||
for knowledge_base_id, distance in zip(result_ids, result_distances):
|
||||
if knowledge_base_id in seen_ids:
|
||||
continue
|
||||
seen_ids.add(knowledge_base_id)
|
||||
|
||||
if len(top_results_heap) < count:
|
||||
heapq.heappush(top_results_heap, (distance, knowledge_base_id))
|
||||
elif distance > top_results_heap[0][0]:
|
||||
heapq.heapreplace(top_results_heap, (distance, knowledge_base_id))
|
||||
|
||||
page_offset += page_size
|
||||
if len(accessible_knowledge_bases.items) < page_size:
|
||||
break
|
||||
if page_offset >= MAX_KNOWLEDGE_BASE_SEARCH_ITEMS:
|
||||
break
|
||||
|
||||
# Sort by distance descending (best first) and fetch KB details
|
||||
sorted_results = sorted(top_results_heap, key=lambda x: x[0], reverse=True)
|
||||
|
||||
matching_knowledge_bases = []
|
||||
for distance, knowledge_base_id in sorted_results:
|
||||
knowledge_base = Knowledges.get_knowledge_by_id(knowledge_base_id)
|
||||
if knowledge_base:
|
||||
matching_knowledge_bases.append({
|
||||
"id": knowledge_base.id,
|
||||
"name": knowledge_base.name,
|
||||
"description": knowledge_base.description or "",
|
||||
"similarity": round(distance, 4),
|
||||
})
|
||||
|
||||
return json.dumps(matching_knowledge_bases, ensure_ascii=False)
|
||||
|
||||
except Exception as e:
|
||||
log.exception(f"query_knowledge_bases error: {e}")
|
||||
return json.dumps({"error": str(e)})
|
||||
|
||||
|
||||
@@ -157,7 +157,7 @@ def get_citation_source_from_tool_result(
|
||||
- document: list of document contents
|
||||
- metadata: list of metadata objects with source, file_id, name fields
|
||||
|
||||
Returns a list of sources (usually one, but query_knowledge_bases may return multiple).
|
||||
Returns a list of sources (usually one, but query_knowledge_files may return multiple).
|
||||
"""
|
||||
try:
|
||||
if tool_name == "search_web":
|
||||
@@ -217,7 +217,7 @@ def get_citation_source_from_tool_result(
|
||||
}
|
||||
]
|
||||
|
||||
elif tool_name == "query_knowledge_bases":
|
||||
elif tool_name == "query_knowledge_files":
|
||||
chunks = json.loads(tool_result)
|
||||
|
||||
# Group chunks by source for better citation display
|
||||
@@ -3343,7 +3343,7 @@ async def process_chat_response(
|
||||
in [
|
||||
"search_web",
|
||||
"view_knowledge_file",
|
||||
"query_knowledge_bases",
|
||||
"query_knowledge_files",
|
||||
]
|
||||
and tool_result
|
||||
):
|
||||
|
||||
@@ -68,9 +68,10 @@ from open_webui.tools.builtin import (
|
||||
write_note,
|
||||
list_knowledge_bases,
|
||||
search_knowledge_bases,
|
||||
search_knowledge_files,
|
||||
view_knowledge_file,
|
||||
query_knowledge_bases,
|
||||
search_knowledge_files,
|
||||
query_knowledge_files,
|
||||
view_knowledge_file,
|
||||
)
|
||||
|
||||
import copy
|
||||
@@ -406,21 +407,22 @@ def get_builtin_tools(
|
||||
builtin_functions.extend([get_current_timestamp, calculate_timestamp])
|
||||
|
||||
# Knowledge base tools - conditional injection based on model knowledge
|
||||
# If model has attached knowledge (any type), only provide query_knowledge_bases
|
||||
# If model has attached knowledge (any type), only provide query_knowledge_files
|
||||
# Otherwise, provide all KB browsing tools
|
||||
model_knowledge = model.get("info", {}).get("meta", {}).get("knowledge", [])
|
||||
if model_knowledge:
|
||||
# Model has attached knowledge - only allow semantic search within it
|
||||
builtin_functions.append(query_knowledge_bases)
|
||||
builtin_functions.append(query_knowledge_files)
|
||||
else:
|
||||
# No model knowledge - allow full KB browsing
|
||||
builtin_functions.extend(
|
||||
[
|
||||
list_knowledge_bases,
|
||||
search_knowledge_bases,
|
||||
search_knowledge_files,
|
||||
view_knowledge_file,
|
||||
query_knowledge_bases,
|
||||
search_knowledge_files,
|
||||
query_knowledge_files,
|
||||
view_knowledge_file,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user