mirror of
https://github.com/open-webui/open-webui
synced 2025-04-25 08:48:21 +00:00
chore: format
This commit is contained in:
parent
768b7e139c
commit
9936583477
@ -110,8 +110,7 @@ class ChromaClient:
|
||||
def insert(self, collection_name: str, items: list[VectorItem]):
|
||||
# Insert the items into the collection, if the collection does not exist, it will be created.
|
||||
collection = self.client.get_or_create_collection(
|
||||
name=collection_name,
|
||||
metadata={"hnsw:space": "cosine"}
|
||||
name=collection_name, metadata={"hnsw:space": "cosine"}
|
||||
)
|
||||
|
||||
ids = [item["id"] for item in items]
|
||||
@ -131,8 +130,7 @@ class ChromaClient:
|
||||
def upsert(self, collection_name: str, items: list[VectorItem]):
|
||||
# Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created.
|
||||
collection = self.client.get_or_create_collection(
|
||||
name=collection_name,
|
||||
metadata={"hnsw:space": "cosine"}
|
||||
name=collection_name, metadata={"hnsw:space": "cosine"}
|
||||
)
|
||||
|
||||
ids = [item["id"] for item in items]
|
||||
|
@ -9,6 +9,7 @@ from open_webui.config import QDRANT_URI
|
||||
|
||||
NO_LIMIT = 999999999
|
||||
|
||||
|
||||
class QdrantClient:
|
||||
def __init__(self):
|
||||
self.collection_prefix = "open-webui"
|
||||
@ -38,15 +39,15 @@ class QdrantClient:
|
||||
collection_name_with_prefix = f"{self.collection_prefix}_{collection_name}"
|
||||
self.client.create_collection(
|
||||
collection_name=collection_name_with_prefix,
|
||||
vectors_config=models.VectorParams(size=dimension, distance=models.Distance.COSINE),
|
||||
vectors_config=models.VectorParams(
|
||||
size=dimension, distance=models.Distance.COSINE
|
||||
),
|
||||
)
|
||||
|
||||
print(f"collection {collection_name_with_prefix} successfully created!")
|
||||
|
||||
def _create_collection_if_not_exists(self, collection_name, dimension):
|
||||
if not self.has_collection(
|
||||
collection_name=collection_name
|
||||
):
|
||||
if not self.has_collection(collection_name=collection_name):
|
||||
self._create_collection(
|
||||
collection_name=collection_name, dimension=dimension
|
||||
)
|
||||
@ -56,19 +57,20 @@ class QdrantClient:
|
||||
PointStruct(
|
||||
id=item["id"],
|
||||
vector=item["vector"],
|
||||
payload={
|
||||
"text": item["text"],
|
||||
"metadata": item["metadata"]
|
||||
},
|
||||
payload={"text": item["text"], "metadata": item["metadata"]},
|
||||
)
|
||||
for item in items
|
||||
]
|
||||
|
||||
def has_collection(self, collection_name: str) -> bool:
|
||||
return self.client.collection_exists(f"{self.collection_prefix}_{collection_name}")
|
||||
return self.client.collection_exists(
|
||||
f"{self.collection_prefix}_{collection_name}"
|
||||
)
|
||||
|
||||
def delete_collection(self, collection_name: str):
|
||||
return self.client.delete_collection(collection_name=f"{self.collection_prefix}_{collection_name}")
|
||||
return self.client.delete_collection(
|
||||
collection_name=f"{self.collection_prefix}_{collection_name}"
|
||||
)
|
||||
|
||||
def search(
|
||||
self, collection_name: str, vectors: list[list[float | int]], limit: int
|
||||
@ -87,7 +89,7 @@ class QdrantClient:
|
||||
ids=get_result.ids,
|
||||
documents=get_result.documents,
|
||||
metadatas=get_result.metadatas,
|
||||
distances=[[point.score for point in query_response.points]]
|
||||
distances=[[point.score for point in query_response.points]],
|
||||
)
|
||||
|
||||
def query(self, collection_name: str, filter: dict, limit: Optional[int] = None):
|
||||
@ -101,7 +103,10 @@ class QdrantClient:
|
||||
field_conditions = []
|
||||
for key, value in filter.items():
|
||||
field_conditions.append(
|
||||
models.FieldCondition(key=f"metadata.{key}", match=models.MatchValue(value=value)))
|
||||
models.FieldCondition(
|
||||
key=f"metadata.{key}", match=models.MatchValue(value=value)
|
||||
)
|
||||
)
|
||||
|
||||
points = self.client.query_points(
|
||||
collection_name=f"{self.collection_prefix}_{collection_name}",
|
||||
@ -117,7 +122,7 @@ class QdrantClient:
|
||||
# Get all the items in the collection.
|
||||
points = self.client.query_points(
|
||||
collection_name=f"{self.collection_prefix}_{collection_name}",
|
||||
limit=NO_LIMIT # otherwise qdrant would set limit to 10!
|
||||
limit=NO_LIMIT, # otherwise qdrant would set limit to 10!
|
||||
)
|
||||
return self._result_to_get_result(points.points)
|
||||
|
||||
@ -162,9 +167,7 @@ class QdrantClient:
|
||||
return self.client.delete(
|
||||
collection_name=f"{self.collection_prefix}_{collection_name}",
|
||||
points_selector=models.FilterSelector(
|
||||
filter=models.Filter(
|
||||
must=field_conditions
|
||||
)
|
||||
filter=models.Filter(must=field_conditions)
|
||||
),
|
||||
)
|
||||
|
||||
|
@ -409,7 +409,10 @@ OAUTH_ROLES_CLAIM = PersistentConfig(
|
||||
OAUTH_ALLOWED_ROLES = PersistentConfig(
|
||||
"OAUTH_ALLOWED_ROLES",
|
||||
"oauth.allowed_roles",
|
||||
[role.strip() for role in os.environ.get("OAUTH_ALLOWED_ROLES", "user,admin").split(",")],
|
||||
[
|
||||
role.strip()
|
||||
for role in os.environ.get("OAUTH_ALLOWED_ROLES", "user,admin").split(",")
|
||||
],
|
||||
)
|
||||
|
||||
OAUTH_ADMIN_ROLES = PersistentConfig(
|
||||
@ -418,6 +421,7 @@ OAUTH_ADMIN_ROLES = PersistentConfig(
|
||||
[role.strip() for role in os.environ.get("OAUTH_ADMIN_ROLES", "admin").split(",")],
|
||||
)
|
||||
|
||||
|
||||
def load_oauth_providers():
|
||||
OAUTH_PROVIDERS.clear()
|
||||
if GOOGLE_CLIENT_ID.value and GOOGLE_CLIENT_SECRET.value:
|
||||
|
@ -208,8 +208,6 @@ app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = (
|
||||
app.state.MODELS = {}
|
||||
|
||||
|
||||
|
||||
|
||||
##################################
|
||||
#
|
||||
# ChatCompletion Middleware
|
||||
|
@ -25,7 +25,10 @@ from open_webui.config import (
|
||||
OAUTH_PICTURE_CLAIM,
|
||||
OAUTH_USERNAME_CLAIM,
|
||||
OAUTH_ALLOWED_ROLES,
|
||||
OAUTH_ADMIN_ROLES, WEBHOOK_URL, JWT_EXPIRES_IN, AppConfig,
|
||||
OAUTH_ADMIN_ROLES,
|
||||
WEBHOOK_URL,
|
||||
JWT_EXPIRES_IN,
|
||||
AppConfig,
|
||||
)
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
from open_webui.env import WEBUI_SESSION_COOKIE_SAME_SITE, WEBUI_SESSION_COOKIE_SECURE
|
||||
@ -170,7 +173,9 @@ class OAuthManager:
|
||||
# If the user does not exist, check if signups are enabled
|
||||
if auth_manager_config.ENABLE_OAUTH_SIGNUP.value:
|
||||
# Check if an existing user with the same email already exists
|
||||
existing_user = Users.get_user_by_email(user_data.get("email", "").lower())
|
||||
existing_user = Users.get_user_by_email(
|
||||
user_data.get("email", "").lower()
|
||||
)
|
||||
if existing_user:
|
||||
raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
|
||||
|
||||
@ -182,16 +187,18 @@ class OAuthManager:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(picture_url) as resp:
|
||||
picture = await resp.read()
|
||||
base64_encoded_picture = base64.b64encode(picture).decode(
|
||||
"utf-8"
|
||||
)
|
||||
base64_encoded_picture = base64.b64encode(
|
||||
picture
|
||||
).decode("utf-8")
|
||||
guessed_mime_type = mimetypes.guess_type(picture_url)[0]
|
||||
if guessed_mime_type is None:
|
||||
# assume JPG, browsers are tolerant enough of image formats
|
||||
guessed_mime_type = "image/jpeg"
|
||||
picture_url = f"data:{guessed_mime_type};base64,{base64_encoded_picture}"
|
||||
except Exception as e:
|
||||
log.error(f"Error downloading profile image '{picture_url}': {e}")
|
||||
log.error(
|
||||
f"Error downloading profile image '{picture_url}': {e}"
|
||||
)
|
||||
picture_url = ""
|
||||
if not picture_url:
|
||||
picture_url = "/user.png"
|
||||
@ -216,7 +223,9 @@ class OAuthManager:
|
||||
auth_manager_config.WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
|
||||
{
|
||||
"action": "signup",
|
||||
"message": auth_manager_config.WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
|
||||
"message": auth_manager_config.WEBHOOK_MESSAGES.USER_SIGNUP(
|
||||
user.name
|
||||
),
|
||||
"user": user.model_dump_json(exclude_none=True),
|
||||
},
|
||||
)
|
||||
@ -243,4 +252,5 @@ class OAuthManager:
|
||||
redirect_url = f"{request.base_url}auth#token={jwt_token}"
|
||||
return RedirectResponse(url=redirect_url)
|
||||
|
||||
|
||||
oauth_manager = OAuthManager()
|
Loading…
Reference in New Issue
Block a user