From 5621025c122a783a4912f320f4fd24c710c275ac Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Thu, 20 Jun 2024 20:26:28 -0700 Subject: [PATCH] feat: async filter support --- backend/main.py | 63 +++++++++++++++++++++++++++++++++++-------------- 1 file changed, 45 insertions(+), 18 deletions(-) diff --git a/backend/main.py b/backend/main.py index 346902de6..bfba361ab 100644 --- a/backend/main.py +++ b/backend/main.py @@ -384,15 +384,29 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): try: if hasattr(function_module, "inlet"): - data = function_module.inlet( - data, - { - "id": user.id, - "email": user.email, - "name": user.name, - "role": user.role, - }, - ) + inlet = function_module.inlet + + if inspect.iscoroutinefunction(inlet): + data = await inlet( + data, + { + "id": user.id, + "email": user.email, + "name": user.name, + "role": user.role, + }, + ) + else: + data = inlet( + data, + { + "id": user.id, + "email": user.email, + "name": user.name, + "role": user.role, + }, + ) + except Exception as e: print(f"Error: {e}") return JSONResponse( @@ -1007,15 +1021,28 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)): try: if hasattr(function_module, "outlet"): - data = function_module.outlet( - data, - { - "id": user.id, - "email": user.email, - "name": user.name, - "role": user.role, - }, - ) + outlet = function_module.outlet + if inspect.iscoroutinefunction(outlet): + data = await outlet( + data, + { + "id": user.id, + "email": user.email, + "name": user.name, + "role": user.role, + }, + ) + else: + data = outlet( + data, + { + "id": user.id, + "email": user.email, + "name": user.name, + "role": user.role, + }, + ) + except Exception as e: print(f"Error: {e}") return JSONResponse(