feat: pipe async support

This commit is contained in:
Timothy J. Baek 2024-06-20 20:37:04 -07:00
parent 5621025c12
commit 4370f233a1

View File

@ -843,7 +843,7 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u
"role": user.role, "role": user.role,
} }
def job(): async def job():
pipe_id = form_data["model"] pipe_id = form_data["model"]
if "." in pipe_id: if "." in pipe_id:
pipe_id, sub_pipe_id = pipe_id.split(".", 1) pipe_id, sub_pipe_id = pipe_id.split(".", 1)
@ -852,8 +852,11 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u
pipe = webui_app.state.FUNCTIONS[pipe_id].pipe pipe = webui_app.state.FUNCTIONS[pipe_id].pipe
if form_data["stream"]: if form_data["stream"]:
def stream_content(): async def stream_content():
res = pipe(body=form_data) if inspect.iscoroutinefunction(pipe):
res = await pipe(body=form_data)
else:
res = pipe(body=form_data)
if isinstance(res, str): if isinstance(res, str):
message = stream_message_template(form_data["model"], res) message = stream_message_template(form_data["model"], res)
@ -898,7 +901,10 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u
stream_content(), media_type="text/event-stream" stream_content(), media_type="text/event-stream"
) )
else: else:
res = pipe(body=form_data) if inspect.iscoroutinefunction(pipe):
res = await pipe(body=form_data)
else:
res = pipe(body=form_data)
if isinstance(res, dict): if isinstance(res, dict):
return res return res
@ -930,7 +936,7 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u
], ],
} }
return await run_in_threadpool(job) return await job()
if model["owned_by"] == "ollama": if model["owned_by"] == "ollama":
return await generate_ollama_chat_completion(form_data, user=user) return await generate_ollama_chat_completion(form_data, user=user)
else: else: