Refactor pipeline execution to streamline async handling and improve code readability

This commit is contained in:
Dominik Peter 2025-03-18 19:23:16 +01:00
parent 5854eec4a9
commit 2bb66d0c98

254
main.py
View File

@ -1,3 +1,4 @@
import inspect
from fastapi import FastAPI, Request, Depends, status, HTTPException, UploadFile, File from fastapi import FastAPI, Request, Depends, status, HTTPException, UploadFile, File
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.concurrency import run_in_threadpool from fastapi.concurrency import run_in_threadpool
@ -667,41 +668,51 @@ async def generate_openai_chat_completion(form_data: OpenAIChatCompletionForm):
detail=f"Pipeline {form_data.model} not found", detail=f"Pipeline {form_data.model} not found",
) )
pipeline = app.state.PIPELINES[form_data.model] async def execute_pipe(pipe, *args, **kwargs):
pipeline_id = form_data.model if inspect.isasyncgenfunction(pipe):
async for res in pipe(*args, **kwargs):
yield res
elif inspect.iscoroutinefunction(pipe):
for item in await pipe(*args, **kwargs):
yield item
else:
for item in await run_in_threadpool(pipe, *args, **kwargs):
yield item
if pipeline["type"] == "manifold": async def job():
manifold_id, pipeline_id = pipeline_id.split(".", 1) print(form_data.model)
pipe = PIPELINE_MODULES[manifold_id].pipe
else:
pipe = PIPELINE_MODULES[pipeline_id].pipe
is_async = inspect.iscoroutinefunction(pipe) pipeline = app.state.PIPELINES[form_data.model]
is_async_gen = inspect.isasyncgenfunction(pipe) pipeline_id = form_data.model
# Helper function to ensure line is a string print(pipeline_id)
def ensure_string(line):
if isinstance(line, bytes):
return line.decode("utf-8")
return str(line)
if form_data.stream: if pipeline["type"] == "manifold":
async def stream_content(): manifold_id, pipeline_id = pipeline_id.split(".", 1)
if is_async_gen: pipe = PIPELINE_MODULES[manifold_id].pipe
pipe_gen = pipe( else:
pipe = PIPELINE_MODULES[pipeline_id].pipe
if form_data.stream:
async def stream_content():
res = execute_pipe(pipe,
user_message=user_message, user_message=user_message,
model_id=pipeline_id, model_id=pipeline_id,
messages=messages, messages=messages,
body=form_data.model_dump(), body=form_data.model_dump(),
) )
async for line in res:
async for line in pipe_gen:
if isinstance(line, BaseModel): if isinstance(line, BaseModel):
line = line.model_dump_json() line = line.model_dump_json()
line = f"data: {line}" line = f"data: {line}"
line = ensure_string(line) try:
logging.info(f"stream_content:AsyncGeneratorFunction:{line}") line = line.decode("utf-8")
except:
pass
logging.info(f"stream_content:Generator:{line}")
if line.startswith("data:"): if line.startswith("data:"):
yield f"{line}\n\n" yield f"{line}\n\n"
@ -709,106 +720,6 @@ async def generate_openai_chat_completion(form_data: OpenAIChatCompletionForm):
line = stream_message_template(form_data.model, line) line = stream_message_template(form_data.model, line)
yield f"data: {json.dumps(line)}\n\n" yield f"data: {json.dumps(line)}\n\n"
finish_message = {
"id": f"{form_data.model}-{str(uuid.uuid4())}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": form_data.model,
"choices": [
{
"index": 0,
"delta": {},
"logprobs": None,
"finish_reason": "stop",
}
],
}
yield f"data: {json.dumps(finish_message)}\n\n"
yield f"data: [DONE]"
elif is_async:
res = await pipe(
user_message=user_message,
model_id=pipeline_id,
messages=messages,
body=form_data.model_dump(),
)
logging.info(f"stream:true:async:{res}")
if isinstance(res, str):
message = stream_message_template(form_data.model, res)
logging.info(f"stream_content:str:async:{message}")
yield f"data: {json.dumps(message)}\n\n"
elif inspect.isasyncgen(res):
async for line in res:
if isinstance(line, BaseModel):
line = line.model_dump_json()
line = f"data: {line}"
line = ensure_string(line)
logging.info(f"stream_content:AsyncGenerator:{line}")
if line.startswith("data:"):
yield f"{line}\n\n"
else:
line = stream_message_template(form_data.model, line)
yield f"data: {json.dumps(line)}\n\n"
if isinstance(res, str) or inspect.isasyncgen(res):
finish_message = {
"id": f"{form_data.model}-{str(uuid.uuid4())}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": form_data.model,
"choices": [
{
"index": 0,
"delta": {},
"logprobs": None,
"finish_reason": "stop",
}
],
}
yield f"data: {json.dumps(finish_message)}\n\n"
yield f"data: [DONE]"
else:
def sync_job():
res = pipe(
user_message=user_message,
model_id=pipeline_id,
messages=messages,
body=form_data.model_dump(),
)
return res
res = await run_in_threadpool(sync_job)
logging.info(f"stream:true:sync:{res}")
if isinstance(res, str):
message = stream_message_template(form_data.model, res)
logging.info(f"stream_content:str:{message}")
yield f"data: {json.dumps(message)}\n\n"
if isinstance(res, Iterator):
for line in res:
if isinstance(line, BaseModel):
line = line.model_dump_json()
line = f"data: {line}"
line = ensure_string(line)
logging.info(f"stream_content:Generator:{line}")
if line.startswith("data:"):
yield f"{line}\n\n"
else:
line = stream_message_template(form_data.model, line)
yield f"data: {json.dumps(line)}\n\n"
if isinstance(res, str) or isinstance(res, Generator): if isinstance(res, str) or isinstance(res, Generator):
finish_message = { finish_message = {
"id": f"{form_data.model}-{str(uuid.uuid4())}", "id": f"{form_data.model}-{str(uuid.uuid4())}",
@ -828,10 +739,9 @@ async def generate_openai_chat_completion(form_data: OpenAIChatCompletionForm):
yield f"data: {json.dumps(finish_message)}\n\n" yield f"data: {json.dumps(finish_message)}\n\n"
yield f"data: [DONE]" yield f"data: [DONE]"
return StreamingResponse(stream_content(), media_type="text/event-stream") return StreamingResponse(stream_content(), media_type="text/event-stream")
else: else:
if is_async_gen: res = execute_pipe(pipe,
pipe_gen = pipe(
user_message=user_message, user_message=user_message,
model_id=pipeline_id, model_id=pipeline_id,
messages=messages, messages=messages,
@ -839,11 +749,10 @@ async def generate_openai_chat_completion(form_data: OpenAIChatCompletionForm):
) )
message = "" message = ""
async for stream in pipe_gen: async for stream in res:
stream = ensure_string(stream)
message = f"{message}{stream}" message = f"{message}{stream}"
logging.info(f"stream:false:async_gen_function:{message}") logging.info(f"stream:false:{message}")
return { return {
"id": f"{form_data.model}-{str(uuid.uuid4())}", "id": f"{form_data.model}-{str(uuid.uuid4())}",
"object": "chat.completion", "object": "chat.completion",
@ -861,90 +770,5 @@ async def generate_openai_chat_completion(form_data: OpenAIChatCompletionForm):
} }
], ],
} }
elif is_async:
res = await pipe(
user_message=user_message,
model_id=pipeline_id,
messages=messages,
body=form_data.model_dump(),
)
logging.info(f"stream:false:async:{res}")
if isinstance(res, dict): return await job()
return res
elif isinstance(res, BaseModel):
return res.model_dump()
else:
message = ""
if isinstance(res, str):
message = res
elif inspect.isasyncgen(res):
async for stream in res:
stream = ensure_string(stream)
message = f"{message}{stream}"
logging.info(f"stream:false:async:{message}")
return {
"id": f"{form_data.model}-{str(uuid.uuid4())}",
"object": "chat.completion",
"created": int(time.time()),
"model": form_data.model,
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": message,
},
"logprobs": None,
"finish_reason": "stop",
}
],
}
else:
def job():
res = pipe(
user_message=user_message,
model_id=pipeline_id,
messages=messages,
body=form_data.model_dump(),
)
logging.info(f"stream:false:sync:{res}")
if isinstance(res, dict):
return res
elif isinstance(res, BaseModel):
return res.model_dump()
else:
message = ""
if isinstance(res, str):
message = res
if isinstance(res, Generator):
for stream in res:
stream = ensure_string(stream)
message = f"{message}{stream}"
logging.info(f"stream:false:sync:{message}")
return {
"id": f"{form_data.model}-{str(uuid.uuid4())}",
"object": "chat.completion",
"created": int(time.time()),
"model": form_data.model,
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": message,
},
"logprobs": None,
"finish_reason": "stop",
}
],
}
return await run_in_threadpool(job)