chore: format

This commit is contained in:
Timothy Jaeryang Baek
2025-03-04 00:32:27 -08:00
parent 8697f72068
commit 39ea59edc8
56 changed files with 449 additions and 112 deletions

View File

@@ -659,11 +659,7 @@ if CUSTOM_NAME:
# LICENSE_KEY
####################################
LICENSE_KEY = PersistentConfig(
"LICENSE_KEY",
"license.key",
os.environ.get("LICENSE_KEY", ""),
)
LICENSE_KEY = os.environ.get("LICENSE_KEY", "")
####################################
# STORAGE PROVIDER

View File

@@ -401,8 +401,8 @@ async def lifespan(app: FastAPI):
if RESET_CONFIG_ON_START:
reset_config()
if app.state.config.LICENSE_KEY:
get_license_data(app, app.state.config.LICENSE_KEY)
if LICENSE_KEY:
get_license_data(app, LICENSE_KEY)
asyncio.create_task(periodic_usage_pool_cleanup())
yield
@@ -420,7 +420,7 @@ oauth_manager = OAuthManager(app)
app.state.config = AppConfig()
app.state.WEBUI_NAME = WEBUI_NAME
app.state.config.LICENSE_KEY = LICENSE_KEY
app.state.LICENSE_DATA = None
########################################
#
@@ -1218,6 +1218,7 @@ async def get_app_config(request: Request):
{
"record_count": user_count,
"active_entries": app.state.USER_COUNT,
"license_data": app.state.LICENSE_DATA,
}
if user.role == "admin"
else {}

View File

@@ -17,7 +17,7 @@ elif VECTOR_DB == "pgvector":
VECTOR_DB_CLIENT = PgvectorClient()
elif VECTOR_DB == "elasticsearch":
from open_webui.retrieval.vector.dbs.elasticsearch import ElasticsearchClient
from open_webui.retrieval.vector.dbs.elasticsearch import ElasticsearchClient
VECTOR_DB_CLIENT = ElasticsearchClient()
else:

View File

@@ -1,28 +1,27 @@
from elasticsearch import Elasticsearch, BadRequestError
from typing import Optional
import ssl
from elasticsearch.helpers import bulk,scan
from elasticsearch.helpers import bulk, scan
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
from open_webui.config import (
ELASTICSEARCH_URL,
ELASTICSEARCH_CA_CERTS,
ELASTICSEARCH_CA_CERTS,
ELASTICSEARCH_API_KEY,
ELASTICSEARCH_USERNAME,
ELASTICSEARCH_PASSWORD,
ELASTICSEARCH_PASSWORD,
ELASTICSEARCH_CLOUD_ID,
SSL_ASSERT_FINGERPRINT
SSL_ASSERT_FINGERPRINT,
)
class ElasticsearchClient:
"""
Important:
in order to reduce the number of indexes and since the embedding vector length is fixed, we avoid creating
an index for each file but store it as a text field, while seperating to different index
in order to reduce the number of indexes and since the embedding vector length is fixed, we avoid creating
an index for each file but store it as a text field, while seperating to different index
baesd on the embedding length.
"""
def __init__(self):
self.index_prefix = "open_webui_collections"
self.client = Elasticsearch(
@@ -30,15 +29,19 @@ class ElasticsearchClient:
ca_certs=ELASTICSEARCH_CA_CERTS,
api_key=ELASTICSEARCH_API_KEY,
cloud_id=ELASTICSEARCH_CLOUD_ID,
basic_auth=(ELASTICSEARCH_USERNAME,ELASTICSEARCH_PASSWORD) if ELASTICSEARCH_USERNAME and ELASTICSEARCH_PASSWORD else None,
ssl_assert_fingerprint=SSL_ASSERT_FINGERPRINT
basic_auth=(
(ELASTICSEARCH_USERNAME, ELASTICSEARCH_PASSWORD)
if ELASTICSEARCH_USERNAME and ELASTICSEARCH_PASSWORD
else None
),
ssl_assert_fingerprint=SSL_ASSERT_FINGERPRINT,
)
#Status: works
def _get_index_name(self,dimension:int)->str:
# Status: works
def _get_index_name(self, dimension: int) -> str:
return f"{self.index_prefix}_d{str(dimension)}"
#Status: works
# Status: works
def _scan_result_to_get_result(self, result) -> GetResult:
if not result:
return None
@@ -53,7 +56,7 @@ class ElasticsearchClient:
return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas])
#Status: works
# Status: works
def _result_to_get_result(self, result) -> GetResult:
if not result["hits"]["hits"]:
return None
@@ -68,7 +71,7 @@ class ElasticsearchClient:
return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas])
#Status: works
# Status: works
def _result_to_search_result(self, result) -> SearchResult:
ids = []
distances = []
@@ -82,9 +85,13 @@ class ElasticsearchClient:
metadatas.append(hit["_source"].get("metadata"))
return SearchResult(
ids=[ids], distances=[distances], documents=[documents], metadatas=[metadatas]
ids=[ids],
distances=[distances],
documents=[documents],
metadatas=[metadatas],
)
#Status: works
# Status: works
def _create_index(self, dimension: int):
body = {
"mappings": {
@@ -103,63 +110,51 @@ class ElasticsearchClient:
}
}
self.client.indices.create(index=self._get_index_name(dimension), body=body)
#Status: works
# Status: works
def _create_batches(self, items: list[VectorItem], batch_size=100):
for i in range(0, len(items), batch_size):
yield items[i : min(i + batch_size,len(items))]
yield items[i : min(i + batch_size, len(items))]
#Status: works
def has_collection(self,collection_name) -> bool:
# Status: works
def has_collection(self, collection_name) -> bool:
query_body = {"query": {"bool": {"filter": []}}}
query_body["query"]["bool"]["filter"].append({"term": {"collection": collection_name}})
query_body["query"]["bool"]["filter"].append(
{"term": {"collection": collection_name}}
)
try:
result = self.client.count(
index=f"{self.index_prefix}*",
body=query_body
)
return result.body["count"]>0
result = self.client.count(index=f"{self.index_prefix}*", body=query_body)
return result.body["count"] > 0
except Exception as e:
return None
#@TODO: Make this delete a collection and not an index
# @TODO: Make this delete a collection and not an index
def delete_colleciton(self, collection_name: str):
# TODO: fix this to include the dimension or a * prefix
# delete_collection here means delete a bunch of documents for an index.
# We are simply adapting to the norms of the other DBs.
self.client.indices.delete(index=self._get_collection_name(collection_name))
#Status: works
# Status: works
def search(
self, collection_name: str, vectors: list[list[float]], limit: int
) -> Optional[SearchResult]:
query = {
"size": limit,
"_source": [
"text",
"metadata"
],
"_source": ["text", "metadata"],
"query": {
"script_score": {
"query": {
"bool": {
"filter": [
{
"term": {
"collection": collection_name
}
}
]
}
"bool": {"filter": [{"term": {"collection": collection_name}}]}
},
"script": {
"source": "cosineSimilarity(params.vector, 'vector') + 1.0",
"params": {
"vector": vectors[0]
}, # Assuming single query vector
}, # Assuming single query vector
},
}
},
@@ -170,7 +165,8 @@ class ElasticsearchClient:
)
return self._result_to_search_result(result)
#Status: only tested halfwat
# Status: only tested halfwat
def query(
self, collection_name: str, filter: dict, limit: Optional[int] = None
) -> Optional[GetResult]:
@@ -184,7 +180,9 @@ class ElasticsearchClient:
for field, value in filter.items():
query_body["query"]["bool"]["filter"].append({"term": {field: value}})
query_body["query"]["bool"]["filter"].append({"term": {"collection": collection_name}})
query_body["query"]["bool"]["filter"].append(
{"term": {"collection": collection_name}}
)
size = limit if limit else 10
try:
@@ -193,59 +191,54 @@ class ElasticsearchClient:
body=query_body,
size=size,
)
return self._result_to_get_result(result)
except Exception as e:
return None
#Status: works
def _has_index(self,dimension:int):
return self.client.indices.exists(index=self._get_index_name(dimension=dimension))
# Status: works
def _has_index(self, dimension: int):
return self.client.indices.exists(
index=self._get_index_name(dimension=dimension)
)
def get_or_create_index(self, dimension: int):
if not self._has_index(dimension=dimension):
self._create_index(dimension=dimension)
#Status: works
# Status: works
def get(self, collection_name: str) -> Optional[GetResult]:
# Get all the items in the collection.
query = {
"query": {
"bool": {
"filter": [
{
"term": {
"collection": collection_name
}
}
]
}
}, "_source": ["text", "metadata"]}
"query": {"bool": {"filter": [{"term": {"collection": collection_name}}]}},
"_source": ["text", "metadata"],
}
results = list(scan(self.client, index=f"{self.index_prefix}*", query=query))
return self._scan_result_to_get_result(results)
#Status: works
# Status: works
def insert(self, collection_name: str, items: list[VectorItem]):
if not self._has_index(dimension=len(items[0]["vector"])):
self._create_index(dimension=len(items[0]["vector"]))
for batch in self._create_batches(items):
actions = [
{
"_index":self._get_index_name(dimension=len(items[0]["vector"])),
"_id": item["id"],
"_source": {
"collection": collection_name,
"vector": item["vector"],
"text": item["text"],
"metadata": item["metadata"],
},
}
"_index": self._get_index_name(dimension=len(items[0]["vector"])),
"_id": item["id"],
"_source": {
"collection": collection_name,
"vector": item["vector"],
"text": item["text"],
"metadata": item["metadata"],
},
}
for item in batch
]
bulk(self.client,actions)
bulk(self.client, actions)
# Status: should work
def upsert(self, collection_name: str, items: list[VectorItem]):
if not self._has_index(dimension=len(items[0]["vector"])):
@@ -254,26 +247,24 @@ class ElasticsearchClient:
for batch in self._create_batches(items):
actions = [
{
"_index": self._get_index_name(dimension=len(items[0]["vector"])),
"_id": item["id"],
"_source": {
"vector": item["vector"],
"text": item["text"],
"metadata": item["metadata"],
},
"_index": self._get_index_name(dimension=len(items[0]["vector"])),
"_id": item["id"],
"_source": {
"vector": item["vector"],
"text": item["text"],
"metadata": item["metadata"],
},
}
for item in batch
]
self.client.bulk(actions)
#TODO: This currently deletes by * which is not always supported in ElasticSearch.
# TODO: This currently deletes by * which is not always supported in ElasticSearch.
# Need to read a bit before changing. Also, need to delete from a specific collection
def delete(self, collection_name: str, ids: list[str]):
#Assuming ID is unique across collections and indexes
# Assuming ID is unique across collections and indexes
actions = [
{"delete": {"_index": f"{self.index_prefix}*", "_id": id}}
for id in ids
{"delete": {"_index": f"{self.index_prefix}*", "_id": id}} for id in ids
]
self.client.bulk(body=actions)

View File

@@ -72,7 +72,9 @@ class OpenSearchClient:
}
}
}
self.client.indices.create(index=f"{self.index_prefix}_{collection_name}", body=body)
self.client.indices.create(
index=f"{self.index_prefix}_{collection_name}", body=body
)
def _create_batches(self, items: list[VectorItem], batch_size=100):
for i in range(0, len(items), batch_size):
@@ -81,7 +83,9 @@ class OpenSearchClient:
def has_collection(self, collection_name: str) -> bool:
# has_collection here means has index.
# We are simply adapting to the norms of the other DBs.
return self.client.indices.exists(index=f"{self.index_prefix}_{collection_name}")
return self.client.indices.exists(
index=f"{self.index_prefix}_{collection_name}"
)
def delete_colleciton(self, collection_name: str):
# delete_collection here means delete index.
@@ -154,8 +158,9 @@ class OpenSearchClient:
return self._result_to_get_result(result)
def insert(self, collection_name: str, items: list[VectorItem]):
self._create_index_if_not_exists(collection_name=collection_name,
dimension=len(items[0]["vector"]))
self._create_index_if_not_exists(
collection_name=collection_name, dimension=len(items[0]["vector"])
)
for batch in self._create_batches(items):
actions = [
@@ -174,8 +179,9 @@ class OpenSearchClient:
self.client.bulk(actions)
def upsert(self, collection_name: str, items: list[VectorItem]):
self._create_index_if_not_exists(collection_name=collection_name,
dimension=len(items[0]["vector"]))
self._create_index_if_not_exists(
collection_name=collection_name, dimension=len(items[0]["vector"])
)
for batch in self._create_batches(items):
actions = [

View File

@@ -71,8 +71,9 @@ def override_static(path: str, content: str):
def get_license_data(app, key):
if key:
try:
# https://api.openwebui.com
res = requests.post(
"https://api.openwebui.com/api/v1/license",
"http://localhost:5555/api/v1/license",
json={"key": key, "version": "1"},
timeout=5,
)
@@ -83,11 +84,12 @@ def get_license_data(app, key):
if k == "resources":
for p, c in v.items():
globals().get("override_static", lambda a, b: None)(p, c)
elif k == "user_count":
elif k == "count":
setattr(app.state, "USER_COUNT", v)
elif k == "webui_name":
elif k == "name":
setattr(app.state, "WEBUI_NAME", v)
elif k == "info":
setattr(app.state, "LICENSE_INFO", v)
return True
else:
log.error(