From 646832ba8c98febcaa328b270c6fbd9c2b22067c Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Sat, 22 Jun 2024 12:23:37 -0700 Subject: [PATCH] refac --- backend/main.py | 65 +++++++++++++++++++++++++++++++++++++------------ 1 file changed, 50 insertions(+), 15 deletions(-) diff --git a/backend/main.py b/backend/main.py index 6135d59a3..732ef96aa 100644 --- a/backend/main.py +++ b/backend/main.py @@ -278,8 +278,16 @@ async def get_function_call_response( "email": user.email, "name": user.name, "role": user.role, - "valves": Tools.get_user_valves_by_id_and_user_id( - tool_id, user.id + **( + { + "valves": toolkit_module.UserValves( + Tools.get_user_valves_by_id_and_user_id( + tool_id, user.id + ) + ) + } + if hasattr(toolkit_module, "UserValves") + else {} ), }, } @@ -404,8 +412,18 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): "email": user.email, "name": user.name, "role": user.role, - "valves": Functions.get_user_valves_by_id_and_user_id( - filter_id, user.id + **( + { + "valves": function_module.UserValves( + Functions.get_user_valves_by_id_and_user_id( + filter_id, user.id + ) + ) + } + if hasattr( + function_module, "UserValves" + ) + else {} ), }, } @@ -850,12 +868,6 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u pipe = model.get("pipe") if pipe: - form_data["user"] = { - "id": user.id, - "email": user.email, - "name": user.name, - "role": user.role, - } async def job(): pipe_id = form_data["model"] @@ -863,7 +875,14 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u pipe_id, sub_pipe_id = pipe_id.split(".", 1) print(pipe_id) - pipe = webui_app.state.FUNCTIONS[pipe_id].pipe + # Check if function is already loaded + if pipe_id not in app.state.FUNCTIONS: + function_module, function_type = load_function_module_by_id(pipe_id) + app.state.FUNCTIONS[pipe_id] = function_module + else: + function_module = app.state.FUNCTIONS[pipe_id] + + pipe = function_module.pipe # Get the signature of the function sig = inspect.signature(pipe) @@ -877,8 +896,16 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u "email": user.email, "name": user.name, "role": user.role, - "valves": Functions.get_user_valves_by_id_and_user_id( - pipe_id, user.id + **( + { + "valves": pipe.UserValves( + Functions.get_user_valves_by_id_and_user_id( + pipe_id, user.id + ) + ) + } + if hasattr(function_module, "UserValves") + else {} ), }, } @@ -1079,8 +1106,16 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)): "email": user.email, "name": user.name, "role": user.role, - "valves": Functions.get_user_valves_by_id_and_user_id( - filter_id, user.id + **( + { + "valves": function_module.UserValves( + Functions.get_user_valves_by_id_and_user_id( + filter_id, user.id + ) + ) + } + if hasattr(function_module, "UserValves") + else {} ), }, }