mirror of
https://github.com/open-webui/open-webui
synced 2025-04-08 14:49:46 +00:00
added elasticsearch support
This commit is contained in:
parent
c7d4d2e41f
commit
737dfd2763
@ -1541,6 +1541,15 @@ OPENSEARCH_CERT_VERIFY = os.environ.get("OPENSEARCH_CERT_VERIFY", False)
|
|||||||
OPENSEARCH_USERNAME = os.environ.get("OPENSEARCH_USERNAME", None)
|
OPENSEARCH_USERNAME = os.environ.get("OPENSEARCH_USERNAME", None)
|
||||||
OPENSEARCH_PASSWORD = os.environ.get("OPENSEARCH_PASSWORD", None)
|
OPENSEARCH_PASSWORD = os.environ.get("OPENSEARCH_PASSWORD", None)
|
||||||
|
|
||||||
|
# ElasticSearch
|
||||||
|
ELASTICSEARCH_URL = os.environ.get("ELASTICSEARCH_URL", "https://localhost:9200")
|
||||||
|
ELASTICSEARCH_CA_CERTS = os.environ.get("ELASTICSEARCH_CA_CERTS", None)
|
||||||
|
ELASTICSEARCH_API_KEY = os.environ.get("ELASTICSEARCH_API_KEY", None)
|
||||||
|
ELASTICSEARCH_USERNAME = os.environ.get("ELASTICSEARCH_USERNAME", None)
|
||||||
|
ELASTICSEARCH_PASSWORD = os.environ.get("ELASTICSEARCH_PASSWORD", None)
|
||||||
|
ELASTICSEARCH_CLOUD_ID = os.environ.get("ELASTICSEARCH_CLOUD_ID", None)
|
||||||
|
SSL_ASSERT_FINGERPRINT = os.environ.get("SSL_ASSERT_FINGERPRINT", None)
|
||||||
|
|
||||||
# Pgvector
|
# Pgvector
|
||||||
PGVECTOR_DB_URL = os.environ.get("PGVECTOR_DB_URL", DATABASE_URL)
|
PGVECTOR_DB_URL = os.environ.get("PGVECTOR_DB_URL", DATABASE_URL)
|
||||||
if VECTOR_DB == "pgvector" and not PGVECTOR_DB_URL.startswith("postgres"):
|
if VECTOR_DB == "pgvector" and not PGVECTOR_DB_URL.startswith("postgres"):
|
||||||
|
@ -16,6 +16,10 @@ elif VECTOR_DB == "pgvector":
|
|||||||
from open_webui.retrieval.vector.dbs.pgvector import PgvectorClient
|
from open_webui.retrieval.vector.dbs.pgvector import PgvectorClient
|
||||||
|
|
||||||
VECTOR_DB_CLIENT = PgvectorClient()
|
VECTOR_DB_CLIENT = PgvectorClient()
|
||||||
|
elif VECTOR_DB == "elasticsearch":
|
||||||
|
from open_webui.retrieval.vector.dbs.elasticsearch import ElasticsearchClient
|
||||||
|
|
||||||
|
VECTOR_DB_CLIENT = ElasticsearchClient()
|
||||||
else:
|
else:
|
||||||
from open_webui.retrieval.vector.dbs.chroma import ChromaClient
|
from open_webui.retrieval.vector.dbs.chroma import ChromaClient
|
||||||
|
|
||||||
|
283
backend/open_webui/retrieval/vector/dbs/elasticsearch.py
Normal file
283
backend/open_webui/retrieval/vector/dbs/elasticsearch.py
Normal file
@ -0,0 +1,283 @@
|
|||||||
|
from elasticsearch import Elasticsearch, BadRequestError
|
||||||
|
from typing import Optional
|
||||||
|
import ssl
|
||||||
|
from elasticsearch.helpers import bulk,scan
|
||||||
|
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
|
||||||
|
from open_webui.config import (
|
||||||
|
ELASTICSEARCH_URL,
|
||||||
|
ELASTICSEARCH_CA_CERTS,
|
||||||
|
ELASTICSEARCH_API_KEY,
|
||||||
|
ELASTICSEARCH_USERNAME,
|
||||||
|
ELASTICSEARCH_PASSWORD,
|
||||||
|
ELASTICSEARCH_CLOUD_ID,
|
||||||
|
SSL_ASSERT_FINGERPRINT
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class ElasticsearchClient:
|
||||||
|
"""
|
||||||
|
Important:
|
||||||
|
in order to reduce the number of indexes and since the embedding vector length is fixed, we avoid creating
|
||||||
|
an index for each file but store it as a text field, while seperating to different index
|
||||||
|
baesd on the embedding length.
|
||||||
|
"""
|
||||||
|
def __init__(self):
|
||||||
|
self.index_prefix = "open_webui_collections"
|
||||||
|
self.client = Elasticsearch(
|
||||||
|
hosts=[ELASTICSEARCH_URL],
|
||||||
|
ca_certs=ELASTICSEARCH_CA_CERTS,
|
||||||
|
api_key=ELASTICSEARCH_API_KEY,
|
||||||
|
cloud_id=ELASTICSEARCH_CLOUD_ID,
|
||||||
|
basic_auth=(ELASTICSEARCH_USERNAME,ELASTICSEARCH_PASSWORD) if ELASTICSEARCH_USERNAME and ELASTICSEARCH_PASSWORD else None,
|
||||||
|
ssl_assert_fingerprint=SSL_ASSERT_FINGERPRINT
|
||||||
|
|
||||||
|
)
|
||||||
|
#Status: works
|
||||||
|
def _get_index_name(self,dimension:int)->str:
|
||||||
|
return f"{self.index_prefix}_d{str(dimension)}"
|
||||||
|
|
||||||
|
#Status: works
|
||||||
|
def _scan_result_to_get_result(self, result) -> GetResult:
|
||||||
|
if not result:
|
||||||
|
return None
|
||||||
|
ids = []
|
||||||
|
documents = []
|
||||||
|
metadatas = []
|
||||||
|
|
||||||
|
for hit in result:
|
||||||
|
ids.append(hit["_id"])
|
||||||
|
documents.append(hit["_source"].get("text"))
|
||||||
|
metadatas.append(hit["_source"].get("metadata"))
|
||||||
|
|
||||||
|
return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas])
|
||||||
|
|
||||||
|
#Status: works
|
||||||
|
def _result_to_get_result(self, result) -> GetResult:
|
||||||
|
if not result["hits"]["hits"]:
|
||||||
|
return None
|
||||||
|
ids = []
|
||||||
|
documents = []
|
||||||
|
metadatas = []
|
||||||
|
|
||||||
|
for hit in result["hits"]["hits"]:
|
||||||
|
ids.append(hit["_id"])
|
||||||
|
documents.append(hit["_source"].get("text"))
|
||||||
|
metadatas.append(hit["_source"].get("metadata"))
|
||||||
|
|
||||||
|
return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas])
|
||||||
|
|
||||||
|
#Status: works
|
||||||
|
def _result_to_search_result(self, result) -> SearchResult:
|
||||||
|
ids = []
|
||||||
|
distances = []
|
||||||
|
documents = []
|
||||||
|
metadatas = []
|
||||||
|
|
||||||
|
for hit in result["hits"]["hits"]:
|
||||||
|
ids.append(hit["_id"])
|
||||||
|
distances.append(hit["_score"])
|
||||||
|
documents.append(hit["_source"].get("text"))
|
||||||
|
metadatas.append(hit["_source"].get("metadata"))
|
||||||
|
|
||||||
|
return SearchResult(
|
||||||
|
ids=[ids], distances=[distances], documents=[documents], metadatas=[metadatas]
|
||||||
|
)
|
||||||
|
#Status: works
|
||||||
|
def _create_index(self, dimension: int):
|
||||||
|
body = {
|
||||||
|
"mappings": {
|
||||||
|
"properties": {
|
||||||
|
"collection": {"type": "keyword"},
|
||||||
|
"id": {"type": "keyword"},
|
||||||
|
"vector": {
|
||||||
|
"type": "dense_vector",
|
||||||
|
"dims": dimension, # Adjust based on your vector dimensions
|
||||||
|
"index": True,
|
||||||
|
"similarity": "cosine",
|
||||||
|
},
|
||||||
|
"text": {"type": "text"},
|
||||||
|
"metadata": {"type": "object"},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
self.client.indices.create(index=self._get_index_name(dimension), body=body)
|
||||||
|
#Status: works
|
||||||
|
|
||||||
|
def _create_batches(self, items: list[VectorItem], batch_size=100):
|
||||||
|
for i in range(0, len(items), batch_size):
|
||||||
|
yield items[i : min(i + batch_size,len(items))]
|
||||||
|
|
||||||
|
#Status: works
|
||||||
|
def has_collection(self,collection_name) -> bool:
|
||||||
|
query_body = {"query": {"bool": {"filter": []}}}
|
||||||
|
query_body["query"]["bool"]["filter"].append({"term": {"collection": collection_name}})
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = self.client.count(
|
||||||
|
index=f"{self.index_prefix}*",
|
||||||
|
body=query_body
|
||||||
|
)
|
||||||
|
|
||||||
|
return result.body["count"]>0
|
||||||
|
except Exception as e:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
#@TODO: Make this delete a collection and not an index
|
||||||
|
def delete_colleciton(self, collection_name: str):
|
||||||
|
# TODO: fix this to include the dimension or a * prefix
|
||||||
|
# delete_collection here means delete a bunch of documents for an index.
|
||||||
|
# We are simply adapting to the norms of the other DBs.
|
||||||
|
self.client.indices.delete(index=self._get_collection_name(collection_name))
|
||||||
|
#Status: works
|
||||||
|
def search(
|
||||||
|
self, collection_name: str, vectors: list[list[float]], limit: int
|
||||||
|
) -> Optional[SearchResult]:
|
||||||
|
query = {
|
||||||
|
"size": limit,
|
||||||
|
"_source": [
|
||||||
|
"text",
|
||||||
|
"metadata"
|
||||||
|
],
|
||||||
|
"query": {
|
||||||
|
"script_score": {
|
||||||
|
"query": {
|
||||||
|
"bool": {
|
||||||
|
"filter": [
|
||||||
|
{
|
||||||
|
"term": {
|
||||||
|
"collection": collection_name
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"script": {
|
||||||
|
"source": "cosineSimilarity(params.vector, 'vector') + 1.0",
|
||||||
|
"params": {
|
||||||
|
"vector": vectors[0]
|
||||||
|
}, # Assuming single query vector
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result = self.client.search(
|
||||||
|
index=self._get_index_name(len(vectors[0])), body=query
|
||||||
|
)
|
||||||
|
|
||||||
|
return self._result_to_search_result(result)
|
||||||
|
#Status: only tested halfwat
|
||||||
|
def query(
|
||||||
|
self, collection_name: str, filter: dict, limit: Optional[int] = None
|
||||||
|
) -> Optional[GetResult]:
|
||||||
|
if not self.has_collection(collection_name):
|
||||||
|
return None
|
||||||
|
|
||||||
|
query_body = {
|
||||||
|
"query": {"bool": {"filter": []}},
|
||||||
|
"_source": ["text", "metadata"],
|
||||||
|
}
|
||||||
|
|
||||||
|
for field, value in filter.items():
|
||||||
|
query_body["query"]["bool"]["filter"].append({"term": {field: value}})
|
||||||
|
query_body["query"]["bool"]["filter"].append({"term": {"collection": collection_name}})
|
||||||
|
size = limit if limit else 10
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = self.client.search(
|
||||||
|
index=f"{self.index_prefix}*",
|
||||||
|
body=query_body,
|
||||||
|
size=size,
|
||||||
|
)
|
||||||
|
|
||||||
|
return self._result_to_get_result(result)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return None
|
||||||
|
#Status: works
|
||||||
|
def _has_index(self,dimension:int):
|
||||||
|
return self.client.indices.exists(index=self._get_index_name(dimension=dimension))
|
||||||
|
|
||||||
|
|
||||||
|
def get_or_create_index(self, dimension: int):
|
||||||
|
if not self._has_index(dimension=dimension):
|
||||||
|
self._create_index(dimension=dimension)
|
||||||
|
#Status: works
|
||||||
|
def get(self, collection_name: str) -> Optional[GetResult]:
|
||||||
|
# Get all the items in the collection.
|
||||||
|
query = {
|
||||||
|
"query": {
|
||||||
|
"bool": {
|
||||||
|
"filter": [
|
||||||
|
{
|
||||||
|
"term": {
|
||||||
|
"collection": collection_name
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}, "_source": ["text", "metadata"]}
|
||||||
|
results = list(scan(self.client, index=f"{self.index_prefix}*", query=query))
|
||||||
|
|
||||||
|
return self._scan_result_to_get_result(results)
|
||||||
|
|
||||||
|
#Status: works
|
||||||
|
def insert(self, collection_name: str, items: list[VectorItem]):
|
||||||
|
if not self._has_index(dimension=len(items[0]["vector"])):
|
||||||
|
self._create_index(dimension=len(items[0]["vector"]))
|
||||||
|
|
||||||
|
|
||||||
|
for batch in self._create_batches(items):
|
||||||
|
actions = [
|
||||||
|
{
|
||||||
|
"_index":self._get_index_name(dimension=len(items[0]["vector"])),
|
||||||
|
"_id": item["id"],
|
||||||
|
"_source": {
|
||||||
|
"collection": collection_name,
|
||||||
|
"vector": item["vector"],
|
||||||
|
"text": item["text"],
|
||||||
|
"metadata": item["metadata"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for item in batch
|
||||||
|
]
|
||||||
|
bulk(self.client,actions)
|
||||||
|
# Status: should work
|
||||||
|
def upsert(self, collection_name: str, items: list[VectorItem]):
|
||||||
|
if not self._has_index(dimension=len(items[0]["vector"])):
|
||||||
|
self._create_index(collection_name, dimension=len(items[0]["vector"]))
|
||||||
|
|
||||||
|
for batch in self._create_batches(items):
|
||||||
|
actions = [
|
||||||
|
{
|
||||||
|
"_index": self._get_index_name(dimension=len(items[0]["vector"])),
|
||||||
|
"_id": item["id"],
|
||||||
|
"_source": {
|
||||||
|
"vector": item["vector"],
|
||||||
|
"text": item["text"],
|
||||||
|
"metadata": item["metadata"],
|
||||||
|
},
|
||||||
|
|
||||||
|
}
|
||||||
|
for item in batch
|
||||||
|
]
|
||||||
|
self.client.bulk(actions)
|
||||||
|
|
||||||
|
#TODO: This currently deletes by * which is not always supported in ElasticSearch.
|
||||||
|
# Need to read a bit before changing. Also, need to delete from a specific collection
|
||||||
|
def delete(self, collection_name: str, ids: list[str]):
|
||||||
|
#Assuming ID is unique across collections and indexes
|
||||||
|
actions = [
|
||||||
|
{"delete": {"_index": f"{self.index_prefix}*", "_id": id}}
|
||||||
|
for id in ids
|
||||||
|
]
|
||||||
|
self.client.bulk(body=actions)
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
indices = self.client.indices.get(index=f"{self.index_prefix}*")
|
||||||
|
for index in indices:
|
||||||
|
self.client.indices.delete(index=index)
|
@ -49,6 +49,8 @@ pymilvus==2.5.0
|
|||||||
qdrant-client~=1.12.0
|
qdrant-client~=1.12.0
|
||||||
opensearch-py==2.8.0
|
opensearch-py==2.8.0
|
||||||
playwright==1.49.1 # Caution: version must match docker-compose.playwright.yaml
|
playwright==1.49.1 # Caution: version must match docker-compose.playwright.yaml
|
||||||
|
elasticsearch==8.17.1
|
||||||
|
|
||||||
|
|
||||||
transformers
|
transformers
|
||||||
sentence-transformers==3.3.1
|
sentence-transformers==3.3.1
|
||||||
|
Loading…
Reference in New Issue
Block a user