This commit is contained in:
Adrian Ehrsam 2025-03-18 16:54:11 +01:00
parent d2b7077cce
commit dcb74a7ba8

47
main.py
View File

@ -1,3 +1,4 @@
import inspect
from fastapi import FastAPI, Request, Depends, status, HTTPException, UploadFile, File from fastapi import FastAPI, Request, Depends, status, HTTPException, UploadFile, File
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.concurrency import run_in_threadpool from fastapi.concurrency import run_in_threadpool
@ -666,8 +667,19 @@ async def generate_openai_chat_completion(form_data: OpenAIChatCompletionForm):
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND,
detail=f"Pipeline {form_data.model} not found", detail=f"Pipeline {form_data.model} not found",
) )
async def execute_pipe(pipe, *args, **kwargs):
if inspect.isasyncgenfunction(pipe):
async for res in pipe(*args, **kwargs):
yield res
elif inspect.iscoroutinefunction(pipe):
ls = await pipe(*args, **kwargs)
for item in ls:
yield item
else:
for item in await run_in_threadpool(pipe, *args, **kwargs):
yield item
def job(): async def job():
print(form_data.model) print(form_data.model)
pipeline = app.state.PIPELINES[form_data.model] pipeline = app.state.PIPELINES[form_data.model]
@ -683,23 +695,14 @@ async def generate_openai_chat_completion(form_data: OpenAIChatCompletionForm):
if form_data.stream: if form_data.stream:
def stream_content(): async def stream_content():
res = pipe( res = execute_pipe(pipe,
user_message=user_message, user_message=user_message,
model_id=pipeline_id, model_id=pipeline_id,
messages=messages, messages=messages,
body=form_data.model_dump(), body=form_data.model_dump(),
) )
async for line in res:
logging.info(f"stream:true:{res}")
if isinstance(res, str):
message = stream_message_template(form_data.model, res)
logging.info(f"stream_content:str:{message}")
yield f"data: {json.dumps(message)}\n\n"
if isinstance(res, Iterator):
for line in res:
if isinstance(line, BaseModel): if isinstance(line, BaseModel):
line = line.model_dump_json() line = line.model_dump_json()
line = f"data: {line}" line = f"data: {line}"
@ -738,27 +741,15 @@ async def generate_openai_chat_completion(form_data: OpenAIChatCompletionForm):
return StreamingResponse(stream_content(), media_type="text/event-stream") return StreamingResponse(stream_content(), media_type="text/event-stream")
else: else:
res = pipe( res = execute_pipe(pipe,
user_message=user_message, user_message=user_message,
model_id=pipeline_id, model_id=pipeline_id,
messages=messages, messages=messages,
body=form_data.model_dump(), body=form_data.model_dump(),
) )
logging.info(f"stream:false:{res}")
if isinstance(res, dict):
return res
elif isinstance(res, BaseModel):
return res.model_dump()
else:
message = "" message = ""
async for stream in res:
if isinstance(res, str):
message = res
if isinstance(res, Generator):
for stream in res:
message = f"{message}{stream}" message = f"{message}{stream}"
logging.info(f"stream:false:{message}") logging.info(f"stream:false:{message}")
@ -780,4 +771,4 @@ async def generate_openai_chat_completion(form_data: OpenAIChatCompletionForm):
], ],
} }
return await run_in_threadpool(job) return await job()