fix: opensearch vector db query structures, result mapping, filters, bulk query actions, knn_vector usage

This commit is contained in:
Katharina 2025-03-06 23:49:54 +01:00
parent 3b70cd64d7
commit 6cb0c0339a

View File

@ -1,4 +1,5 @@
from opensearchpy import OpenSearch from opensearchpy import OpenSearch
from opensearchpy.helpers import bulk
from typing import Optional from typing import Optional
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
@ -21,7 +22,13 @@ class OpenSearchClient:
http_auth=(OPENSEARCH_USERNAME, OPENSEARCH_PASSWORD), http_auth=(OPENSEARCH_USERNAME, OPENSEARCH_PASSWORD),
) )
def _get_index_name(self, collection_name: str) -> str:
return f"{self.index_prefix}_{collection_name}"
def _result_to_get_result(self, result) -> GetResult: def _result_to_get_result(self, result) -> GetResult:
if not result["hits"]["hits"]:
return None
ids = [] ids = []
documents = [] documents = []
metadatas = [] metadatas = []
@ -31,9 +38,12 @@ class OpenSearchClient:
documents.append(hit["_source"].get("text")) documents.append(hit["_source"].get("text"))
metadatas.append(hit["_source"].get("metadata")) metadatas.append(hit["_source"].get("metadata"))
return GetResult(ids=ids, documents=documents, metadatas=metadatas) return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas])
def _result_to_search_result(self, result) -> SearchResult: def _result_to_search_result(self, result) -> SearchResult:
if not result["hits"]["hits"]:
return None
ids = [] ids = []
distances = [] distances = []
documents = [] documents = []
@ -46,25 +56,32 @@ class OpenSearchClient:
metadatas.append(hit["_source"].get("metadata")) metadatas.append(hit["_source"].get("metadata"))
return SearchResult( return SearchResult(
ids=ids, distances=distances, documents=documents, metadatas=metadatas ids=[ids], distances=[distances], documents=[documents], metadatas=[metadatas]
) )
def _create_index(self, collection_name: str, dimension: int): def _create_index(self, collection_name: str, dimension: int):
body = { body = {
"settings": {
"index": {
"knn": True
}
},
"mappings": { "mappings": {
"properties": { "properties": {
"id": {"type": "keyword"}, "id": {"type": "keyword"},
"vector": { "vector": {
"type": "dense_vector", "type": "knn_vector",
"dims": dimension, # Adjust based on your vector dimensions "dimension": dimension, # Adjust based on your vector dimensions
"index": true, "index": True,
"similarity": "faiss", "similarity": "faiss",
"method": { "method": {
"name": "hnsw", "name": "hnsw",
"space_type": "ip", # Use inner product to approximate cosine similarity "space_type": "innerproduct", # Use inner product to approximate cosine similarity
"engine": "faiss", "engine": "faiss",
"parameters": {
"ef_construction": 128, "ef_construction": 128,
"m": 16, "m": 16,
}
}, },
}, },
"text": {"type": "text"}, "text": {"type": "text"},
@ -73,7 +90,7 @@ class OpenSearchClient:
} }
} }
self.client.indices.create( self.client.indices.create(
index=f"{self.index_prefix}_{collection_name}", body=body index=self._get_index_name(collection_name), body=body
) )
def _create_batches(self, items: list[VectorItem], batch_size=100): def _create_batches(self, items: list[VectorItem], batch_size=100):
@ -84,27 +101,34 @@ class OpenSearchClient:
# has_collection here means has index. # has_collection here means has index.
# We are simply adapting to the norms of the other DBs. # We are simply adapting to the norms of the other DBs.
return self.client.indices.exists( return self.client.indices.exists(
index=f"{self.index_prefix}_{collection_name}" index=self._get_index_name(collection_name)
) )
def delete_colleciton(self, collection_name: str): def delete_collection(self, collection_name: str):
# delete_collection here means delete index. # delete_collection here means delete index.
# We are simply adapting to the norms of the other DBs. # We are simply adapting to the norms of the other DBs.
self.client.indices.delete(index=f"{self.index_prefix}_{collection_name}") self.client.indices.delete(index=self._get_index_name(collection_name))
def search( def search(
self, collection_name: str, vectors: list[list[float]], limit: int self, collection_name: str, vectors: list[list[float | int]], limit: int
) -> Optional[SearchResult]: ) -> Optional[SearchResult]:
try:
if not self.has_collection(collection_name):
return None
query = { query = {
"size": limit, "size": limit,
"_source": ["text", "metadata"], "_source": ["text", "metadata"],
"query": { "query": {
"script_score": { "script_score": {
"query": {"match_all": {}}, "query": {
"match_all": {}
},
"script": { "script": {
"source": "cosineSimilarity(params.vector, 'vector') + 1.0", "source": "cosineSimilarity(params.query_value, doc[params.field]) + 1.0",
"params": { "params": {
"vector": vectors[0] "field": "vector",
"query_value": vectors[0]
}, # Assuming single query vector }, # Assuming single query vector
}, },
} }
@ -112,11 +136,15 @@ class OpenSearchClient:
} }
result = self.client.search( result = self.client.search(
index=f"{self.index_prefix}_{collection_name}", body=query index=self._get_index_name(collection_name),
body=query
) )
return self._result_to_search_result(result) return self._result_to_search_result(result)
except Exception as e:
return None
def query( def query(
self, collection_name: str, filter: dict, limit: Optional[int] = None self, collection_name: str, filter: dict, limit: Optional[int] = None
) -> Optional[GetResult]: ) -> Optional[GetResult]:
@ -124,18 +152,26 @@ class OpenSearchClient:
return None return None
query_body = { query_body = {
"query": {"bool": {"filter": []}}, "query": {
"bool": {
"filter": []
}
},
"_source": ["text", "metadata"], "_source": ["text", "metadata"],
} }
for field, value in filter.items(): for field, value in filter.items():
query_body["query"]["bool"]["filter"].append({"term": {field: value}}) query_body["query"]["bool"]["filter"].append({
"match": {
"metadata." + str(field): value
}
})
size = limit if limit else 10 size = limit if limit else 10
try: try:
result = self.client.search( result = self.client.search(
index=f"{self.index_prefix}_{collection_name}", index=self._get_index_name(collection_name),
body=query_body, body=query_body,
size=size, size=size,
) )
@ -146,14 +182,14 @@ class OpenSearchClient:
return None return None
def _create_index_if_not_exists(self, collection_name: str, dimension: int): def _create_index_if_not_exists(self, collection_name: str, dimension: int):
if not self.has_index(collection_name): if not self.has_collection(collection_name):
self._create_index(collection_name, dimension) self._create_index(collection_name, dimension)
def get(self, collection_name: str) -> Optional[GetResult]: def get(self, collection_name: str) -> Optional[GetResult]:
query = {"query": {"match_all": {}}, "_source": ["text", "metadata"]} query = {"query": {"match_all": {}}, "_source": ["text", "metadata"]}
result = self.client.search( result = self.client.search(
index=f"{self.index_prefix}_{collection_name}", body=query index=self._get_index_name(collection_name), body=query
) )
return self._result_to_get_result(result) return self._result_to_get_result(result)
@ -165,7 +201,8 @@ class OpenSearchClient:
for batch in self._create_batches(items): for batch in self._create_batches(items):
actions = [ actions = [
{ {
"index": { "_op_type": "index",
"_index": self._get_index_name(collection_name),
"_id": item["id"], "_id": item["id"],
"_source": { "_source": {
"vector": item["vector"], "vector": item["vector"],
@ -173,10 +210,9 @@ class OpenSearchClient:
"metadata": item["metadata"], "metadata": item["metadata"],
}, },
} }
}
for item in batch for item in batch
] ]
self.client.bulk(actions) bulk(self.client, actions)
def upsert(self, collection_name: str, items: list[VectorItem]): def upsert(self, collection_name: str, items: list[VectorItem]):
self._create_index_if_not_exists( self._create_index_if_not_exists(
@ -186,26 +222,46 @@ class OpenSearchClient:
for batch in self._create_batches(items): for batch in self._create_batches(items):
actions = [ actions = [
{ {
"index": { "_op_type": "update",
"_index": self._get_index_name(collection_name),
"_id": item["id"], "_id": item["id"],
"_index": f"{self.index_prefix}_{collection_name}", "doc": {
"_source": {
"vector": item["vector"], "vector": item["vector"],
"text": item["text"], "text": item["text"],
"metadata": item["metadata"], "metadata": item["metadata"],
}, },
} "doc_as_upsert": True,
} }
for item in batch for item in batch
] ]
self.client.bulk(actions) bulk(self.client, actions)
def delete(self, collection_name: str, ids: list[str]): def delete(self, collection_name: str, ids: Optional[list[str]] = None, filter: Optional[dict] = None):
if ids:
actions = [ actions = [
{"delete": {"_index": f"{self.index_prefix}_{collection_name}", "_id": id}} {
"_op_type": "delete",
"_index": self._get_index_name(collection_name),
"_id": id,
}
for id in ids for id in ids
] ]
self.client.bulk(body=actions) bulk(self.client, actions)
elif filter:
query_body = {
"query": {
"bool": {
"filter": []
}
},
}
for field, value in filter.items():
query_body["query"]["bool"]["filter"].append({
"match": {
"metadata." + str(field): value
}
})
self.client.delete_by_query(index=self._get_index_name(collection_name), body=query_body)
def reset(self): def reset(self):
indices = self.client.indices.get(index=f"{self.index_prefix}_*") indices = self.client.indices.get(index=f"{self.index_prefix}_*")