Refactor pipeline execution to streamline async handling and improve code readability

This commit is contained in:
Dominik Peter 2025-03-18 19:23:16 +01:00
parent 5854eec4a9
commit 2bb66d0c98

228
main.py
View File

@ -1,3 +1,4 @@
import inspect
from fastapi import FastAPI, Request, Depends, status, HTTPException, UploadFile, File from fastapi import FastAPI, Request, Depends, status, HTTPException, UploadFile, File
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.concurrency import run_in_threadpool from fastapi.concurrency import run_in_threadpool
@ -667,140 +668,50 @@ async def generate_openai_chat_completion(form_data: OpenAIChatCompletionForm):
detail=f"Pipeline {form_data.model} not found", detail=f"Pipeline {form_data.model} not found",
) )
async def execute_pipe(pipe, *args, **kwargs):
if inspect.isasyncgenfunction(pipe):
async for res in pipe(*args, **kwargs):
yield res
elif inspect.iscoroutinefunction(pipe):
for item in await pipe(*args, **kwargs):
yield item
else:
for item in await run_in_threadpool(pipe, *args, **kwargs):
yield item
async def job():
print(form_data.model)
pipeline = app.state.PIPELINES[form_data.model] pipeline = app.state.PIPELINES[form_data.model]
pipeline_id = form_data.model pipeline_id = form_data.model
print(pipeline_id)
if pipeline["type"] == "manifold": if pipeline["type"] == "manifold":
manifold_id, pipeline_id = pipeline_id.split(".", 1) manifold_id, pipeline_id = pipeline_id.split(".", 1)
pipe = PIPELINE_MODULES[manifold_id].pipe pipe = PIPELINE_MODULES[manifold_id].pipe
else: else:
pipe = PIPELINE_MODULES[pipeline_id].pipe pipe = PIPELINE_MODULES[pipeline_id].pipe
is_async = inspect.iscoroutinefunction(pipe)
is_async_gen = inspect.isasyncgenfunction(pipe)
# Helper function to ensure line is a string
def ensure_string(line):
if isinstance(line, bytes):
return line.decode("utf-8")
return str(line)
if form_data.stream: if form_data.stream:
async def stream_content(): async def stream_content():
if is_async_gen: res = execute_pipe(pipe,
pipe_gen = pipe(
user_message=user_message, user_message=user_message,
model_id=pipeline_id, model_id=pipeline_id,
messages=messages, messages=messages,
body=form_data.model_dump(), body=form_data.model_dump(),
) )
async for line in pipe_gen:
if isinstance(line, BaseModel):
line = line.model_dump_json()
line = f"data: {line}"
line = ensure_string(line)
logging.info(f"stream_content:AsyncGeneratorFunction:{line}")
if line.startswith("data:"):
yield f"{line}\n\n"
else:
line = stream_message_template(form_data.model, line)
yield f"data: {json.dumps(line)}\n\n"
finish_message = {
"id": f"{form_data.model}-{str(uuid.uuid4())}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": form_data.model,
"choices": [
{
"index": 0,
"delta": {},
"logprobs": None,
"finish_reason": "stop",
}
],
}
yield f"data: {json.dumps(finish_message)}\n\n"
yield f"data: [DONE]"
elif is_async:
res = await pipe(
user_message=user_message,
model_id=pipeline_id,
messages=messages,
body=form_data.model_dump(),
)
logging.info(f"stream:true:async:{res}")
if isinstance(res, str):
message = stream_message_template(form_data.model, res)
logging.info(f"stream_content:str:async:{message}")
yield f"data: {json.dumps(message)}\n\n"
elif inspect.isasyncgen(res):
async for line in res: async for line in res:
if isinstance(line, BaseModel): if isinstance(line, BaseModel):
line = line.model_dump_json() line = line.model_dump_json()
line = f"data: {line}" line = f"data: {line}"
line = ensure_string(line) try:
logging.info(f"stream_content:AsyncGenerator:{line}") line = line.decode("utf-8")
except:
pass
if line.startswith("data:"):
yield f"{line}\n\n"
else:
line = stream_message_template(form_data.model, line)
yield f"data: {json.dumps(line)}\n\n"
if isinstance(res, str) or inspect.isasyncgen(res):
finish_message = {
"id": f"{form_data.model}-{str(uuid.uuid4())}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": form_data.model,
"choices": [
{
"index": 0,
"delta": {},
"logprobs": None,
"finish_reason": "stop",
}
],
}
yield f"data: {json.dumps(finish_message)}\n\n"
yield f"data: [DONE]"
else:
def sync_job():
res = pipe(
user_message=user_message,
model_id=pipeline_id,
messages=messages,
body=form_data.model_dump(),
)
return res
res = await run_in_threadpool(sync_job)
logging.info(f"stream:true:sync:{res}")
if isinstance(res, str):
message = stream_message_template(form_data.model, res)
logging.info(f"stream_content:str:{message}")
yield f"data: {json.dumps(message)}\n\n"
if isinstance(res, Iterator):
for line in res:
if isinstance(line, BaseModel):
line = line.model_dump_json()
line = f"data: {line}"
line = ensure_string(line)
logging.info(f"stream_content:Generator:{line}") logging.info(f"stream_content:Generator:{line}")
if line.startswith("data:"): if line.startswith("data:"):
@ -830,8 +741,7 @@ async def generate_openai_chat_completion(form_data: OpenAIChatCompletionForm):
return StreamingResponse(stream_content(), media_type="text/event-stream") return StreamingResponse(stream_content(), media_type="text/event-stream")
else: else:
if is_async_gen: res = execute_pipe(pipe,
pipe_gen = pipe(
user_message=user_message, user_message=user_message,
model_id=pipeline_id, model_id=pipeline_id,
messages=messages, messages=messages,
@ -839,96 +749,10 @@ async def generate_openai_chat_completion(form_data: OpenAIChatCompletionForm):
) )
message = "" message = ""
async for stream in pipe_gen:
stream = ensure_string(stream)
message = f"{message}{stream}"
logging.info(f"stream:false:async_gen_function:{message}")
return {
"id": f"{form_data.model}-{str(uuid.uuid4())}",
"object": "chat.completion",
"created": int(time.time()),
"model": form_data.model,
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": message,
},
"logprobs": None,
"finish_reason": "stop",
}
],
}
elif is_async:
res = await pipe(
user_message=user_message,
model_id=pipeline_id,
messages=messages,
body=form_data.model_dump(),
)
logging.info(f"stream:false:async:{res}")
if isinstance(res, dict):
return res
elif isinstance(res, BaseModel):
return res.model_dump()
else:
message = ""
if isinstance(res, str):
message = res
elif inspect.isasyncgen(res):
async for stream in res: async for stream in res:
stream = ensure_string(stream)
message = f"{message}{stream}" message = f"{message}{stream}"
logging.info(f"stream:false:async:{message}") logging.info(f"stream:false:{message}")
return {
"id": f"{form_data.model}-{str(uuid.uuid4())}",
"object": "chat.completion",
"created": int(time.time()),
"model": form_data.model,
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": message,
},
"logprobs": None,
"finish_reason": "stop",
}
],
}
else:
def job():
res = pipe(
user_message=user_message,
model_id=pipeline_id,
messages=messages,
body=form_data.model_dump(),
)
logging.info(f"stream:false:sync:{res}")
if isinstance(res, dict):
return res
elif isinstance(res, BaseModel):
return res.model_dump()
else:
message = ""
if isinstance(res, str):
message = res
if isinstance(res, Generator):
for stream in res:
stream = ensure_string(stream)
message = f"{message}{stream}"
logging.info(f"stream:false:sync:{message}")
return { return {
"id": f"{form_data.model}-{str(uuid.uuid4())}", "id": f"{form_data.model}-{str(uuid.uuid4())}",
"object": "chat.completion", "object": "chat.completion",
@ -947,4 +771,4 @@ async def generate_openai_chat_completion(form_data: OpenAIChatCompletionForm):
], ],
} }
return await run_in_threadpool(job) return await job()