From ff9d899f9c5d8318cacc17c23b1df64a64213eea Mon Sep 17 00:00:00 2001 From: Michael Poluektov Date: Sun, 11 Aug 2024 08:31:40 +0100 Subject: [PATCH] fix more LSP errors --- backend/main.py | 130 +++++++++++++++++++----------------------------- 1 file changed, 51 insertions(+), 79 deletions(-) diff --git a/backend/main.py b/backend/main.py index 512f3d006..0aa6bf167 100644 --- a/backend/main.py +++ b/backend/main.py @@ -261,6 +261,7 @@ def get_filter_function_ids(model): def get_priority(function_id): function = Functions.get_function_by_id(function_id) if function is not None and hasattr(function, "valves"): + # TODO: Fix FunctionModel return (function.valves if function.valves else {}).get("priority", 0) return 0 @@ -322,14 +323,7 @@ async def call_tool_from_completion( async def get_function_call_response( - messages, - files, - tool_id, - template, - task_model_id, - user, - __event_emitter__=None, - __event_call__=None, + messages, files, tool_id, template, task_model_id, user, extra_params ) -> tuple[Optional[str], Optional[dict], bool]: tool = Tools.get_tool_by_id(tool_id) if tool is None: @@ -373,32 +367,22 @@ async def get_function_call_response( toolkit_module, _ = load_toolkit_module_by_id(tool_id) webui_app.state.TOOLS[tool_id] = toolkit_module - __user__ = { - "id": user.id, - "email": user.email, - "name": user.name, - "role": user.role, + custom_params = { + **extra_params, + "__model__": app.state.MODELS[task_model_id], + "__id__": tool_id, + "__messages__": messages, + "__files__": files, } - try: if hasattr(toolkit_module, "UserValves"): - __user__["valves"] = toolkit_module.UserValves( + custom_params["__user__"]["valves"] = toolkit_module.UserValves( **Tools.get_user_valves_by_id_and_user_id(tool_id, user.id) ) except Exception as e: print(e) - extra_params = { - "__model__": app.state.MODELS[task_model_id], - "__id__": tool_id, - "__messages__": messages, - "__files__": files, - "__event_emitter__": __event_emitter__, - "__event_call__": __event_call__, - "__user__": __user__, - } - file_handler = hasattr(toolkit_module, "file_handler") if hasattr(toolkit_module, "valves") and hasattr(toolkit_module, "Valves"): @@ -417,7 +401,7 @@ async def get_function_call_response( result = json.loads(content) function_result = await call_tool_from_completion( - result, extra_params, toolkit_module + result, custom_params, toolkit_module ) if hasattr(toolkit_module, "citation") and toolkit_module.citation: @@ -438,9 +422,7 @@ async def get_function_call_response( return None, None, False -async def chat_completion_inlets_handler( - body, model, user, __event_emitter__, __event_call__ -): +async def chat_completion_inlets_handler(body, model, extra_params): skip_files = None filter_ids = get_filter_function_ids(model) @@ -476,38 +458,18 @@ async def chat_completion_inlets_handler( params = {"body": body} # Extra parameters to be passed to the function - extra_params = { - "__model__": model, - "__id__": filter_id, - "__event_emitter__": __event_emitter__, - "__event_call__": __event_call__, - } + custom_params = {**extra_params, "__model__": model, "__id__": filter_id} + if hasattr(function_module, "UserValves") and "__user__" in sig.parameters: + uid = custom_params["__user__"]["id"] + custom_params["__user__"]["valves"] = function_module.UserValves( + **Functions.get_user_valves_by_id_and_user_id(filter_id, uid) + ) # Add extra params in contained in function signature - for key, value in extra_params.items(): + for key, value in custom_params.items(): if key in sig.parameters: params[key] = value - if "__user__" in sig.parameters: - __user__ = { - "id": user.id, - "email": user.email, - "name": user.name, - "role": user.role, - } - - try: - if hasattr(function_module, "UserValves"): - __user__["valves"] = function_module.UserValves( - **Functions.get_user_valves_by_id_and_user_id( - filter_id, user.id - ) - ) - except Exception as e: - print(e) - - params = {**params, "__user__": __user__} - if inspect.iscoroutinefunction(inlet): body = await inlet(**params) else: @@ -524,7 +486,7 @@ async def chat_completion_inlets_handler( return body, {} -async def chat_completion_tools_handler(body, user, __event_emitter__, __event_call__): +async def chat_completion_tools_handler(body, user, extra_params): skip_files = None contexts = [] @@ -547,8 +509,7 @@ async def chat_completion_tools_handler(body, user, __event_emitter__, __event_c template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, task_model_id=task_model_id, user=user, - __event_emitter__=__event_emitter__, - __event_call__=__event_call__, + extra_params=extra_params, ) print(file_handler) @@ -584,10 +545,7 @@ async def chat_completion_files_handler(body): contexts = [] citations = None - if "files" in body: - files = body["files"] - del body["files"] - + if files := body.pop("files", None): contexts, citations = get_rag_context( files=files, messages=body["messages"], @@ -634,8 +592,18 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): "valves": body.pop("valves", None), } - __event_emitter__ = get_event_emitter(metadata) - __event_call__ = get_event_call(metadata) + __user__ = { + "id": user.id, + "email": user.email, + "name": user.name, + "role": user.role, + } + + extra_params = { + "__user__": __user__, + "__event_emitter__": get_event_emitter(metadata), + "__event_call__": get_event_call(metadata), + } # Initialize data_items to store additional data to be sent to the client # Initalize contexts and citation @@ -645,7 +613,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): try: body, flags = await chat_completion_inlets_handler( - body, model, user, __event_emitter__, __event_call__ + body, model, extra_params ) except Exception as e: return JSONResponse( @@ -654,10 +622,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): ) try: - body, flags = await chat_completion_tools_handler( - body, user, __event_emitter__, __event_call__ - ) - + body, flags = await chat_completion_tools_handler(body, user, extra_params) contexts.extend(flags.get("contexts", [])) citations.extend(flags.get("citations", [])) except Exception as e: @@ -666,7 +631,6 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): try: body, flags = await chat_completion_files_handler(body) - contexts.extend(flags.get("contexts", [])) citations.extend(flags.get("citations", [])) except Exception as e: @@ -713,7 +677,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): response = await call_next(request) if isinstance(response, StreamingResponse): # If it's a streaming response, inject it as SSE event or NDJSON line - content_type = response.headers.get("Content-Type") + content_type = response.headers["Content-Type"] if "text/event-stream" in content_type: return StreamingResponse( self.openai_stream_wrapper(response.body_iterator, data_items), @@ -832,7 +796,7 @@ class PipelineMiddleware(BaseHTTPMiddleware): user = get_current_user( request, - get_http_authorization_cred(request.headers.get("Authorization")), + get_http_authorization_cred(request.headers["Authorization"]), ) try: @@ -1015,6 +979,8 @@ async def get_all_models(): model["actions"] = [] for action_id in action_ids: action = Functions.get_function_by_id(action_id) + if action is None: + raise Exception(f"Action not found: {action_id}") if action_id in webui_app.state.FUNCTIONS: function_module = webui_app.state.FUNCTIONS[action_id] @@ -1022,6 +988,10 @@ async def get_all_models(): function_module, _, _ = load_function_module_by_id(action_id) webui_app.state.FUNCTIONS[action_id] = function_module + icon_url = None + if action.meta.manifest is not None: + icon_url = action.meta.manifest.get("icon_url", None) + if hasattr(function_module, "actions"): actions = function_module.actions model["actions"].extend( @@ -1032,9 +1002,7 @@ async def get_all_models(): "name", f"{action.name} ({_action['id']})" ), "description": action.meta.description, - "icon_url": _action.get( - "icon_url", action.meta.manifest.get("icon_url", None) - ), + "icon_url": _action.get("icon_url", icon_url), } for _action in actions ] @@ -1045,7 +1013,7 @@ async def get_all_models(): "id": action_id, "name": action.name, "description": action.meta.description, - "icon_url": action.meta.manifest.get("icon_url", None), + "icon_url": icon_url, } ) @@ -1175,6 +1143,7 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)): def get_priority(function_id): function = Functions.get_function_by_id(function_id) if function is not None and hasattr(function, "valves"): + # TODO: Fix FunctionModel to include vavles return (function.valves if function.valves else {}).get("priority", 0) return 0 @@ -1631,7 +1600,7 @@ async def upload_pipeline( ): print("upload_pipeline", urlIdx, file.filename) # Check if the uploaded file is a python file - if not file.filename.endswith(".py"): + if not (file.filename and file.filename.endswith(".py")): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Only Python (.py) files are allowed.", @@ -2080,7 +2049,10 @@ async def oauth_login(provider: str, request: Request): redirect_uri = OAUTH_PROVIDERS[provider].get("redirect_uri") or request.url_for( "oauth_callback", provider=provider ) - return await oauth.create_client(provider).authorize_redirect(request, redirect_uri) + client = oauth.create_client(provider) + if client is None: + raise HTTPException(404) + return await client.authorize_redirect(request, redirect_uri) # OAuth login logic is as follows: