enh: kb metadata search

This commit is contained in:
Timothy Jaeryang Baek
2026-01-09 22:21:00 +04:00
parent eff772562b
commit 3c986adeda
18 changed files with 257 additions and 26 deletions

View File

@@ -69,7 +69,7 @@ class ChromaClient(VectorDBBase):
return self.client.delete_collection(name=collection_name)
def search(
self, collection_name: str, vectors: list[list[float | int]], limit: int
self, collection_name: str, vectors: list[list[float | int]], filter: Optional[dict] = None, limit: int = 10
) -> Optional[SearchResult]:
# Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
try:
@@ -78,6 +78,7 @@ class ChromaClient(VectorDBBase):
result = collection.query(
query_embeddings=vectors,
n_results=limit,
where=filter,
)
# chromadb has cosine distance, 2 (worst) -> 0 (best). Re-odering to 0 -> 1

View File

@@ -153,7 +153,7 @@ class ElasticsearchClient(VectorDBBase):
# Status: works
def search(
self, collection_name: str, vectors: list[list[float]], limit: int
self, collection_name: str, vectors: list[list[float]], filter: Optional[dict] = None, limit: int = 10
) -> Optional[SearchResult]:
query = {
"size": limit,

View File

@@ -179,7 +179,7 @@ class MilvusClient(VectorDBBase):
)
def search(
self, collection_name: str, vectors: list[list[float | int]], limit: int
self, collection_name: str, vectors: list[list[float | int]], filter: Optional[dict] = None, limit: int = 10
) -> Optional[SearchResult]:
# Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
collection_name = collection_name.replace("-", "_")

View File

@@ -157,7 +157,7 @@ class MilvusClient(VectorDBBase):
collection.insert(entities)
def search(
self, collection_name: str, vectors: List[List[float]], limit: int
self, collection_name: str, vectors: List[List[float]], filter: Optional[Dict] = None, limit: int = 10
) -> Optional[SearchResult]:
if not vectors:
return None

View File

@@ -233,7 +233,8 @@ class OpenGaussClient(VectorDBBase):
self,
collection_name: str,
vectors: List[List[float]],
limit: Optional[int] = None,
filter: Optional[Dict[str, Any]] = None,
limit: int = 10,
) -> Optional[SearchResult]:
try:
if not vectors:

View File

@@ -113,7 +113,7 @@ class OpenSearchClient(VectorDBBase):
self.client.indices.delete(index=self._get_index_name(collection_name))
def search(
self, collection_name: str, vectors: list[list[float | int]], limit: int
self, collection_name: str, vectors: list[list[float | int]], filter: Optional[dict] = None, limit: int = 10
) -> Optional[SearchResult]:
try:
if not self.has_collection(collection_name):

View File

@@ -521,7 +521,7 @@ class Oracle23aiClient(VectorDBBase):
raise
def search(
self, collection_name: str, vectors: List[List[Union[float, int]]], limit: int
self, collection_name: str, vectors: List[List[Union[float, int]]], filter: Optional[dict] = None, limit: int = 10
) -> Optional[SearchResult]:
"""
Search for similar vectors in the database.

View File

@@ -427,7 +427,8 @@ class PgvectorClient(VectorDBBase):
self,
collection_name: str,
vectors: List[List[float]],
limit: Optional[int] = None,
filter: Optional[Dict[str, Any]] = None,
limit: int = 10,
) -> Optional[SearchResult]:
try:
if not vectors:
@@ -475,9 +476,40 @@ class PgvectorClient(VectorDBBase):
)
# Build the lateral subquery for each query vector
where_clauses = [DocumentChunk.collection_name == collection_name]
# Apply metadata filter if provided
if filter:
for key, value in filter.items():
if isinstance(value, dict) and "$in" in value:
# Handle $in operator: {"field": {"$in": [values]}}
in_values = value["$in"]
if PGVECTOR_PGCRYPTO:
where_clauses.append(
pgcrypto_decrypt(
DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB
)[key].astext.in_([str(v) for v in in_values])
)
else:
where_clauses.append(
DocumentChunk.vmetadata[key].astext.in_([str(v) for v in in_values])
)
else:
# Handle simple equality: {"field": "value"}
if PGVECTOR_PGCRYPTO:
where_clauses.append(
pgcrypto_decrypt(
DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB
)[key].astext == str(value)
)
else:
where_clauses.append(
DocumentChunk.vmetadata[key].astext == str(value)
)
subq = (
select(*result_fields)
.where(DocumentChunk.collection_name == collection_name)
.where(*where_clauses)
.order_by(
(DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector))
)

View File

@@ -391,7 +391,7 @@ class PineconeClient(VectorDBBase):
)
def search(
self, collection_name: str, vectors: List[List[Union[float, int]]], limit: int
self, collection_name: str, vectors: List[List[Union[float, int]]], filter: Optional[dict] = None, limit: int = 10
) -> Optional[SearchResult]:
"""Search for similar vectors in a collection."""
if not vectors or not vectors[0]:

View File

@@ -145,7 +145,7 @@ class QdrantClient(VectorDBBase):
)
def search(
self, collection_name: str, vectors: list[list[float | int]], limit: int
self, collection_name: str, vectors: list[list[float | int]], filter: Optional[dict] = None, limit: int = 10
) -> Optional[SearchResult]:
# Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
if limit is None:

View File

@@ -254,7 +254,7 @@ class QdrantClient(VectorDBBase):
)
def search(
self, collection_name: str, vectors: List[List[float | int]], limit: int
self, collection_name: str, vectors: List[List[float | int]], filter: Optional[Dict] = None, limit: int = 10
) -> Optional[SearchResult]:
"""
Search for the nearest neighbor items based on the vectors with tenant isolation.

View File

@@ -295,7 +295,7 @@ class S3VectorClient(VectorDBBase):
raise
def search(
self, collection_name: str, vectors: List[List[Union[float, int]]], limit: int
self, collection_name: str, vectors: List[List[Union[float, int]]], filter: Optional[dict] = None, limit: int = 10
) -> Optional[SearchResult]:
"""
Search for similar vectors in a collection using multiple query vectors.

View File

@@ -159,7 +159,7 @@ class WeaviateClient(VectorDBBase):
)
def search(
self, collection_name: str, vectors: List[List[Union[float, int]]], limit: int
self, collection_name: str, vectors: List[List[Union[float, int]]], filter: Optional[dict] = None, limit: int = 10
) -> Optional[SearchResult]:
sane_collection_name = self._sanitize_collection_name(collection_name)
if not self.client.collections.exists(sane_collection_name):

View File

@@ -53,7 +53,11 @@ class VectorDBBase(ABC):
@abstractmethod
def search(
self, collection_name: str, vectors: List[List[Union[float, int]]], limit: int
self,
collection_name: str,
vectors: List[List[Union[float, int]]],
filter: Optional[Dict] = None,
limit: int = 10,
) -> Optional[SearchResult]:
"""Search for similar vectors in a collection."""
pass

View File

@@ -46,6 +46,54 @@ router = APIRouter()
PAGE_ITEM_COUNT = 30
############################
# Knowledge Base Embedding
############################
KNOWLEDGE_BASES_COLLECTION = "knowledge-bases"
async def embed_knowledge_base_metadata(
request: Request,
knowledge_base_id: str,
name: str,
description: str,
) -> bool:
"""Generate and store embedding for knowledge base."""
try:
content = f"{name}\n\n{description}" if description else name
embedding = await request.app.state.EMBEDDING_FUNCTION(content)
VECTOR_DB_CLIENT.upsert(
collection_name=KNOWLEDGE_BASES_COLLECTION,
items=[
{
"id": knowledge_base_id,
"text": content,
"vector": embedding,
"metadata": {
"knowledge_base_id": knowledge_base_id,
},
}
],
)
return True
except Exception as e:
log.error(f"Failed to embed knowledge base {knowledge_base_id}: {e}")
return False
def remove_knowledge_base_metadata_embedding(knowledge_base_id: str) -> bool:
"""Remove knowledge base embedding."""
try:
VECTOR_DB_CLIENT.delete(
collection_name=KNOWLEDGE_BASES_COLLECTION,
ids=[knowledge_base_id],
)
return True
except Exception as e:
log.debug(f"Failed to remove embedding for {knowledge_base_id}: {e}")
return False
class KnowledgeAccessResponse(KnowledgeUserResponse):
write_access: Optional[bool] = False
@@ -205,6 +253,13 @@ async def create_new_knowledge(
knowledge = Knowledges.insert_new_knowledge(user.id, form_data, db=db)
if knowledge:
# Embed knowledge base for semantic search
await embed_knowledge_base_metadata(
request,
knowledge.id,
knowledge.name,
knowledge.description,
)
return knowledge
else:
raise HTTPException(
@@ -281,6 +336,30 @@ async def reindex_knowledge_files(
return True
############################
# ReindexKnowledgeBases
############################
@router.post("/metadata/reindex", response_model=dict)
async def reindex_knowledge_base_metadata_embeddings(
request: Request,
user=Depends(get_admin_user),
db: Session = Depends(get_session),
):
"""Batch embed all existing knowledge bases. Admin only."""
knowledge_bases = Knowledges.get_knowledge_bases(db=db)
log.info(f"Reindexing embeddings for {len(knowledge_bases)} knowledge bases")
success_count = 0
for kb in knowledge_bases:
if await embed_knowledge_base_metadata(request, kb.id, kb.name, kb.description):
success_count += 1
log.info(f"Embedding reindex complete: {success_count}/{len(knowledge_bases)}")
return {"total": len(knowledge_bases), "success": success_count}
############################
# GetKnowledgeById
############################
@@ -369,6 +448,13 @@ async def update_knowledge_by_id(
knowledge = Knowledges.update_knowledge_by_id(id=id, form_data=form_data, db=db)
if knowledge:
# Re-embed knowledge base for semantic search
await embed_knowledge_base_metadata(
request,
knowledge.id,
knowledge.name,
knowledge.description,
)
return KnowledgeFilesResponse(
**knowledge.model_dump(),
files=Knowledges.get_file_metadatas_by_id(knowledge.id, db=db),
@@ -718,6 +804,10 @@ async def delete_knowledge_by_id(
except Exception as e:
log.debug(e)
pass
# Remove knowledge base embedding
remove_knowledge_base_metadata_embedding(id)
result = Knowledges.delete_knowledge_by_id(id=id, db=db)
return result

View File

@@ -39,6 +39,8 @@ from open_webui.models.groups import Groups
log = logging.getLogger(__name__)
MAX_KNOWLEDGE_BASE_SEARCH_ITEMS = 10_000
# =============================================================================
# TIME UTILITIES
# =============================================================================
@@ -1413,7 +1415,7 @@ async def view_knowledge_file(
return json.dumps({"error": str(e)})
async def query_knowledge_bases(
async def query_knowledge_files(
query: str,
knowledge_ids: Optional[list[str]] = None,
count: int = 5,
@@ -1422,7 +1424,7 @@ async def query_knowledge_bases(
__model_knowledge__: list[dict] = None,
) -> str:
"""
Search internal knowledge bases using semantic/vector search. This should be your first
Search knowledge base files using semantic/vector search. This should be your first
choice for finding information before searching the web. Searches across collections (KBs),
individual files, and notes that the user has access to.
@@ -1558,6 +1560,105 @@ async def query_knowledge_bases(
chunks = chunks[:count]
return json.dumps(chunks, ensure_ascii=False)
except Exception as e:
log.exception(f"query_knowledge_files error: {e}")
return json.dumps({"error": str(e)})
async def query_knowledge_bases(
query: str,
count: int = 5,
__request__: Request = None,
__user__: dict = None,
) -> str:
"""
Search knowledge bases by semantic similarity to query.
Finds KBs whose name/description match the meaning of your query.
Use this to discover relevant knowledge bases before querying their files.
:param query: Natural language query describing what you're looking for
:param count: Maximum results (default: 5)
:return: JSON with matching KBs (id, name, description, similarity)
"""
if __request__ is None:
return json.dumps({"error": "Request context not available"})
if not __user__:
return json.dumps({"error": "User context not available"})
try:
import heapq
from open_webui.models.knowledge import Knowledges
from open_webui.routers.knowledge import KNOWLEDGE_BASES_COLLECTION
from open_webui.retrieval.vector.factory import VECTOR_DB_CLIENT
user_id = __user__.get("id")
user_group_ids = [group.id for group in Groups.get_groups_by_member_id(user_id)]
query_embedding = await __request__.app.state.EMBEDDING_FUNCTION(query)
# Min-heap of (distance, knowledge_base_id) - only holds top `count` results
top_results_heap = []
seen_ids = set()
page_offset = 0
page_size = 100
while True:
accessible_knowledge_bases = Knowledges.search_knowledge_bases(
user_id,
filter={"user_id": user_id, "group_ids": user_group_ids},
skip=page_offset,
limit=page_size,
)
if not accessible_knowledge_bases.items:
break
accessible_ids = [kb.id for kb in accessible_knowledge_bases.items]
search_results = VECTOR_DB_CLIENT.search(
collection_name=KNOWLEDGE_BASES_COLLECTION,
vectors=[query_embedding],
filter={"knowledge_base_id": {"$in": accessible_ids}},
limit=count,
)
if search_results and search_results.ids and search_results.ids[0]:
result_ids = search_results.ids[0]
result_distances = search_results.distances[0] if search_results.distances else [0] * len(result_ids)
for knowledge_base_id, distance in zip(result_ids, result_distances):
if knowledge_base_id in seen_ids:
continue
seen_ids.add(knowledge_base_id)
if len(top_results_heap) < count:
heapq.heappush(top_results_heap, (distance, knowledge_base_id))
elif distance > top_results_heap[0][0]:
heapq.heapreplace(top_results_heap, (distance, knowledge_base_id))
page_offset += page_size
if len(accessible_knowledge_bases.items) < page_size:
break
if page_offset >= MAX_KNOWLEDGE_BASE_SEARCH_ITEMS:
break
# Sort by distance descending (best first) and fetch KB details
sorted_results = sorted(top_results_heap, key=lambda x: x[0], reverse=True)
matching_knowledge_bases = []
for distance, knowledge_base_id in sorted_results:
knowledge_base = Knowledges.get_knowledge_by_id(knowledge_base_id)
if knowledge_base:
matching_knowledge_bases.append({
"id": knowledge_base.id,
"name": knowledge_base.name,
"description": knowledge_base.description or "",
"similarity": round(distance, 4),
})
return json.dumps(matching_knowledge_bases, ensure_ascii=False)
except Exception as e:
log.exception(f"query_knowledge_bases error: {e}")
return json.dumps({"error": str(e)})

View File

@@ -157,7 +157,7 @@ def get_citation_source_from_tool_result(
- document: list of document contents
- metadata: list of metadata objects with source, file_id, name fields
Returns a list of sources (usually one, but query_knowledge_bases may return multiple).
Returns a list of sources (usually one, but query_knowledge_files may return multiple).
"""
try:
if tool_name == "search_web":
@@ -217,7 +217,7 @@ def get_citation_source_from_tool_result(
}
]
elif tool_name == "query_knowledge_bases":
elif tool_name == "query_knowledge_files":
chunks = json.loads(tool_result)
# Group chunks by source for better citation display
@@ -3343,7 +3343,7 @@ async def process_chat_response(
in [
"search_web",
"view_knowledge_file",
"query_knowledge_bases",
"query_knowledge_files",
]
and tool_result
):

View File

@@ -68,9 +68,10 @@ from open_webui.tools.builtin import (
write_note,
list_knowledge_bases,
search_knowledge_bases,
search_knowledge_files,
view_knowledge_file,
query_knowledge_bases,
search_knowledge_files,
query_knowledge_files,
view_knowledge_file,
)
import copy
@@ -406,21 +407,22 @@ def get_builtin_tools(
builtin_functions.extend([get_current_timestamp, calculate_timestamp])
# Knowledge base tools - conditional injection based on model knowledge
# If model has attached knowledge (any type), only provide query_knowledge_bases
# If model has attached knowledge (any type), only provide query_knowledge_files
# Otherwise, provide all KB browsing tools
model_knowledge = model.get("info", {}).get("meta", {}).get("knowledge", [])
if model_knowledge:
# Model has attached knowledge - only allow semantic search within it
builtin_functions.append(query_knowledge_bases)
builtin_functions.append(query_knowledge_files)
else:
# No model knowledge - allow full KB browsing
builtin_functions.extend(
[
list_knowledge_bases,
search_knowledge_bases,
search_knowledge_files,
view_knowledge_file,
query_knowledge_bases,
search_knowledge_files,
query_knowledge_files,
view_knowledge_file,
]
)