From 63ba8145b98b4fdf2346b72c019f455b26ea3a73 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Thu, 22 Aug 2024 16:02:29 +0200 Subject: [PATCH] refac --- backend/apps/webui/routers/files.py | 14 +++---- backend/main.py | 63 +++++++++++++++-------------- 2 files changed, 40 insertions(+), 37 deletions(-) diff --git a/backend/apps/webui/routers/files.py b/backend/apps/webui/routers/files.py index ba571fc71..ae01df1ae 100644 --- a/backend/apps/webui/routers/files.py +++ b/backend/apps/webui/routers/files.py @@ -26,7 +26,7 @@ from apps.webui.models.files import ( FileModel, FileModelResponse, ) -from utils.utils import get_verified_user, get_admin_user +from utils.utils import get_current_user, get_admin_user from constants import ERROR_MESSAGES from importlib import util @@ -50,7 +50,7 @@ router = APIRouter() @router.post("/") -def upload_file(file: UploadFile = File(...), user=Depends(get_verified_user)): +def upload_file(file: UploadFile = File(...), user=Depends(get_current_user)): log.info(f"file.content_type: {file.content_type}") try: unsanitized_filename = file.filename @@ -105,7 +105,7 @@ def upload_file(file: UploadFile = File(...), user=Depends(get_verified_user)): @router.get("/", response_model=list[FileModel]) -async def list_files(user=Depends(get_verified_user)): +async def list_files(user=Depends(get_current_user)): files = Files.get_files() return files @@ -153,7 +153,7 @@ async def delete_all_files(user=Depends(get_admin_user)): @router.get("/{id}", response_model=Optional[FileModel]) -async def get_file_by_id(id: str, user=Depends(get_verified_user)): +async def get_file_by_id(id: str, user=Depends(get_current_user)): file = Files.get_file_by_id(id) if file: @@ -171,7 +171,7 @@ async def get_file_by_id(id: str, user=Depends(get_verified_user)): @router.get("/{id}/content", response_model=Optional[FileModel]) -async def get_file_content_by_id(id: str, user=Depends(get_verified_user)): +async def get_file_content_by_id(id: str, user=Depends(get_current_user)): file = Files.get_file_by_id(id) if file: @@ -194,7 +194,7 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)): @router.get("/{id}/content/{file_name}", response_model=Optional[FileModel]) -async def get_file_content_by_id(id: str, user=Depends(get_verified_user)): +async def get_file_content_by_id(id: str, user=Depends(get_current_user)): file = Files.get_file_by_id(id) if file: @@ -222,7 +222,7 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)): @router.delete("/{id}") -async def delete_file_by_id(id: str, user=Depends(get_verified_user)): +async def delete_file_by_id(id: str, user=Depends(get_current_user)): file = Files.get_file_by_id(id) if file: diff --git a/backend/main.py b/backend/main.py index c59f631b3..ab557bea1 100644 --- a/backend/main.py +++ b/backend/main.py @@ -299,24 +299,26 @@ async def chat_completion_filter_functions_handler(body, model, extra_params): # Get the signature of the function sig = inspect.signature(inlet) - params = {"body": body} + params = {"body": body} | { + k: v + for k, v in { + **extra_params, + "__model__": model, + "__id__": filter_id, + }.items() + if k in sig.parameters + } - # Extra parameters to be passed to the function - custom_params = {**extra_params, "__model__": model, "__id__": filter_id} - if hasattr(function_module, "UserValves") and "__user__" in sig.parameters: + if "__user__" in params and hasattr(function_module, "UserValves"): try: - uid = custom_params["__user__"]["id"] - custom_params["__user__"]["valves"] = function_module.UserValves( - **Functions.get_user_valves_by_id_and_user_id(filter_id, uid) + params["__user__"]["valves"] = function_module.UserValves( + **Functions.get_user_valves_by_id_and_user_id( + filter_id, params["__user__"]["id"] + ) ) except Exception as e: print(e) - # Add extra params in contained in function signature - for key, value in custom_params.items(): - if key in sig.parameters: - params[key] = value - if inspect.iscoroutinefunction(inlet): body = await inlet(**params) else: @@ -372,7 +374,9 @@ async def chat_completion_tools_handler( ) -> tuple[dict, dict]: # If tool_ids field is present, call the functions metadata = body.get("metadata", {}) + tool_ids = metadata.get("tool_ids", None) + log.debug(f"{tool_ids=}") if not tool_ids: return body, {} @@ -381,16 +385,17 @@ async def chat_completion_tools_handler( citations = [] task_model_id = get_task_model_id(body["model"]) - - log.debug(f"{tool_ids=}") - - custom_params = { - **extra_params, - "__model__": app.state.MODELS[task_model_id], - "__messages__": body["messages"], - "__files__": metadata.get("files", []), - } - tools = get_tools(webui_app, tool_ids, user, custom_params) + tools = get_tools( + webui_app, + tool_ids, + user, + { + **extra_params, + "__model__": app.state.MODELS[task_model_id], + "__messages__": body["messages"], + "__files__": metadata.get("files", []), + }, + ) log.info(f"{tools=}") specs = [tool["spec"] for tool in tools.values()] @@ -530,17 +535,15 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): } body["metadata"] = metadata - __user__ = { - "id": user.id, - "email": user.email, - "name": user.name, - "role": user.role, - } - extra_params = { - "__user__": __user__, "__event_emitter__": get_event_emitter(metadata), "__event_call__": get_event_call(metadata), + "__user__": { + "id": user.id, + "email": user.email, + "name": user.name, + "role": user.role, + }, } # Initialize data_items to store additional data to be sent to the client