diff --git a/backend/open_webui/apps/retrieval/vector/connector.py b/backend/open_webui/apps/retrieval/vector/connector.py index c7f00f5fd..b95b55783 100644 --- a/backend/open_webui/apps/retrieval/vector/connector.py +++ b/backend/open_webui/apps/retrieval/vector/connector.py @@ -8,6 +8,10 @@ elif VECTOR_DB == "qdrant": from open_webui.apps.retrieval.vector.dbs.qdrant import QdrantClient VECTOR_DB_CLIENT = QdrantClient() +elif VECTOR_DB == "opensearch": + from open_webui.apps.retrieval.vector.dbs.opensearch import OpenSearchClient + + VECTOR_DB_CLIENT = OpenSearchClient() else: from open_webui.apps.retrieval.vector.dbs.chroma import ChromaClient diff --git a/backend/open_webui/apps/retrieval/vector/dbs/opensearch.py b/backend/open_webui/apps/retrieval/vector/dbs/opensearch.py new file mode 100644 index 000000000..00027dfb9 --- /dev/null +++ b/backend/open_webui/apps/retrieval/vector/dbs/opensearch.py @@ -0,0 +1,152 @@ +from opensearchpy import OpenSearch +from typing import Optional + +from open_webui.apps.rag.vector.main import VectorItem, SearchResult, GetResult +from open_webui.config import ( + OPENSEARCH_URI, + OPENSEARCH_SSL, + OPENSEARCH_CERT_VERIFY, + OPENSEARCH_USERNAME, + OPENSEARCH_PASSWORD +) + +class OpenSearchClient: + def __init__(self): + self.index_prefix = "open_webui" + self.client = OpenSearch( + hosts=[OPENSEARCH_URI], + use_ssl=OPENSEARCH_SSL, + verify_certs=OPENSEARCH_CERT_VERIFY, + http_auth=(OPENSEARCH_USERNAME, OPENSEARCH_PASSWORD), + ) + + def _result_to_get_result(self, result) -> GetResult: + ids = [] + documents = [] + metadatas = [] + + for hit in result['hits']['hits']: + ids.append(hit['_id']) + documents.append(hit['_source'].get("text")) + metadatas.append(hit['_source'].get("metadata")) + + return GetResult(ids=ids, documents=documents, metadatas=metadatas) + + def _result_to_search_result(self, result) -> SearchResult: + ids = [] + distances = [] + documents = [] + metadatas = [] + + for hit in result['hits']['hits']: + ids.append(hit['_id']) + distances.append(hit['_score']) + documents.append(hit['_source'].get("text")) + metadatas.append(hit['_source'].get("metadata")) + + return SearchResult(ids=ids, distances=distances, documents=documents, metadatas=metadatas) + + def _create_index(self, index_name: str, dimension: int): + body = { + "mappings": { + "properties": { + "id": {"type": "keyword"}, + "vector": { + "type": "dense_vector", + "dims": dimension, # Adjust based on your vector dimensions + "index": true, + "similarity": "faiss", + "method": { + "name": "hnsw", + "space_type": "ip", # Use inner product to approximate cosine similarity + "engine": "faiss", + "ef_construction": 128, + "m": 16 + } + }, + "text": {"type": "text"}, + "metadata": {"type": "object"} + } + } + } + self.client.indices.create(index=f"{self.index_prefix}_{index_name}", body=body) + + def _create_batches(self, items: list[VectorItem], batch_size=100): + for i in range(0, len(items), batch_size): + yield items[i:i + batch_size] + + def has_collection(self, index_name: str) -> bool: + # has_collection here means has index. + # We are simply adapting to the norms of the other DBs. + return self.client.indices.exists(index=f"{self.index_prefix}_{index_name}") + + def delete_colleciton(self, index_name: str): + # delete_collection here means delete index. + # We are simply adapting to the norms of the other DBs. + self.client.indices.delete(index=f"{self.index_prefix}_{index_name}") + + def search(self, index_name: str, vectors: list[list[float]], limit: int) -> Optional[SearchResult]: + query = { + "size": limit, + "_source": ["text", "metadata"], + "query": { + "script_score": { + "query": {"match_all": {}}, + "script": { + "source": "cosineSimilarity(params.vector, 'vector') + 1.0", + "params": {"vector": vectors[0]} # Assuming single query vector + } + } + } + } + + result = self.client.search( + index=f"{self.index_prefix}_{index_name}", + body=query + ) + + return self._result_to_search_result(result) + + def get_or_create_index(self, index_name: str, dimension: int): + if not self.has_index(index_name): + self._create_index(index_name, dimension) + + def get(self, index_name: str) -> Optional[GetResult]: + query = { + "query": {"match_all": {}}, + "_source": ["text", "metadata"] + } + + result = self.client.search(index=f"{self.index_prefix}_{index_name}", body=query) + return self._result_to_get_result(result) + + def insert(self, index_name: str, items: list[VectorItem]): + if not self.has_index(index_name): + self._create_index(index_name, dimension=len(items[0]["vector"])) + + for batch in self._create_batches(items): + actions = [ + {"index": {"_id": item["id"], "_source": {"vector": item["vector"], "text": item["text"], "metadata": item["metadata"]}}} + for item in batch + ] + self.client.bulk(actions) + + def upsert(self, index_name: str, items: list[VectorItem]): + if not self.has_index(index_name): + self._create_index(index_name, dimension=len(items[0]["vector"])) + + for batch in self._create_batches(items): + actions = [ + {"index": {"_id": item["id"], "_source": {"vector": item["vector"], "text": item["text"], "metadata": item["metadata"]}}} + for item in batch + ] + self.client.bulk(actions) + + def delete(self, index_name: str, ids: list[str]): + actions = [{"delete": {"_index": f"{self.index_prefix}_{index_name}", "_id": id}} for id in ids] + self.client.bulk(body=actions) + + def reset(self): + indices = self.client.indices.get(index=f"{self.index_prefix}_*") + for index in indices: + self.client.indices.delete(index=index) diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py index f7c00ccde..f7221eeaf 100644 --- a/backend/open_webui/config.py +++ b/backend/open_webui/config.py @@ -957,6 +957,13 @@ MILVUS_URI = os.environ.get("MILVUS_URI", f"{DATA_DIR}/vector_db/milvus.db") # Qdrant QDRANT_URI = os.environ.get("QDRANT_URI", None) +# OpenSearch +OPENSEARCH_URI = os.environ.get("OPENSEARCH_URI", "https://localhost:9200") +OPENSEARCH_SSL = os.environ.get("OPENSEARCH_SSL", True) +OPENSEARCH_CERT_VERIFY = os.environ.get("OPENSEARCH_CERT_VERIFY", False) +OPENSEARCH_USERNAME = os.environ.get("OPENSEARCH_USERNAME", None) +OPENSEARCH_PASSWORD = os.environ.get("OPENSEARCH_PASSWORD", None) + #################################### # Information Retrieval (RAG) #################################### diff --git a/backend/requirements.txt b/backend/requirements.txt index 52de2ba8c..b37b02b81 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -43,6 +43,7 @@ fake-useragent==1.5.1 chromadb==0.5.15 pymilvus==2.4.9 qdrant-client~=1.12.0 +opensearch-py==2.7.1 sentence-transformers==3.2.0 colbert-ai==0.2.21 diff --git a/pyproject.toml b/pyproject.toml index 31fdd3e4a..305ced6eb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,6 +49,7 @@ dependencies = [ "fake-useragent==1.5.1", "chromadb==0.5.9", "pymilvus==2.4.7", + "opensearch-py==2.7.1", "sentence-transformers==3.2.0", "colbert-ai==0.2.21",