From ae567796ee304a114f5ec97966b947b19d9b5ac4 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Sat, 22 Jun 2024 01:39:53 -0700 Subject: [PATCH] refac --- backend/apps/webui/main.py | 4 +- backend/main.py | 83 +++++++++++++++++++++----------------- 2 files changed, 50 insertions(+), 37 deletions(-) diff --git a/backend/apps/webui/main.py b/backend/apps/webui/main.py index ce58047ed..a9f7fb286 100644 --- a/backend/apps/webui/main.py +++ b/backend/apps/webui/main.py @@ -130,7 +130,9 @@ async def get_pipe_models(): manifold_pipe_name = p["name"] if hasattr(function_module, "name"): - manifold_pipe_name = f"{pipe.name}{manifold_pipe_name}" + manifold_pipe_name = ( + f"{function_module.name}{manifold_pipe_name}" + ) pipe_models.append( { diff --git a/backend/main.py b/backend/main.py index 737641e2d..9eb7d2071 100644 --- a/backend/main.py +++ b/backend/main.py @@ -389,26 +389,31 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): if hasattr(function_module, "inlet"): inlet = function_module.inlet + # Get the signature of the function + sig = inspect.signature(inlet) + param = {"body": data} + + if "__user__" in sig.parameters: + param = { + **param, + "__user__": { + "id": user.id, + "email": user.email, + "name": user.name, + "role": user.role, + }, + } + + if "__id__" in sig.parameters: + param = { + **param, + "__id__": filter_id, + } + if inspect.iscoroutinefunction(inlet): - data = await inlet( - data, - { - "id": user.id, - "email": user.email, - "name": user.name, - "role": user.role, - }, - ) + data = await inlet(**param) else: - data = inlet( - data, - { - "id": user.id, - "email": user.email, - "name": user.name, - "role": user.role, - }, - ) + data = inlet(**param) except Exception as e: print(f"Error: {e}") @@ -1031,26 +1036,32 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)): try: if hasattr(function_module, "outlet"): outlet = function_module.outlet + + # Get the signature of the function + sig = inspect.signature(outlet) + param = {"body": data} + + if "__user__" in sig.parameters: + param = { + **param, + "__user__": { + "id": user.id, + "email": user.email, + "name": user.name, + "role": user.role, + }, + } + + if "__id__" in sig.parameters: + param = { + **param, + "__id__": filter_id, + } + if inspect.iscoroutinefunction(outlet): - data = await outlet( - data, - { - "id": user.id, - "email": user.email, - "name": user.name, - "role": user.role, - }, - ) + data = await outlet(**param) else: - data = outlet( - data, - { - "id": user.id, - "email": user.email, - "name": user.name, - "role": user.role, - }, - ) + data = outlet(**param) except Exception as e: print(f"Error: {e}")