diff --git a/backend/open_webui/apps/ollama/main.py b/backend/open_webui/apps/ollama/main.py index f835e3175..cb38a53eb 100644 --- a/backend/open_webui/apps/ollama/main.py +++ b/backend/open_webui/apps/ollama/main.py @@ -547,7 +547,7 @@ class GenerateEmbeddingsForm(BaseModel): class GenerateEmbedForm(BaseModel): model: str - input: list[str]|str + input: list[str] | str truncate: Optional[bool] = None options: Optional[dict] = None keep_alive: Optional[Union[int, str]] = None diff --git a/backend/open_webui/apps/retrieval/vector/dbs/chroma.py b/backend/open_webui/apps/retrieval/vector/dbs/chroma.py index c6d95bd52..7782671a2 100644 --- a/backend/open_webui/apps/retrieval/vector/dbs/chroma.py +++ b/backend/open_webui/apps/retrieval/vector/dbs/chroma.py @@ -110,9 +110,8 @@ class ChromaClient: def insert(self, collection_name: str, items: list[VectorItem]): # Insert the items into the collection, if the collection does not exist, it will be created. collection = self.client.get_or_create_collection( - name=collection_name, - metadata={"hnsw:space": "cosine"} - ) + name=collection_name, metadata={"hnsw:space": "cosine"} + ) ids = [item["id"] for item in items] documents = [item["text"] for item in items] @@ -131,9 +130,8 @@ class ChromaClient: def upsert(self, collection_name: str, items: list[VectorItem]): # Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created. collection = self.client.get_or_create_collection( - name=collection_name, - metadata={"hnsw:space": "cosine"} - ) + name=collection_name, metadata={"hnsw:space": "cosine"} + ) ids = [item["id"] for item in items] documents = [item["text"] for item in items] diff --git a/backend/open_webui/apps/retrieval/vector/dbs/qdrant.py b/backend/open_webui/apps/retrieval/vector/dbs/qdrant.py index 70908dc33..c1e06872f 100644 --- a/backend/open_webui/apps/retrieval/vector/dbs/qdrant.py +++ b/backend/open_webui/apps/retrieval/vector/dbs/qdrant.py @@ -9,6 +9,7 @@ from open_webui.config import QDRANT_URI NO_LIMIT = 999999999 + class QdrantClient: def __init__(self): self.collection_prefix = "open-webui" @@ -38,15 +39,15 @@ class QdrantClient: collection_name_with_prefix = f"{self.collection_prefix}_{collection_name}" self.client.create_collection( collection_name=collection_name_with_prefix, - vectors_config=models.VectorParams(size=dimension, distance=models.Distance.COSINE), + vectors_config=models.VectorParams( + size=dimension, distance=models.Distance.COSINE + ), ) print(f"collection {collection_name_with_prefix} successfully created!") def _create_collection_if_not_exists(self, collection_name, dimension): - if not self.has_collection( - collection_name=collection_name - ): + if not self.has_collection(collection_name=collection_name): self._create_collection( collection_name=collection_name, dimension=dimension ) @@ -56,22 +57,23 @@ class QdrantClient: PointStruct( id=item["id"], vector=item["vector"], - payload={ - "text": item["text"], - "metadata": item["metadata"] - }, + payload={"text": item["text"], "metadata": item["metadata"]}, ) for item in items ] def has_collection(self, collection_name: str) -> bool: - return self.client.collection_exists(f"{self.collection_prefix}_{collection_name}") + return self.client.collection_exists( + f"{self.collection_prefix}_{collection_name}" + ) def delete_collection(self, collection_name: str): - return self.client.delete_collection(collection_name=f"{self.collection_prefix}_{collection_name}") + return self.client.delete_collection( + collection_name=f"{self.collection_prefix}_{collection_name}" + ) def search( - self, collection_name: str, vectors: list[list[float | int]], limit: int + self, collection_name: str, vectors: list[list[float | int]], limit: int ) -> Optional[SearchResult]: # Search for the nearest neighbor items based on the vectors and return 'limit' number of results. if limit is None: @@ -87,7 +89,7 @@ class QdrantClient: ids=get_result.ids, documents=get_result.documents, metadatas=get_result.metadatas, - distances=[[point.score for point in query_response.points]] + distances=[[point.score for point in query_response.points]], ) def query(self, collection_name: str, filter: dict, limit: Optional[int] = None): @@ -101,7 +103,10 @@ class QdrantClient: field_conditions = [] for key, value in filter.items(): field_conditions.append( - models.FieldCondition(key=f"metadata.{key}", match=models.MatchValue(value=value))) + models.FieldCondition( + key=f"metadata.{key}", match=models.MatchValue(value=value) + ) + ) points = self.client.query_points( collection_name=f"{self.collection_prefix}_{collection_name}", @@ -117,7 +122,7 @@ class QdrantClient: # Get all the items in the collection. points = self.client.query_points( collection_name=f"{self.collection_prefix}_{collection_name}", - limit=NO_LIMIT # otherwise qdrant would set limit to 10! + limit=NO_LIMIT, # otherwise qdrant would set limit to 10! ) return self._result_to_get_result(points.points) @@ -134,10 +139,10 @@ class QdrantClient: return self.client.upsert(f"{self.collection_prefix}_{collection_name}", points) def delete( - self, - collection_name: str, - ids: Optional[list[str]] = None, - filter: Optional[dict] = None, + self, + collection_name: str, + ids: Optional[list[str]] = None, + filter: Optional[dict] = None, ): # Delete the items from the collection based on the ids. field_conditions = [] @@ -162,9 +167,7 @@ class QdrantClient: return self.client.delete( collection_name=f"{self.collection_prefix}_{collection_name}", points_selector=models.FilterSelector( - filter=models.Filter( - must=field_conditions - ) + filter=models.Filter(must=field_conditions) ), ) diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py index ab8424f59..496f2395f 100644 --- a/backend/open_webui/config.py +++ b/backend/open_webui/config.py @@ -409,7 +409,10 @@ OAUTH_ROLES_CLAIM = PersistentConfig( OAUTH_ALLOWED_ROLES = PersistentConfig( "OAUTH_ALLOWED_ROLES", "oauth.allowed_roles", - [role.strip() for role in os.environ.get("OAUTH_ALLOWED_ROLES", "user,admin").split(",")], + [ + role.strip() + for role in os.environ.get("OAUTH_ALLOWED_ROLES", "user,admin").split(",") + ], ) OAUTH_ADMIN_ROLES = PersistentConfig( @@ -418,6 +421,7 @@ OAUTH_ADMIN_ROLES = PersistentConfig( [role.strip() for role in os.environ.get("OAUTH_ADMIN_ROLES", "admin").split(",")], ) + def load_oauth_providers(): OAUTH_PROVIDERS.clear() if GOOGLE_CLIENT_ID.value and GOOGLE_CLIENT_SECRET.value: diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 57c89579e..5b3ca7e64 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -208,8 +208,6 @@ app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = ( app.state.MODELS = {} - - ################################## # # ChatCompletion Middleware @@ -223,14 +221,14 @@ def get_task_model_id(default_model_id): # Check if the user has a custom task model and use that model if app.state.MODELS[task_model_id]["owned_by"] == "ollama": if ( - app.state.config.TASK_MODEL - and app.state.config.TASK_MODEL in app.state.MODELS + app.state.config.TASK_MODEL + and app.state.config.TASK_MODEL in app.state.MODELS ): task_model_id = app.state.config.TASK_MODEL else: if ( - app.state.config.TASK_MODEL_EXTERNAL - and app.state.config.TASK_MODEL_EXTERNAL in app.state.MODELS + app.state.config.TASK_MODEL_EXTERNAL + and app.state.config.TASK_MODEL_EXTERNAL in app.state.MODELS ): task_model_id = app.state.config.TASK_MODEL_EXTERNAL @@ -367,7 +365,7 @@ async def get_content_from_response(response) -> Optional[str]: async def chat_completion_tools_handler( - body: dict, user: UserModel, extra_params: dict + body: dict, user: UserModel, extra_params: dict ) -> tuple[dict, dict]: # If tool_ids field is present, call the functions metadata = body.get("metadata", {}) @@ -681,15 +679,15 @@ def get_sorted_filters(model_id): model for model in app.state.MODELS.values() if "pipeline" in model - and "type" in model["pipeline"] - and model["pipeline"]["type"] == "filter" - and ( - model["pipeline"]["pipelines"] == ["*"] - or any( - model_id == target_model_id - for target_model_id in model["pipeline"]["pipelines"] - ) - ) + and "type" in model["pipeline"] + and model["pipeline"]["type"] == "filter" + and ( + model["pipeline"]["pipelines"] == ["*"] + or any( + model_id == target_model_id + for target_model_id in model["pipeline"]["pipelines"] + ) + ) ] sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"]) return sorted_filters @@ -875,8 +873,8 @@ async def update_embedding_function(request: Request, call_next): @app.middleware("http") async def inspect_websocket(request: Request, call_next): if ( - "/ws/socket.io" in request.url.path - and request.query_params.get("transport") == "websocket" + "/ws/socket.io" in request.url.path + and request.query_params.get("transport") == "websocket" ): upgrade = (request.headers.get("Upgrade") or "").lower() connection = (request.headers.get("Connection") or "").lower().split(",") @@ -945,8 +943,8 @@ async def get_all_models(): if custom_model.base_model_id is None: for model in models: if ( - custom_model.id == model["id"] - or custom_model.id == model["id"].split(":")[0] + custom_model.id == model["id"] + or custom_model.id == model["id"].split(":")[0] ): model["name"] = custom_model.name model["info"] = custom_model.model_dump() @@ -963,8 +961,8 @@ async def get_all_models(): for model in models: if ( - custom_model.base_model_id == model["id"] - or custom_model.base_model_id == model["id"].split(":")[0] + custom_model.base_model_id == model["id"] + or custom_model.base_model_id == model["id"].split(":")[0] ): owned_by = model["owned_by"] if "pipe" in model: @@ -1840,7 +1838,7 @@ async def get_pipelines_list(user=Depends(get_admin_user)): @app.post("/api/pipelines/upload") async def upload_pipeline( - urlIdx: int = Form(...), file: UploadFile = File(...), user=Depends(get_admin_user) + urlIdx: int = Form(...), file: UploadFile = File(...), user=Depends(get_admin_user) ): print("upload_pipeline", urlIdx, file.filename) # Check if the uploaded file is a python file @@ -2017,9 +2015,9 @@ async def get_pipelines(urlIdx: Optional[int] = None, user=Depends(get_admin_use @app.get("/api/pipelines/{pipeline_id}/valves") async def get_pipeline_valves( - urlIdx: Optional[int], - pipeline_id: str, - user=Depends(get_admin_user), + urlIdx: Optional[int], + pipeline_id: str, + user=Depends(get_admin_user), ): r = None try: @@ -2055,9 +2053,9 @@ async def get_pipeline_valves( @app.get("/api/pipelines/{pipeline_id}/valves/spec") async def get_pipeline_valves_spec( - urlIdx: Optional[int], - pipeline_id: str, - user=Depends(get_admin_user), + urlIdx: Optional[int], + pipeline_id: str, + user=Depends(get_admin_user), ): r = None try: @@ -2092,10 +2090,10 @@ async def get_pipeline_valves_spec( @app.post("/api/pipelines/{pipeline_id}/valves/update") async def update_pipeline_valves( - urlIdx: Optional[int], - pipeline_id: str, - form_data: dict, - user=Depends(get_admin_user), + urlIdx: Optional[int], + pipeline_id: str, + form_data: dict, + user=Depends(get_admin_user), ): r = None try: @@ -2219,7 +2217,7 @@ class ModelFilterConfigForm(BaseModel): @app.post("/api/config/model/filter") async def update_model_filter_config( - form_data: ModelFilterConfigForm, user=Depends(get_admin_user) + form_data: ModelFilterConfigForm, user=Depends(get_admin_user) ): app.state.config.ENABLE_MODEL_FILTER = form_data.enabled app.state.config.MODEL_FILTER_LIST = form_data.models @@ -2274,7 +2272,7 @@ async def get_app_latest_release_version(): timeout = aiohttp.ClientTimeout(total=1) async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session: async with session.get( - "https://api.github.com/repos/open-webui/open-webui/releases/latest" + "https://api.github.com/repos/open-webui/open-webui/releases/latest" ) as response: response.raise_for_status() data = await response.json() diff --git a/backend/open_webui/utils/oauth.py b/backend/open_webui/utils/oauth.py index dc3130031..d59e36733 100644 --- a/backend/open_webui/utils/oauth.py +++ b/backend/open_webui/utils/oauth.py @@ -25,7 +25,10 @@ from open_webui.config import ( OAUTH_PICTURE_CLAIM, OAUTH_USERNAME_CLAIM, OAUTH_ALLOWED_ROLES, - OAUTH_ADMIN_ROLES, WEBHOOK_URL, JWT_EXPIRES_IN, AppConfig, + OAUTH_ADMIN_ROLES, + WEBHOOK_URL, + JWT_EXPIRES_IN, + AppConfig, ) from open_webui.constants import ERROR_MESSAGES from open_webui.env import WEBUI_SESSION_COOKIE_SAME_SITE, WEBUI_SESSION_COOKIE_SECURE @@ -170,7 +173,9 @@ class OAuthManager: # If the user does not exist, check if signups are enabled if auth_manager_config.ENABLE_OAUTH_SIGNUP.value: # Check if an existing user with the same email already exists - existing_user = Users.get_user_by_email(user_data.get("email", "").lower()) + existing_user = Users.get_user_by_email( + user_data.get("email", "").lower() + ) if existing_user: raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN) @@ -182,16 +187,18 @@ class OAuthManager: async with aiohttp.ClientSession() as session: async with session.get(picture_url) as resp: picture = await resp.read() - base64_encoded_picture = base64.b64encode(picture).decode( - "utf-8" - ) + base64_encoded_picture = base64.b64encode( + picture + ).decode("utf-8") guessed_mime_type = mimetypes.guess_type(picture_url)[0] if guessed_mime_type is None: # assume JPG, browsers are tolerant enough of image formats guessed_mime_type = "image/jpeg" picture_url = f"data:{guessed_mime_type};base64,{base64_encoded_picture}" except Exception as e: - log.error(f"Error downloading profile image '{picture_url}': {e}") + log.error( + f"Error downloading profile image '{picture_url}': {e}" + ) picture_url = "" if not picture_url: picture_url = "/user.png" @@ -216,7 +223,9 @@ class OAuthManager: auth_manager_config.WEBHOOK_MESSAGES.USER_SIGNUP(user.name), { "action": "signup", - "message": auth_manager_config.WEBHOOK_MESSAGES.USER_SIGNUP(user.name), + "message": auth_manager_config.WEBHOOK_MESSAGES.USER_SIGNUP( + user.name + ), "user": user.model_dump_json(exclude_none=True), }, ) @@ -243,4 +252,5 @@ class OAuthManager: redirect_url = f"{request.base_url}auth#token={jwt_token}" return RedirectResponse(url=redirect_url) -oauth_manager = OAuthManager() \ No newline at end of file + +oauth_manager = OAuthManager()