This commit is contained in:
dominik 2025-04-13 16:17:30 -04:00 committed by GitHub
commit 05395415d8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

115
main.py
View File

@ -1,3 +1,4 @@
import inspect
from fastapi import FastAPI, Request, Depends, status, HTTPException, UploadFile, File
from fastapi.middleware.cors import CORSMiddleware
from fastapi.concurrency import run_in_threadpool
@ -27,7 +28,7 @@ import json
import uuid
import sys
import subprocess
import inspect
from config import API_KEY, PIPELINES_DIR, LOG_LEVELS
@ -671,7 +672,18 @@ async def generate_openai_chat_completion(form_data: OpenAIChatCompletionForm):
detail=f"Pipeline {form_data.model} not found",
)
def job():
async def execute_pipe(pipe, *args, **kwargs):
if inspect.isasyncgenfunction(pipe):
async for res in pipe(*args, **kwargs):
yield res
elif inspect.iscoroutinefunction(pipe):
for item in await pipe(*args, **kwargs):
yield item
else:
for item in await run_in_threadpool(pipe, *args, **kwargs):
yield item
async def job():
print(form_data.model)
pipeline = app.state.PIPELINES[form_data.model]
@ -687,39 +699,30 @@ async def generate_openai_chat_completion(form_data: OpenAIChatCompletionForm):
if form_data.stream:
def stream_content():
res = pipe(
async def stream_content():
res = execute_pipe(pipe,
user_message=user_message,
model_id=pipeline_id,
messages=messages,
body=form_data.model_dump(),
)
async for line in res:
if isinstance(line, BaseModel):
line = line.model_dump_json()
line = f"data: {line}"
logging.info(f"stream:true:{res}")
try:
line = line.decode("utf-8")
except:
pass
if isinstance(res, str):
message = stream_message_template(form_data.model, res)
logging.info(f"stream_content:str:{message}")
yield f"data: {json.dumps(message)}\n\n"
logging.info(f"stream_content:Generator:{line}")
if isinstance(res, Iterator):
for line in res:
if isinstance(line, BaseModel):
line = line.model_dump_json()
line = f"data: {line}"
try:
line = line.decode("utf-8")
except:
pass
logging.info(f"stream_content:Generator:{line}")
if line.startswith("data:"):
yield f"{line}\n\n"
else:
line = stream_message_template(form_data.model, line)
yield f"data: {json.dumps(line)}\n\n"
if line.startswith("data:"):
yield f"{line}\n\n"
else:
line = stream_message_template(form_data.model, line)
yield f"data: {json.dumps(line)}\n\n"
if isinstance(res, str) or isinstance(res, Generator):
finish_message = {
@ -742,46 +745,34 @@ async def generate_openai_chat_completion(form_data: OpenAIChatCompletionForm):
return StreamingResponse(stream_content(), media_type="text/event-stream")
else:
res = pipe(
res = execute_pipe(pipe,
user_message=user_message,
model_id=pipeline_id,
messages=messages,
body=form_data.model_dump(),
)
logging.info(f"stream:false:{res}")
if isinstance(res, dict):
return res
elif isinstance(res, BaseModel):
return res.model_dump()
else:
message = ""
async for stream in res:
message = f"{message}{stream}"
message = ""
logging.info(f"stream:false:{message}")
return {
"id": f"{form_data.model}-{str(uuid.uuid4())}",
"object": "chat.completion",
"created": int(time.time()),
"model": form_data.model,
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": message,
},
"logprobs": None,
"finish_reason": "stop",
}
],
}
if isinstance(res, str):
message = res
if isinstance(res, Generator):
for stream in res:
message = f"{message}{stream}"
logging.info(f"stream:false:{message}")
return {
"id": f"{form_data.model}-{str(uuid.uuid4())}",
"object": "chat.completion",
"created": int(time.time()),
"model": form_data.model,
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": message,
},
"logprobs": None,
"finish_reason": "stop",
}
],
}
return await run_in_threadpool(job)
return await job()