open-webui/backend/open_webui/retrieval/vector/dbs/opensearch.py

179 lines
6.2 KiB
Python
Raw Normal View History

2024-10-30 00:28:37 +00:00
from opensearchpy import OpenSearch
from typing import Optional
2024-11-20 14:09:48 +00:00
from open_webui.apps.retrieval.vector.main import VectorItem, SearchResult, GetResult
2024-10-30 00:28:37 +00:00
from open_webui.config import (
2024-11-04 20:14:53 +00:00
OPENSEARCH_URI,
OPENSEARCH_SSL,
OPENSEARCH_CERT_VERIFY,
OPENSEARCH_USERNAME,
2024-11-17 07:46:12 +00:00
OPENSEARCH_PASSWORD,
2024-10-30 00:28:37 +00:00
)
2024-11-17 07:46:12 +00:00
2024-10-30 00:28:37 +00:00
class OpenSearchClient:
def __init__(self):
self.index_prefix = "open_webui"
self.client = OpenSearch(
2024-11-04 20:14:53 +00:00
hosts=[OPENSEARCH_URI],
2024-10-30 00:28:37 +00:00
use_ssl=OPENSEARCH_SSL,
verify_certs=OPENSEARCH_CERT_VERIFY,
2024-11-04 20:14:53 +00:00
http_auth=(OPENSEARCH_USERNAME, OPENSEARCH_PASSWORD),
2024-10-30 00:28:37 +00:00
)
def _result_to_get_result(self, result) -> GetResult:
ids = []
documents = []
metadatas = []
2024-11-17 07:46:12 +00:00
for hit in result["hits"]["hits"]:
ids.append(hit["_id"])
documents.append(hit["_source"].get("text"))
metadatas.append(hit["_source"].get("metadata"))
2024-10-30 00:28:37 +00:00
return GetResult(ids=ids, documents=documents, metadatas=metadatas)
def _result_to_search_result(self, result) -> SearchResult:
ids = []
distances = []
documents = []
metadatas = []
2024-11-17 07:46:12 +00:00
for hit in result["hits"]["hits"]:
ids.append(hit["_id"])
distances.append(hit["_score"])
documents.append(hit["_source"].get("text"))
metadatas.append(hit["_source"].get("metadata"))
2024-10-30 00:28:37 +00:00
2024-11-17 07:46:12 +00:00
return SearchResult(
ids=ids, distances=distances, documents=documents, metadatas=metadatas
)
2024-10-30 00:28:37 +00:00
def _create_index(self, index_name: str, dimension: int):
body = {
"mappings": {
"properties": {
"id": {"type": "keyword"},
"vector": {
2024-11-17 07:46:12 +00:00
"type": "dense_vector",
"dims": dimension, # Adjust based on your vector dimensions
"index": true,
"similarity": "faiss",
"method": {
2024-10-30 00:28:37 +00:00
"name": "hnsw",
"space_type": "ip", # Use inner product to approximate cosine similarity
"engine": "faiss",
"ef_construction": 128,
2024-11-17 07:46:12 +00:00
"m": 16,
},
2024-10-30 00:28:37 +00:00
},
"text": {"type": "text"},
2024-11-17 07:46:12 +00:00
"metadata": {"type": "object"},
2024-10-30 00:28:37 +00:00
}
}
}
self.client.indices.create(index=f"{self.index_prefix}_{index_name}", body=body)
def _create_batches(self, items: list[VectorItem], batch_size=100):
for i in range(0, len(items), batch_size):
2024-11-17 07:46:12 +00:00
yield items[i : i + batch_size]
2024-10-30 00:28:37 +00:00
def has_collection(self, index_name: str) -> bool:
2024-11-17 07:46:12 +00:00
# has_collection here means has index.
2024-10-30 00:28:37 +00:00
# We are simply adapting to the norms of the other DBs.
return self.client.indices.exists(index=f"{self.index_prefix}_{index_name}")
def delete_colleciton(self, index_name: str):
2024-11-17 07:46:12 +00:00
# delete_collection here means delete index.
2024-10-30 00:28:37 +00:00
# We are simply adapting to the norms of the other DBs.
self.client.indices.delete(index=f"{self.index_prefix}_{index_name}")
2024-11-17 07:46:12 +00:00
def search(
self, index_name: str, vectors: list[list[float]], limit: int
) -> Optional[SearchResult]:
2024-10-30 00:28:37 +00:00
query = {
"size": limit,
"_source": ["text", "metadata"],
"query": {
"script_score": {
"query": {"match_all": {}},
"script": {
"source": "cosineSimilarity(params.vector, 'vector') + 1.0",
2024-11-17 07:46:12 +00:00
"params": {
"vector": vectors[0]
}, # Assuming single query vector
},
2024-10-30 00:28:37 +00:00
}
2024-11-17 07:46:12 +00:00
},
2024-10-30 00:28:37 +00:00
}
result = self.client.search(
2024-11-17 07:46:12 +00:00
index=f"{self.index_prefix}_{index_name}", body=query
2024-10-30 00:28:37 +00:00
)
return self._result_to_search_result(result)
def get_or_create_index(self, index_name: str, dimension: int):
if not self.has_index(index_name):
self._create_index(index_name, dimension)
def get(self, index_name: str) -> Optional[GetResult]:
2024-11-17 07:46:12 +00:00
query = {"query": {"match_all": {}}, "_source": ["text", "metadata"]}
2024-10-30 00:28:37 +00:00
2024-11-17 07:46:12 +00:00
result = self.client.search(
index=f"{self.index_prefix}_{index_name}", body=query
)
2024-10-30 00:28:37 +00:00
return self._result_to_get_result(result)
def insert(self, index_name: str, items: list[VectorItem]):
if not self.has_index(index_name):
self._create_index(index_name, dimension=len(items[0]["vector"]))
for batch in self._create_batches(items):
actions = [
2024-11-17 07:46:12 +00:00
{
"index": {
"_id": item["id"],
"_source": {
"vector": item["vector"],
"text": item["text"],
"metadata": item["metadata"],
},
}
}
2024-10-30 00:28:37 +00:00
for item in batch
]
self.client.bulk(actions)
def upsert(self, index_name: str, items: list[VectorItem]):
if not self.has_index(index_name):
self._create_index(index_name, dimension=len(items[0]["vector"]))
for batch in self._create_batches(items):
actions = [
2024-11-17 07:46:12 +00:00
{
"index": {
"_id": item["id"],
"_source": {
"vector": item["vector"],
"text": item["text"],
"metadata": item["metadata"],
},
}
}
2024-10-30 00:28:37 +00:00
for item in batch
]
self.client.bulk(actions)
def delete(self, index_name: str, ids: list[str]):
2024-11-17 07:46:12 +00:00
actions = [
{"delete": {"_index": f"{self.index_prefix}_{index_name}", "_id": id}}
for id in ids
]
2024-10-30 00:28:37 +00:00
self.client.bulk(body=actions)
def reset(self):
indices = self.client.indices.get(index=f"{self.index_prefix}_*")
for index in indices:
self.client.indices.delete(index=index)