From 4370f233a1106270968c210eb60bd16261c199e5 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Thu, 20 Jun 2024 20:37:04 -0700 Subject: [PATCH] feat: pipe async support --- backend/main.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/backend/main.py b/backend/main.py index bfba361ab..bd24c369b 100644 --- a/backend/main.py +++ b/backend/main.py @@ -843,7 +843,7 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u "role": user.role, } - def job(): + async def job(): pipe_id = form_data["model"] if "." in pipe_id: pipe_id, sub_pipe_id = pipe_id.split(".", 1) @@ -852,8 +852,11 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u pipe = webui_app.state.FUNCTIONS[pipe_id].pipe if form_data["stream"]: - def stream_content(): - res = pipe(body=form_data) + async def stream_content(): + if inspect.iscoroutinefunction(pipe): + res = await pipe(body=form_data) + else: + res = pipe(body=form_data) if isinstance(res, str): message = stream_message_template(form_data["model"], res) @@ -898,7 +901,10 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u stream_content(), media_type="text/event-stream" ) else: - res = pipe(body=form_data) + if inspect.iscoroutinefunction(pipe): + res = await pipe(body=form_data) + else: + res = pipe(body=form_data) if isinstance(res, dict): return res @@ -930,7 +936,7 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u ], } - return await run_in_threadpool(job) + return await job() if model["owned_by"] == "ollama": return await generate_ollama_chat_completion(form_data, user=user) else: