diff --git a/backend/open_webui/apps/rag/vector/connector.py b/backend/open_webui/apps/rag/vector/connector.py new file mode 100644 index 000000000..9b9b6089e --- /dev/null +++ b/backend/open_webui/apps/rag/vector/connector.py @@ -0,0 +1,12 @@ +from open_webui.apps.rag.vector.dbs.chroma import ChromaClient +from open_webui.apps.rag.vector.dbs.milvus import MilvusClient + + +from open_webui.config import VECTOR_DB + +if VECTOR_DB == "milvus": + VECTOR_DB_CLIENT = MilvusClient() +elif VECTOR_DB == "opensearch": + VECTOR_DB_CLIENT = OpenSearchClient() +else: + VECTOR_DB_CLIENT = ChromaClient() diff --git a/backend/open_webui/apps/retrieval/vector/dbs/opensearch.py b/backend/open_webui/apps/retrieval/vector/dbs/opensearch.py new file mode 100644 index 000000000..24ed1324a --- /dev/null +++ b/backend/open_webui/apps/retrieval/vector/dbs/opensearch.py @@ -0,0 +1,150 @@ +from opensearchpy import OpenSearch +from typing import Optional + +from open_webui.apps.rag.vector.main import VectorItem, SearchResult, GetResult +from open_webui.config import ( + OPENSEARCH_URI, # Assuming you define OPENSEARCH_URI in config +) + +class OpenSearchClient: + def __init__(self): + self.index_prefix = "open_webui" + self.client = OpenSearch( + hosts=[config["OPENSEARCH_URI"]], + use_ssl=OPENSEARCH_SSL, + verify_certs=OPENSEARCH_CERT_VERIFY, + http_auth=(OPENSEARCH_USERNAME,OPENSEARCH_PASSWORD), + ) + + def _result_to_get_result(self, result) -> GetResult: + ids = [] + documents = [] + metadatas = [] + + for hit in result['hits']['hits']: + ids.append(hit['_id']) + documents.append(hit['_source'].get("text")) + metadatas.append(hit['_source'].get("metadata")) + + return GetResult(ids=ids, documents=documents, metadatas=metadatas) + + def _result_to_search_result(self, result) -> SearchResult: + ids = [] + distances = [] + documents = [] + metadatas = [] + + for hit in result['hits']['hits']: + ids.append(hit['_id']) + distances.append(hit['_score']) + documents.append(hit['_source'].get("text")) + metadatas.append(hit['_source'].get("metadata")) + + return SearchResult(ids=ids, distances=distances, documents=documents, metadatas=metadatas) + + def _create_index(self, index_name: str, dimension: int): + body = { + "mappings": { + "properties": { + "id": {"type": "keyword"}, + "vector": { + "type": "dense_vector", + "dims": dimension, # Adjust based on your vector dimensions + "index": true, + "similarity": "faiss", + "method": { + "name": "hnsw", + "space_type": "ip", # Use inner product to approximate cosine similarity + "engine": "faiss", + "ef_construction": 128, + "m": 16 + } + }, + "text": {"type": "text"}, + "metadata": {"type": "object"} + } + } + } + self.client.indices.create(index=f"{self.index_prefix}_{index_name}", body=body) + + def _create_batches(self, items: list[VectorItem], batch_size=100): + for i in range(0, len(items), batch_size): + yield items[i:i + batch_size] + + def has_collection(self, index_name: str) -> bool: + # has_collection here means has index. + # We are simply adapting to the norms of the other DBs. + return self.client.indices.exists(index=f"{self.index_prefix}_{index_name}") + + def delete_colleciton(self, index_name: str): + # delete_collection here means delete index. + # We are simply adapting to the norms of the other DBs. + self.client.indices.delete(index=f"{self.index_prefix}_{index_name}") + + def search(self, index_name: str, vectors: list[list[float]], limit: int) -> Optional[SearchResult]: + query = { + "size": limit, + "_source": ["text", "metadata"], + "query": { + "script_score": { + "query": {"match_all": {}}, + "script": { + "source": "cosineSimilarity(params.vector, 'vector') + 1.0", + "params": {"vector": vectors[0]} # Assuming single query vector + } + } + } + } + + result = self.client.search( + index=f"{self.index_prefix}_{index_name}", + body=query + ) + + return self._result_to_search_result(result) + + + def get_or_create_index(self, index_name: str, dimension: int): + if not self.has_index(index_name): + self._create_index(index_name, dimension) + + + def get(self, index_name: str) -> Optional[GetResult]: + query = { + "query": {"match_all": {}}, + "_source": ["text", "metadata"] + } + + result = self.client.search(index=f"{self.index_prefix}_{index_name}", body=query) + return self._result_to_get_result(result) + + def insert(self, index_name: str, items: list[VectorItem]): + if not self.has_index(index_name): + self._create_index(index_name, dimension=len(items[0]["vector"])) + + for batch in self._create_batches(items): + actions = [ + {"index": {"_id": item["id"], "_source": {"vector": item["vector"], "text": item["text"], "metadata": item["metadata"]}}} + for item in batch + ] + self.client.bulk(actions) + + def upsert(self, index_name: str, items: list[VectorItem]): + if not self.has_index(index_name): + self._create_index(index_name, dimension=len(items[0]["vector"])) + + for batch in self._create_batches(items): + actions = [ + {"index": {"_id": item["id"], "_source": {"vector": item["vector"], "text": item["text"], "metadata": item["metadata"]}}} + for item in batch + ] + self.client.bulk(actions) + + def delete(self, index_name: str, ids: list[str]): + actions = [{"delete": {"_index": f"{self.index_prefix}_{index_name}", "_id": id}} for id in ids] + self.client.bulk(body=actions) + + def reset(self): + indices = self.client.indices.get(index=f"{self.index_prefix}_*") + for index in indices: + self.client.indices.delete(index=index) diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py index 06ac258f9..dd6215692 100644 --- a/backend/open_webui/config.py +++ b/backend/open_webui/config.py @@ -957,6 +957,13 @@ MILVUS_URI = os.environ.get("MILVUS_URI", f"{DATA_DIR}/vector_db/milvus.db") # Qdrant QDRANT_URI = os.environ.get("QDRANT_URI", None) +# OpenSearch +OPENSEARCH_URI = os.environ.get("OPENSEARCH_URI", "https://localhost:9200") +OPENSEARCH_SSL = os.environ.get("OPENSEARCH_SSL", True) +OPENSEARCH_CERT_VERIFY = os.environ.get("OPENSEARCH_CERT_VERIFY", False) +OPENSEARCH_USERNAME = os.environ.get("OPENSEARCH_USERNAME", None) +OPENSEARCH_PASSWORD = os.environ.get("OPENSEARCH_PASSWORD", None) + #################################### # Information Retrieval (RAG) ####################################