mirror of
				https://github.com/open-webui/open-webui
				synced 2025-06-26 18:26:48 +00:00 
			
		
		
		
	Merge pull request #11089 from tupe2009/kleqon.feat-add-elasticsearch-support
feat: Elasticsearch as a vector store support
This commit is contained in:
		
						commit
						6f8c1a8f0d
					
				| @ -1541,6 +1541,15 @@ OPENSEARCH_CERT_VERIFY = os.environ.get("OPENSEARCH_CERT_VERIFY", False) | ||||
| OPENSEARCH_USERNAME = os.environ.get("OPENSEARCH_USERNAME", None) | ||||
| OPENSEARCH_PASSWORD = os.environ.get("OPENSEARCH_PASSWORD", None) | ||||
| 
 | ||||
| # ElasticSearch | ||||
| ELASTICSEARCH_URL = os.environ.get("ELASTICSEARCH_URL", "https://localhost:9200") | ||||
| ELASTICSEARCH_CA_CERTS = os.environ.get("ELASTICSEARCH_CA_CERTS", None) | ||||
| ELASTICSEARCH_API_KEY = os.environ.get("ELASTICSEARCH_API_KEY", None) | ||||
| ELASTICSEARCH_USERNAME = os.environ.get("ELASTICSEARCH_USERNAME", None) | ||||
| ELASTICSEARCH_PASSWORD = os.environ.get("ELASTICSEARCH_PASSWORD", None) | ||||
| ELASTICSEARCH_CLOUD_ID = os.environ.get("ELASTICSEARCH_CLOUD_ID", None) | ||||
| SSL_ASSERT_FINGERPRINT = os.environ.get("SSL_ASSERT_FINGERPRINT", None) | ||||
| 
 | ||||
| # Pgvector | ||||
| PGVECTOR_DB_URL = os.environ.get("PGVECTOR_DB_URL", DATABASE_URL) | ||||
| if VECTOR_DB == "pgvector" and not PGVECTOR_DB_URL.startswith("postgres"): | ||||
|  | ||||
| @ -16,6 +16,10 @@ elif VECTOR_DB == "pgvector": | ||||
|     from open_webui.retrieval.vector.dbs.pgvector import PgvectorClient | ||||
| 
 | ||||
|     VECTOR_DB_CLIENT = PgvectorClient() | ||||
| elif VECTOR_DB == "elasticsearch": | ||||
|     from open_webui.retrieval.vector.dbs.elasticsearch import ElasticsearchClient  | ||||
| 
 | ||||
|     VECTOR_DB_CLIENT = ElasticsearchClient() | ||||
| else: | ||||
|     from open_webui.retrieval.vector.dbs.chroma import ChromaClient | ||||
| 
 | ||||
|  | ||||
							
								
								
									
										283
									
								
								backend/open_webui/retrieval/vector/dbs/elasticsearch.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										283
									
								
								backend/open_webui/retrieval/vector/dbs/elasticsearch.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,283 @@ | ||||
| from elasticsearch import Elasticsearch, BadRequestError | ||||
| from typing import Optional | ||||
| import ssl | ||||
| from elasticsearch.helpers import bulk,scan | ||||
| from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult | ||||
| from open_webui.config import ( | ||||
|     ELASTICSEARCH_URL, | ||||
|     ELASTICSEARCH_CA_CERTS,  | ||||
|     ELASTICSEARCH_API_KEY, | ||||
|     ELASTICSEARCH_USERNAME, | ||||
|     ELASTICSEARCH_PASSWORD,  | ||||
|     ELASTICSEARCH_CLOUD_ID, | ||||
|     SSL_ASSERT_FINGERPRINT | ||||
| ) | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| class ElasticsearchClient: | ||||
|     """ | ||||
|     Important: | ||||
|     in order to reduce the number of indexes and since the embedding vector length is fixed, we avoid creating  | ||||
|     an index for each file but store it as a text field, while seperating to different index  | ||||
|     baesd on the embedding length. | ||||
|     """ | ||||
|     def __init__(self): | ||||
|         self.index_prefix = "open_webui_collections" | ||||
|         self.client = Elasticsearch( | ||||
|             hosts=[ELASTICSEARCH_URL], | ||||
|             ca_certs=ELASTICSEARCH_CA_CERTS, | ||||
|             api_key=ELASTICSEARCH_API_KEY, | ||||
|             cloud_id=ELASTICSEARCH_CLOUD_ID, | ||||
|             basic_auth=(ELASTICSEARCH_USERNAME,ELASTICSEARCH_PASSWORD) if ELASTICSEARCH_USERNAME and ELASTICSEARCH_PASSWORD else None, | ||||
|             ssl_assert_fingerprint=SSL_ASSERT_FINGERPRINT | ||||
|              | ||||
|         ) | ||||
|     #Status: works | ||||
|     def _get_index_name(self,dimension:int)->str: | ||||
|         return f"{self.index_prefix}_d{str(dimension)}" | ||||
|      | ||||
|     #Status: works | ||||
|     def _scan_result_to_get_result(self, result) -> GetResult: | ||||
|         if not result: | ||||
|             return None | ||||
|         ids = [] | ||||
|         documents = [] | ||||
|         metadatas = [] | ||||
| 
 | ||||
|         for hit in result: | ||||
|             ids.append(hit["_id"]) | ||||
|             documents.append(hit["_source"].get("text")) | ||||
|             metadatas.append(hit["_source"].get("metadata")) | ||||
| 
 | ||||
|         return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas]) | ||||
| 
 | ||||
|     #Status: works | ||||
|     def _result_to_get_result(self, result) -> GetResult: | ||||
|         if not result["hits"]["hits"]: | ||||
|             return None | ||||
|         ids = [] | ||||
|         documents = [] | ||||
|         metadatas = [] | ||||
| 
 | ||||
|         for hit in result["hits"]["hits"]: | ||||
|             ids.append(hit["_id"]) | ||||
|             documents.append(hit["_source"].get("text")) | ||||
|             metadatas.append(hit["_source"].get("metadata")) | ||||
| 
 | ||||
|         return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas]) | ||||
| 
 | ||||
|     #Status: works | ||||
|     def _result_to_search_result(self, result) -> SearchResult: | ||||
|         ids = [] | ||||
|         distances = [] | ||||
|         documents = [] | ||||
|         metadatas = [] | ||||
| 
 | ||||
|         for hit in result["hits"]["hits"]: | ||||
|             ids.append(hit["_id"]) | ||||
|             distances.append(hit["_score"]) | ||||
|             documents.append(hit["_source"].get("text")) | ||||
|             metadatas.append(hit["_source"].get("metadata")) | ||||
| 
 | ||||
|         return SearchResult( | ||||
|             ids=[ids], distances=[distances], documents=[documents], metadatas=[metadatas] | ||||
|         ) | ||||
|     #Status: works | ||||
|     def _create_index(self, dimension: int): | ||||
|         body = { | ||||
|             "mappings": { | ||||
|                 "properties": { | ||||
|                     "collection": {"type": "keyword"}, | ||||
|                     "id": {"type": "keyword"}, | ||||
|                     "vector": { | ||||
|                         "type": "dense_vector", | ||||
|                         "dims": dimension,  # Adjust based on your vector dimensions | ||||
|                         "index": True, | ||||
|                         "similarity": "cosine", | ||||
|                     }, | ||||
|                     "text": {"type": "text"}, | ||||
|                     "metadata": {"type": "object"}, | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|         self.client.indices.create(index=self._get_index_name(dimension), body=body) | ||||
|     #Status: works | ||||
| 
 | ||||
|     def _create_batches(self, items: list[VectorItem], batch_size=100): | ||||
|         for i in range(0, len(items), batch_size): | ||||
|             yield items[i : min(i + batch_size,len(items))] | ||||
| 
 | ||||
|     #Status: works | ||||
|     def has_collection(self,collection_name) -> bool: | ||||
|         query_body = {"query": {"bool": {"filter": []}}} | ||||
|         query_body["query"]["bool"]["filter"].append({"term": {"collection": collection_name}}) | ||||
| 
 | ||||
|         try: | ||||
|             result = self.client.count( | ||||
|                 index=f"{self.index_prefix}*", | ||||
|                 body=query_body | ||||
|             ) | ||||
|              | ||||
|             return result.body["count"]>0 | ||||
|         except Exception as e: | ||||
|             return None | ||||
|          | ||||
| 
 | ||||
|          | ||||
|     #@TODO: Make this delete a collection and not an index | ||||
|     def delete_colleciton(self, collection_name: str): | ||||
|         # TODO: fix this to include the dimension or a * prefix | ||||
|         # delete_collection here means delete a bunch of documents for an index. | ||||
|         # We are simply adapting to the norms of the other DBs. | ||||
|         self.client.indices.delete(index=self._get_collection_name(collection_name)) | ||||
|     #Status: works | ||||
|     def search( | ||||
|         self, collection_name: str, vectors: list[list[float]], limit: int | ||||
|     ) -> Optional[SearchResult]: | ||||
|         query = { | ||||
|             "size": limit, | ||||
|             "_source": [ | ||||
|                 "text", | ||||
|                 "metadata" | ||||
|             ], | ||||
|             "query": { | ||||
|                 "script_score": { | ||||
|                     "query": { | ||||
|                         "bool": { | ||||
|                             "filter": [ | ||||
|                                 { | ||||
|                                     "term": { | ||||
|                                         "collection": collection_name | ||||
|                                     } | ||||
|                                 } | ||||
|                             ] | ||||
|                         } | ||||
|                     }, | ||||
|                     "script": { | ||||
|                         "source": "cosineSimilarity(params.vector, 'vector') + 1.0", | ||||
|                         "params": { | ||||
|                             "vector": vectors[0] | ||||
|                         }, # Assuming single query vector | ||||
|                     }, | ||||
|                 } | ||||
|             }, | ||||
|         } | ||||
| 
 | ||||
|         result = self.client.search( | ||||
|             index=self._get_index_name(len(vectors[0])), body=query | ||||
|         ) | ||||
| 
 | ||||
|         return self._result_to_search_result(result) | ||||
|     #Status: only tested halfwat | ||||
|     def query( | ||||
|         self, collection_name: str, filter: dict, limit: Optional[int] = None | ||||
|     ) -> Optional[GetResult]: | ||||
|         if not self.has_collection(collection_name): | ||||
|             return None | ||||
| 
 | ||||
|         query_body = { | ||||
|             "query": {"bool": {"filter": []}}, | ||||
|             "_source": ["text", "metadata"], | ||||
|         } | ||||
| 
 | ||||
|         for field, value in filter.items(): | ||||
|             query_body["query"]["bool"]["filter"].append({"term": {field: value}}) | ||||
|         query_body["query"]["bool"]["filter"].append({"term": {"collection": collection_name}}) | ||||
|         size = limit if limit else 10 | ||||
| 
 | ||||
|         try: | ||||
|             result = self.client.search( | ||||
|                 index=f"{self.index_prefix}*", | ||||
|                 body=query_body, | ||||
|                 size=size, | ||||
|             ) | ||||
|              | ||||
|             return self._result_to_get_result(result) | ||||
| 
 | ||||
|         except Exception as e: | ||||
|             return None | ||||
|     #Status: works | ||||
|     def _has_index(self,dimension:int): | ||||
|         return self.client.indices.exists(index=self._get_index_name(dimension=dimension)) | ||||
| 
 | ||||
| 
 | ||||
|     def get_or_create_index(self, dimension: int): | ||||
|         if not self._has_index(dimension=dimension): | ||||
|             self._create_index(dimension=dimension) | ||||
|     #Status: works | ||||
|     def get(self, collection_name: str) -> Optional[GetResult]: | ||||
|         # Get all the items in the collection. | ||||
|         query = { | ||||
|                     "query": { | ||||
|                         "bool": { | ||||
|                             "filter": [ | ||||
|                                 { | ||||
|                                     "term": { | ||||
|                                         "collection": collection_name | ||||
|                                     } | ||||
|                                 } | ||||
|                             ] | ||||
|                         } | ||||
|                     }, "_source": ["text", "metadata"]} | ||||
|         results = list(scan(self.client, index=f"{self.index_prefix}*", query=query)) | ||||
|          | ||||
|         return self._scan_result_to_get_result(results) | ||||
| 
 | ||||
|     #Status: works | ||||
|     def insert(self, collection_name: str, items: list[VectorItem]): | ||||
|         if not self._has_index(dimension=len(items[0]["vector"])): | ||||
|             self._create_index(dimension=len(items[0]["vector"])) | ||||
| 
 | ||||
| 
 | ||||
|         for batch in self._create_batches(items): | ||||
|             actions = [ | ||||
|                 { | ||||
|                         "_index":self._get_index_name(dimension=len(items[0]["vector"])), | ||||
|                         "_id": item["id"], | ||||
|                         "_source": { | ||||
|                             "collection": collection_name, | ||||
|                             "vector": item["vector"], | ||||
|                             "text": item["text"], | ||||
|                             "metadata": item["metadata"], | ||||
|                         }, | ||||
|                     } | ||||
|                 for item in batch | ||||
|             ] | ||||
|             bulk(self.client,actions) | ||||
|     # Status: should work | ||||
|     def upsert(self, collection_name: str, items: list[VectorItem]): | ||||
|         if not self._has_index(dimension=len(items[0]["vector"])): | ||||
|             self._create_index(collection_name, dimension=len(items[0]["vector"])) | ||||
| 
 | ||||
|         for batch in self._create_batches(items): | ||||
|             actions = [ | ||||
|                 { | ||||
|                         "_index": self._get_index_name(dimension=len(items[0]["vector"])), | ||||
|                         "_id": item["id"], | ||||
|                         "_source": { | ||||
|                             "vector": item["vector"], | ||||
|                             "text": item["text"], | ||||
|                             "metadata": item["metadata"], | ||||
|                         }, | ||||
|                      | ||||
|                 } | ||||
|                 for item in batch | ||||
|             ] | ||||
|             self.client.bulk(actions) | ||||
| 
 | ||||
|     #TODO: This currently deletes by * which is not always supported in ElasticSearch.  | ||||
|     # Need to read a bit before changing. Also, need to delete from a specific collection | ||||
|     def delete(self, collection_name: str, ids: list[str]): | ||||
|         #Assuming ID is unique across collections and indexes | ||||
|         actions = [ | ||||
|             {"delete": {"_index": f"{self.index_prefix}*", "_id": id}} | ||||
|             for id in ids | ||||
|         ] | ||||
|         self.client.bulk(body=actions) | ||||
| 
 | ||||
|     def reset(self): | ||||
|         indices = self.client.indices.get(index=f"{self.index_prefix}*") | ||||
|         for index in indices: | ||||
|             self.client.indices.delete(index=index) | ||||
| @ -49,6 +49,8 @@ pymilvus==2.5.0 | ||||
| qdrant-client~=1.12.0 | ||||
| opensearch-py==2.8.0 | ||||
| playwright==1.49.1 # Caution: version must match docker-compose.playwright.yaml | ||||
| elasticsearch==8.17.1 | ||||
| 
 | ||||
| 
 | ||||
| transformers | ||||
| sentence-transformers==3.3.1 | ||||
|  | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user