Refactor pipeline execution to streamline async handling and improve code readability

This commit is contained in:
Dominik Peter 2025-03-18 19:23:16 +01:00
parent 5854eec4a9
commit 2bb66d0c98

276
main.py
View File

@ -1,3 +1,4 @@
import inspect
from fastapi import FastAPI, Request, Depends, status, HTTPException, UploadFile, File from fastapi import FastAPI, Request, Depends, status, HTTPException, UploadFile, File
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.concurrency import run_in_threadpool from fastapi.concurrency import run_in_threadpool
@ -666,149 +667,59 @@ async def generate_openai_chat_completion(form_data: OpenAIChatCompletionForm):
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND,
detail=f"Pipeline {form_data.model} not found", detail=f"Pipeline {form_data.model} not found",
) )
pipeline = app.state.PIPELINES[form_data.model]
pipeline_id = form_data.model
if pipeline["type"] == "manifold":
manifold_id, pipeline_id = pipeline_id.split(".", 1)
pipe = PIPELINE_MODULES[manifold_id].pipe
else:
pipe = PIPELINE_MODULES[pipeline_id].pipe
is_async = inspect.iscoroutinefunction(pipe)
is_async_gen = inspect.isasyncgenfunction(pipe)
# Helper function to ensure line is a string async def execute_pipe(pipe, *args, **kwargs):
def ensure_string(line): if inspect.isasyncgenfunction(pipe):
if isinstance(line, bytes): async for res in pipe(*args, **kwargs):
return line.decode("utf-8") yield res
return str(line) elif inspect.iscoroutinefunction(pipe):
for item in await pipe(*args, **kwargs):
if form_data.stream: yield item
async def stream_content(): else:
if is_async_gen: for item in await run_in_threadpool(pipe, *args, **kwargs):
pipe_gen = pipe( yield item
async def job():
print(form_data.model)
pipeline = app.state.PIPELINES[form_data.model]
pipeline_id = form_data.model
print(pipeline_id)
if pipeline["type"] == "manifold":
manifold_id, pipeline_id = pipeline_id.split(".", 1)
pipe = PIPELINE_MODULES[manifold_id].pipe
else:
pipe = PIPELINE_MODULES[pipeline_id].pipe
if form_data.stream:
async def stream_content():
res = execute_pipe(pipe,
user_message=user_message, user_message=user_message,
model_id=pipeline_id, model_id=pipeline_id,
messages=messages, messages=messages,
body=form_data.model_dump(), body=form_data.model_dump(),
) )
async for line in res:
async for line in pipe_gen:
if isinstance(line, BaseModel): if isinstance(line, BaseModel):
line = line.model_dump_json() line = line.model_dump_json()
line = f"data: {line}" line = f"data: {line}"
line = ensure_string(line) try:
logging.info(f"stream_content:AsyncGeneratorFunction:{line}") line = line.decode("utf-8")
except:
pass
logging.info(f"stream_content:Generator:{line}")
if line.startswith("data:"): if line.startswith("data:"):
yield f"{line}\n\n" yield f"{line}\n\n"
else: else:
line = stream_message_template(form_data.model, line) line = stream_message_template(form_data.model, line)
yield f"data: {json.dumps(line)}\n\n" yield f"data: {json.dumps(line)}\n\n"
finish_message = {
"id": f"{form_data.model}-{str(uuid.uuid4())}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": form_data.model,
"choices": [
{
"index": 0,
"delta": {},
"logprobs": None,
"finish_reason": "stop",
}
],
}
yield f"data: {json.dumps(finish_message)}\n\n"
yield f"data: [DONE]"
elif is_async:
res = await pipe(
user_message=user_message,
model_id=pipeline_id,
messages=messages,
body=form_data.model_dump(),
)
logging.info(f"stream:true:async:{res}")
if isinstance(res, str):
message = stream_message_template(form_data.model, res)
logging.info(f"stream_content:str:async:{message}")
yield f"data: {json.dumps(message)}\n\n"
elif inspect.isasyncgen(res):
async for line in res:
if isinstance(line, BaseModel):
line = line.model_dump_json()
line = f"data: {line}"
line = ensure_string(line)
logging.info(f"stream_content:AsyncGenerator:{line}")
if line.startswith("data:"):
yield f"{line}\n\n"
else:
line = stream_message_template(form_data.model, line)
yield f"data: {json.dumps(line)}\n\n"
if isinstance(res, str) or inspect.isasyncgen(res):
finish_message = {
"id": f"{form_data.model}-{str(uuid.uuid4())}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": form_data.model,
"choices": [
{
"index": 0,
"delta": {},
"logprobs": None,
"finish_reason": "stop",
}
],
}
yield f"data: {json.dumps(finish_message)}\n\n"
yield f"data: [DONE]"
else:
def sync_job():
res = pipe(
user_message=user_message,
model_id=pipeline_id,
messages=messages,
body=form_data.model_dump(),
)
return res
res = await run_in_threadpool(sync_job)
logging.info(f"stream:true:sync:{res}")
if isinstance(res, str):
message = stream_message_template(form_data.model, res)
logging.info(f"stream_content:str:{message}")
yield f"data: {json.dumps(message)}\n\n"
if isinstance(res, Iterator):
for line in res:
if isinstance(line, BaseModel):
line = line.model_dump_json()
line = f"data: {line}"
line = ensure_string(line)
logging.info(f"stream_content:Generator:{line}")
if line.startswith("data:"):
yield f"{line}\n\n"
else:
line = stream_message_template(form_data.model, line)
yield f"data: {json.dumps(line)}\n\n"
if isinstance(res, str) or isinstance(res, Generator): if isinstance(res, str) or isinstance(res, Generator):
finish_message = { finish_message = {
"id": f"{form_data.model}-{str(uuid.uuid4())}", "id": f"{form_data.model}-{str(uuid.uuid4())}",
@ -827,23 +738,21 @@ async def generate_openai_chat_completion(form_data: OpenAIChatCompletionForm):
yield f"data: {json.dumps(finish_message)}\n\n" yield f"data: {json.dumps(finish_message)}\n\n"
yield f"data: [DONE]" yield f"data: [DONE]"
return StreamingResponse(stream_content(), media_type="text/event-stream") return StreamingResponse(stream_content(), media_type="text/event-stream")
else: else:
if is_async_gen: res = execute_pipe(pipe,
pipe_gen = pipe(
user_message=user_message, user_message=user_message,
model_id=pipeline_id, model_id=pipeline_id,
messages=messages, messages=messages,
body=form_data.model_dump(), body=form_data.model_dump(),
) )
message = "" message = ""
async for stream in pipe_gen: async for stream in res:
stream = ensure_string(stream)
message = f"{message}{stream}" message = f"{message}{stream}"
logging.info(f"stream:false:async_gen_function:{message}") logging.info(f"stream:false:{message}")
return { return {
"id": f"{form_data.model}-{str(uuid.uuid4())}", "id": f"{form_data.model}-{str(uuid.uuid4())}",
"object": "chat.completion", "object": "chat.completion",
@ -861,90 +770,5 @@ async def generate_openai_chat_completion(form_data: OpenAIChatCompletionForm):
} }
], ],
} }
elif is_async:
res = await pipe( return await job()
user_message=user_message,
model_id=pipeline_id,
messages=messages,
body=form_data.model_dump(),
)
logging.info(f"stream:false:async:{res}")
if isinstance(res, dict):
return res
elif isinstance(res, BaseModel):
return res.model_dump()
else:
message = ""
if isinstance(res, str):
message = res
elif inspect.isasyncgen(res):
async for stream in res:
stream = ensure_string(stream)
message = f"{message}{stream}"
logging.info(f"stream:false:async:{message}")
return {
"id": f"{form_data.model}-{str(uuid.uuid4())}",
"object": "chat.completion",
"created": int(time.time()),
"model": form_data.model,
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": message,
},
"logprobs": None,
"finish_reason": "stop",
}
],
}
else:
def job():
res = pipe(
user_message=user_message,
model_id=pipeline_id,
messages=messages,
body=form_data.model_dump(),
)
logging.info(f"stream:false:sync:{res}")
if isinstance(res, dict):
return res
elif isinstance(res, BaseModel):
return res.model_dump()
else:
message = ""
if isinstance(res, str):
message = res
if isinstance(res, Generator):
for stream in res:
stream = ensure_string(stream)
message = f"{message}{stream}"
logging.info(f"stream:false:sync:{message}")
return {
"id": f"{form_data.model}-{str(uuid.uuid4())}",
"object": "chat.completion",
"created": int(time.time()),
"model": form_data.model,
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": message,
},
"logprobs": None,
"finish_reason": "stop",
}
],
}
return await run_in_threadpool(job)