This commit is contained in:
Adrian Ehrsam 2025-03-18 16:54:11 +01:00
parent d2b7077cce
commit dcb74a7ba8

113
main.py
View File

@ -1,3 +1,4 @@
import inspect
from fastapi import FastAPI, Request, Depends, status, HTTPException, UploadFile, File from fastapi import FastAPI, Request, Depends, status, HTTPException, UploadFile, File
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.concurrency import run_in_threadpool from fastapi.concurrency import run_in_threadpool
@ -666,8 +667,19 @@ async def generate_openai_chat_completion(form_data: OpenAIChatCompletionForm):
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND,
detail=f"Pipeline {form_data.model} not found", detail=f"Pipeline {form_data.model} not found",
) )
async def execute_pipe(pipe, *args, **kwargs):
if inspect.isasyncgenfunction(pipe):
async for res in pipe(*args, **kwargs):
yield res
elif inspect.iscoroutinefunction(pipe):
ls = await pipe(*args, **kwargs)
for item in ls:
yield item
else:
for item in await run_in_threadpool(pipe, *args, **kwargs):
yield item
def job(): async def job():
print(form_data.model) print(form_data.model)
pipeline = app.state.PIPELINES[form_data.model] pipeline = app.state.PIPELINES[form_data.model]
@ -683,39 +695,30 @@ async def generate_openai_chat_completion(form_data: OpenAIChatCompletionForm):
if form_data.stream: if form_data.stream:
def stream_content(): async def stream_content():
res = pipe( res = execute_pipe(pipe,
user_message=user_message, user_message=user_message,
model_id=pipeline_id, model_id=pipeline_id,
messages=messages, messages=messages,
body=form_data.model_dump(), body=form_data.model_dump(),
) )
async for line in res:
if isinstance(line, BaseModel):
line = line.model_dump_json()
line = f"data: {line}"
logging.info(f"stream:true:{res}") try:
line = line.decode("utf-8")
except:
pass
if isinstance(res, str): logging.info(f"stream_content:Generator:{line}")
message = stream_message_template(form_data.model, res)
logging.info(f"stream_content:str:{message}")
yield f"data: {json.dumps(message)}\n\n"
if isinstance(res, Iterator): if line.startswith("data:"):
for line in res: yield f"{line}\n\n"
if isinstance(line, BaseModel): else:
line = line.model_dump_json() line = stream_message_template(form_data.model, line)
line = f"data: {line}" yield f"data: {json.dumps(line)}\n\n"
try:
line = line.decode("utf-8")
except:
pass
logging.info(f"stream_content:Generator:{line}")
if line.startswith("data:"):
yield f"{line}\n\n"
else:
line = stream_message_template(form_data.model, line)
yield f"data: {json.dumps(line)}\n\n"
if isinstance(res, str) or isinstance(res, Generator): if isinstance(res, str) or isinstance(res, Generator):
finish_message = { finish_message = {
@ -738,46 +741,34 @@ async def generate_openai_chat_completion(form_data: OpenAIChatCompletionForm):
return StreamingResponse(stream_content(), media_type="text/event-stream") return StreamingResponse(stream_content(), media_type="text/event-stream")
else: else:
res = pipe( res = execute_pipe(pipe,
user_message=user_message, user_message=user_message,
model_id=pipeline_id, model_id=pipeline_id,
messages=messages, messages=messages,
body=form_data.model_dump(), body=form_data.model_dump(),
) )
logging.info(f"stream:false:{res}")
if isinstance(res, dict): message = ""
return res async for stream in res:
elif isinstance(res, BaseModel): message = f"{message}{stream}"
return res.model_dump()
else:
message = "" logging.info(f"stream:false:{message}")
return {
"id": f"{form_data.model}-{str(uuid.uuid4())}",
"object": "chat.completion",
"created": int(time.time()),
"model": form_data.model,
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": message,
},
"logprobs": None,
"finish_reason": "stop",
}
],
}
if isinstance(res, str): return await job()
message = res
if isinstance(res, Generator):
for stream in res:
message = f"{message}{stream}"
logging.info(f"stream:false:{message}")
return {
"id": f"{form_data.model}-{str(uuid.uuid4())}",
"object": "chat.completion",
"created": int(time.time()),
"model": form_data.model,
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": message,
},
"logprobs": None,
"finish_reason": "stop",
}
],
}
return await run_in_threadpool(job)