feat: pipe async support

This commit is contained in:
Timothy J. Baek 2024-06-20 20:37:04 -07:00
parent 5621025c12
commit 4370f233a1

View File

@ -843,7 +843,7 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u
"role": user.role,
}
def job():
async def job():
pipe_id = form_data["model"]
if "." in pipe_id:
pipe_id, sub_pipe_id = pipe_id.split(".", 1)
@ -852,8 +852,11 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u
pipe = webui_app.state.FUNCTIONS[pipe_id].pipe
if form_data["stream"]:
def stream_content():
res = pipe(body=form_data)
async def stream_content():
if inspect.iscoroutinefunction(pipe):
res = await pipe(body=form_data)
else:
res = pipe(body=form_data)
if isinstance(res, str):
message = stream_message_template(form_data["model"], res)
@ -898,7 +901,10 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u
stream_content(), media_type="text/event-stream"
)
else:
res = pipe(body=form_data)
if inspect.iscoroutinefunction(pipe):
res = await pipe(body=form_data)
else:
res = pipe(body=form_data)
if isinstance(res, dict):
return res
@ -930,7 +936,7 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u
],
}
return await run_in_threadpool(job)
return await job()
if model["owned_by"] == "ollama":
return await generate_ollama_chat_completion(form_data, user=user)
else: