From a8a451344cd5970c939a1744aff571db57b74e6d Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Sat, 22 Jun 2024 01:42:28 -0700 Subject: [PATCH] refac --- backend/main.py | 25 ++++++++++++++++++++----- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/backend/main.py b/backend/main.py index 9eb7d2071..0b7f23bde 100644 --- a/backend/main.py +++ b/backend/main.py @@ -858,13 +858,28 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u print(pipe_id) pipe = webui_app.state.FUNCTIONS[pipe_id].pipe - if form_data["stream"]: + # Get the signature of the function + sig = inspect.signature(pipe) + param = {"body": form_data} + + if "__user__" in sig.parameters: + param = { + **param, + "__user__": { + "id": user.id, + "email": user.email, + "name": user.name, + "role": user.role, + }, + } + + if form_data["stream"]: async def stream_content(): if inspect.iscoroutinefunction(pipe): - res = await pipe(body=form_data) + res = await pipe(**param) else: - res = pipe(body=form_data) + res = pipe(**param) if isinstance(res, str): message = stream_message_template(form_data["model"], res) @@ -910,9 +925,9 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u ) else: if inspect.iscoroutinefunction(pipe): - res = await pipe(body=form_data) + res = await pipe(**param) else: - res = pipe(body=form_data) + res = pipe(**param) if isinstance(res, dict): return res