This commit is contained in:
dominik 2025-04-13 16:17:30 -04:00 committed by GitHub
commit 05395415d8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

117
main.py
View File

@ -1,3 +1,4 @@
import inspect
from fastapi import FastAPI, Request, Depends, status, HTTPException, UploadFile, File from fastapi import FastAPI, Request, Depends, status, HTTPException, UploadFile, File
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.concurrency import run_in_threadpool from fastapi.concurrency import run_in_threadpool
@ -27,7 +28,7 @@ import json
import uuid import uuid
import sys import sys
import subprocess import subprocess
import inspect
from config import API_KEY, PIPELINES_DIR, LOG_LEVELS from config import API_KEY, PIPELINES_DIR, LOG_LEVELS
@ -670,8 +671,19 @@ async def generate_openai_chat_completion(form_data: OpenAIChatCompletionForm):
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND,
detail=f"Pipeline {form_data.model} not found", detail=f"Pipeline {form_data.model} not found",
) )
async def execute_pipe(pipe, *args, **kwargs):
if inspect.isasyncgenfunction(pipe):
async for res in pipe(*args, **kwargs):
yield res
elif inspect.iscoroutinefunction(pipe):
for item in await pipe(*args, **kwargs):
yield item
else:
for item in await run_in_threadpool(pipe, *args, **kwargs):
yield item
def job(): async def job():
print(form_data.model) print(form_data.model)
pipeline = app.state.PIPELINES[form_data.model] pipeline = app.state.PIPELINES[form_data.model]
@ -687,39 +699,30 @@ async def generate_openai_chat_completion(form_data: OpenAIChatCompletionForm):
if form_data.stream: if form_data.stream:
def stream_content(): async def stream_content():
res = pipe( res = execute_pipe(pipe,
user_message=user_message, user_message=user_message,
model_id=pipeline_id, model_id=pipeline_id,
messages=messages, messages=messages,
body=form_data.model_dump(), body=form_data.model_dump(),
) )
async for line in res:
if isinstance(line, BaseModel):
line = line.model_dump_json()
line = f"data: {line}"
logging.info(f"stream:true:{res}") try:
line = line.decode("utf-8")
except:
pass
if isinstance(res, str): logging.info(f"stream_content:Generator:{line}")
message = stream_message_template(form_data.model, res)
logging.info(f"stream_content:str:{message}")
yield f"data: {json.dumps(message)}\n\n"
if isinstance(res, Iterator): if line.startswith("data:"):
for line in res: yield f"{line}\n\n"
if isinstance(line, BaseModel): else:
line = line.model_dump_json() line = stream_message_template(form_data.model, line)
line = f"data: {line}" yield f"data: {json.dumps(line)}\n\n"
try:
line = line.decode("utf-8")
except:
pass
logging.info(f"stream_content:Generator:{line}")
if line.startswith("data:"):
yield f"{line}\n\n"
else:
line = stream_message_template(form_data.model, line)
yield f"data: {json.dumps(line)}\n\n"
if isinstance(res, str) or isinstance(res, Generator): if isinstance(res, str) or isinstance(res, Generator):
finish_message = { finish_message = {
@ -736,52 +739,40 @@ async def generate_openai_chat_completion(form_data: OpenAIChatCompletionForm):
} }
], ],
} }
yield f"data: {json.dumps(finish_message)}\n\n" yield f"data: {json.dumps(finish_message)}\n\n"
yield f"data: [DONE]" yield f"data: [DONE]"
return StreamingResponse(stream_content(), media_type="text/event-stream") return StreamingResponse(stream_content(), media_type="text/event-stream")
else: else:
res = pipe( res = execute_pipe(pipe,
user_message=user_message, user_message=user_message,
model_id=pipeline_id, model_id=pipeline_id,
messages=messages, messages=messages,
body=form_data.model_dump(), body=form_data.model_dump(),
) )
logging.info(f"stream:false:{res}")
if isinstance(res, dict): message = ""
return res async for stream in res:
elif isinstance(res, BaseModel): message = f"{message}{stream}"
return res.model_dump()
else:
message = "" logging.info(f"stream:false:{message}")
return {
"id": f"{form_data.model}-{str(uuid.uuid4())}",
"object": "chat.completion",
"created": int(time.time()),
"model": form_data.model,
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": message,
},
"logprobs": None,
"finish_reason": "stop",
}
],
}
if isinstance(res, str): return await job()
message = res
if isinstance(res, Generator):
for stream in res:
message = f"{message}{stream}"
logging.info(f"stream:false:{message}")
return {
"id": f"{form_data.model}-{str(uuid.uuid4())}",
"object": "chat.completion",
"created": int(time.time()),
"model": form_data.model,
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": message,
},
"logprobs": None,
"finish_reason": "stop",
}
],
}
return await run_in_threadpool(job)